| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454 |
- package nn
- import (
- "fmt"
- "math"
- "sort"
- "sync"
- "makarna/pkg/backend/cpu"
- "makarna/pkg/backend/cuda"
- "makarna/pkg/kvcache"
- "makarna/pkg/tensor"
- )
- var useFastExp = true
- func float16BitsToFloat32(bits uint16) float32 {
- sign := uint32(bits&0x8000) << 16
- exp := int32((bits & 0x7C00) >> 10)
- mant := uint32(bits & 0x03FF)
- if exp == 0 {
- if mant == 0 {
- return math.Float32frombits(sign)
- }
- for mant&0x0400 == 0 {
- mant <<= 1
- exp--
- }
- exp++
- mant &= 0x03FF
- } else if exp == 0x1F {
- if mant == 0 {
- return math.Float32frombits(sign | 0x7F800000)
- }
- return math.Float32frombits(sign | 0x7FC00000)
- }
- exp = exp + (127 - 15)
- return math.Float32frombits(sign | (uint32(exp) << 23) | (mant << 13))
- }
- func bfloat16BitsToFloat32(bits uint16) float32 {
- return math.Float32frombits(uint32(bits) << 16)
- }
- func expf(x float32) float32 {
- if useFastExp {
- // Clamp to a reasonable range for stability.
- // For softmax weights, very negative values underflow to ~0 anyway.
- if x < -20 {
- x = -20
- } else if x > 10 {
- x = 10
- }
- // Schraudolph-style fast exp approximation.
- // Good tradeoff for softmax weights; much faster than math.Exp.
- const a = 12102203.0 // (1<<23)/ln(2)
- const b = 1065353216.0
- return math.Float32frombits(uint32(float32(a)*x + float32(b)))
- }
- return float32(math.Exp(float64(x)))
- }
- type viewData struct {
- kData []float32
- vData []float32
- start int
- length int
- }
- // CausalAttentionCached computes causal attention using cached K/V
- // Q: [newTokens, numHeads * headDim] - query for new tokens only
- // K: [totalSeqLen, numKVHeads * headDim] - full K history including current
- // V: [totalSeqLen, numKVHeads * headDim] - full V history including current
- // Output: [newTokens, numHeads * headDim]
- // startPos: position of first new token in sequence
- func CausalAttentionCached(q, k, v, output *cpu.Tensor, numHeads, numKVHeads, headDim, startPos int) error {
- newTokens := q.Shape()[0]
- totalSeqLen := k.Shape()[0]
- qData := q.DataFloat32()
- kData := k.DataFloat32()
- vData := v.DataFloat32()
- outData := output.DataFloat32()
- scale := 1.0 / math.Sqrt(float64(headDim))
- groupSize := numHeads / numKVHeads
- workers := cpu.MaxThreads()
- if workers < 2 || numHeads < 2 {
- runCausalCachedHeads(qData, kData, vData, outData, numHeads, numKVHeads, headDim, groupSize, newTokens, totalSeqLen, startPos, scale, 0, numHeads)
- return nil
- }
- chunk := (numHeads + workers - 1) / workers
- var wg sync.WaitGroup
- for start := 0; start < numHeads; start += chunk {
- end := start + chunk
- if end > numHeads {
- end = numHeads
- }
- wg.Add(1)
- go func(s, e int) {
- defer wg.Done()
- runCausalCachedHeads(qData, kData, vData, outData, numHeads, numKVHeads, headDim, groupSize, newTokens, totalSeqLen, startPos, scale, s, e)
- }(start, end)
- }
- wg.Wait()
- return nil
- }
- // CausalAttentionPackedBlocks computes causal attention over packed KV views.
- // Packed layout is head-major: [kvHead][tokenWithinBlock][headDim] as a flat slice.
- // This avoids kvDim-stride traversal and is a fast CPU path.
- func CausalAttentionPackedBlocks(
- q *cpu.Tensor,
- views []kvcache.PackedView,
- output *cpu.Tensor,
- numHeads, numKVHeads, headDim, startPos int,
- ) error {
- newTokens := q.Shape()[0]
- qData := q.DataFloat32()
- outData := output.DataFloat32()
- scale := 1.0 / math.Sqrt(float64(headDim))
- groupSize := numHeads / numKVHeads
- // Sort to guarantee increasing start positions.
- sort.Slice(views, func(i, j int) bool {
- return views[i].Start < views[j].Start
- })
- workers := cpu.MaxThreads()
- if workers < 2 || numHeads < 2 {
- runCausalPackedHeads(qData, outData, views, numHeads, numKVHeads, headDim, groupSize, newTokens, startPos, scale, 0, numHeads)
- return nil
- }
- chunk := (numHeads + workers - 1) / workers
- var wg sync.WaitGroup
- for start := 0; start < numHeads; start += chunk {
- end := start + chunk
- if end > numHeads {
- end = numHeads
- }
- wg.Add(1)
- go func(s, e int) {
- defer wg.Done()
- runCausalPackedHeads(qData, outData, views, numHeads, numKVHeads, headDim, groupSize, newTokens, startPos, scale, s, e)
- }(start, end)
- }
- wg.Wait()
- return nil
- }
- func runCausalCachedHeads(qData, kData, vData, outData []float32, numHeads, numKVHeads, headDim, groupSize, newTokens, totalSeqLen, startPos int, scale float64, hStart, hEnd int) {
- strideQ := numHeads * headDim
- strideKV := numKVHeads * headDim
- for h := hStart; h < hEnd; h++ {
- qHeadOffset := h * headDim
- kvHead := h / groupSize
- kvHeadOffset := kvHead * headDim
- for qi := 0; qi < newTokens; qi++ {
- maxKeyPos := startPos + qi + 1
- if maxKeyPos > totalSeqLen {
- maxKeyPos = totalSeqLen
- }
- qBase := qi*strideQ + qHeadOffset
- qPtr := &qData[qBase]
- outBase := qi*strideQ + qHeadOffset
- outVec := outData[outBase : outBase+headDim]
- outPtr := &outData[outBase]
- clear(outVec)
- m := float32(-math.MaxFloat32)
- l := float32(0)
- for ti := 0; ti < maxKeyPos; ti++ {
- kBase := ti*strideKV + kvHeadOffset
- kPtr := &kData[kBase]
- s := cpu.DotFloat32Ptr(qPtr, kPtr, headDim) * float32(scale)
- vBase := ti*strideKV + kvHeadOffset
- vPtr := &vData[vBase]
- if s > m {
- alpha := expf(m - s)
- if l != 0 {
- for i := 0; i < headDim; i++ {
- outVec[i] *= alpha
- }
- l *= alpha
- }
- m = s
- l += 1
- cpu.AxpyPtr(1, vPtr, outPtr, headDim)
- continue
- }
- w := expf(s - m)
- l += w
- cpu.AxpyPtr(w, vPtr, outPtr, headDim)
- }
- if l != 0 {
- inv := 1 / l
- for i := 0; i < headDim; i++ {
- outVec[i] *= inv
- }
- }
- }
- }
- }
- func runCausalPackedHeads(qData, outData []float32, views []kvcache.PackedView, numHeads, numKVHeads, headDim, groupSize, newTokens, startPos int, scale float64, hStart, hEnd int) {
- strideQ := numHeads * headDim
- for h := hStart; h < hEnd; h++ {
- qHeadOffset := h * headDim
- kvHead := h / groupSize
- for qi := 0; qi < newTokens; qi++ {
- maxKeyPos := startPos + qi + 1
- qBase := qi*strideQ + qHeadOffset
- qPtr := &qData[qBase]
- outBase := qi*strideQ + qHeadOffset
- outVec := outData[outBase : outBase+headDim]
- outPtr := &outData[outBase]
- clear(outVec)
- m := float32(-math.MaxFloat32)
- l := float32(0)
- for _, pv := range views {
- if pv.Length == 0 || pv.Start >= maxKeyPos {
- continue
- }
- if pv.HeadDim != headDim || pv.NumKVHeads != numKVHeads {
- continue
- }
- blkStride := pv.BlockSize * headDim
- headBase := kvHead * blkStride
- if headBase+blkStride > len(pv.K) || headBase+blkStride > len(pv.V) {
- continue
- }
- viewLimit := pv.Length
- if pv.Start+viewLimit > maxKeyPos {
- viewLimit = maxKeyPos - pv.Start
- }
- kHead := pv.K[headBase : headBase+blkStride]
- vHead := pv.V[headBase : headBase+blkStride]
- for t := 0; t < viewLimit; t++ {
- kPtr := &kHead[t*headDim]
- s := cpu.DotFloat32Ptr(qPtr, kPtr, headDim) * float32(scale)
- vPtr := &vHead[t*headDim]
- if s > m {
- alpha := expf(m - s)
- if l != 0 {
- for i := 0; i < headDim; i++ {
- outVec[i] *= alpha
- }
- l *= alpha
- }
- m = s
- l += 1
- cpu.AxpyPtr(1, vPtr, outPtr, headDim)
- continue
- }
- w := expf(s - m)
- l += w
- cpu.AxpyPtr(w, vPtr, outPtr, headDim)
- }
- }
- if l != 0 {
- inv := 1 / l
- for i := 0; i < headDim; i++ {
- outVec[i] *= inv
- }
- }
- }
- }
- }
- // CausalAttentionBlocks computes attention directly over KV block views without
- // materializing a contiguous history tensor. startPos is the absolute position
- // of the first new token (current cache length before the append).
- func CausalAttentionBlocks(
- q *cpu.Tensor,
- views []kvcache.View,
- output *cpu.Tensor,
- numHeads, numKVHeads, headDim, startPos int,
- ) error {
- newTokens := q.Shape()[0]
- qData := q.DataFloat32()
- outData := output.DataFloat32()
- scale := 1.0 / math.Sqrt(float64(headDim))
- groupSize := numHeads / numKVHeads
- // Pre-extract data from all views (handles CPU and GPU tensors)
- viewsData := make([]viewData, len(views))
- for i, v := range views {
- if v.Length == 0 {
- continue
- }
- kData, err := tensorToFloat32(v.K)
- if err != nil {
- return fmt.Errorf("failed to get K data from view: %w", err)
- }
- vData, err := tensorToFloat32(v.V)
- if err != nil {
- return fmt.Errorf("failed to get V data from view: %w", err)
- }
- viewsData[i] = viewData{
- kData: kData,
- vData: vData,
- start: v.Start,
- length: v.Length,
- }
- }
- sort.Slice(viewsData, func(i, j int) bool {
- return viewsData[i].start < viewsData[j].start
- })
- workers := cpu.MaxThreads()
- if workers < 2 || numHeads < 2 {
- runCausalBlockHeads(qData, outData, viewsData, numHeads, numKVHeads, headDim, groupSize, newTokens, startPos, scale, 0, numHeads)
- return nil
- }
- chunk := (numHeads + workers - 1) / workers
- var wg sync.WaitGroup
- for start := 0; start < numHeads; start += chunk {
- end := start + chunk
- if end > numHeads {
- end = numHeads
- }
- wg.Add(1)
- go func(s, e int) {
- defer wg.Done()
- runCausalBlockHeads(qData, outData, viewsData, numHeads, numKVHeads, headDim, groupSize, newTokens, startPos, scale, s, e)
- }(start, end)
- }
- wg.Wait()
- return nil
- }
- func runCausalBlockHeads(qData, outData []float32, viewsData []viewData, numHeads, numKVHeads, headDim, groupSize, newTokens, startPos int, scale float64, hStart, hEnd int) {
- strideQ := numHeads * headDim
- strideKV := numKVHeads * headDim
- for h := hStart; h < hEnd; h++ {
- qHeadOffset := h * headDim
- kvHead := h / groupSize
- kvHeadOffset := kvHead * headDim
- for qi := 0; qi < newTokens; qi++ {
- maxKeyPos := startPos + qi + 1
- qBase := qi*strideQ + qHeadOffset
- qVec := qData[qBase : qBase+headDim]
- outBase := qi*strideQ + qHeadOffset
- outVec := outData[outBase : outBase+headDim]
- clear(outVec)
- m := float32(-math.MaxFloat32)
- l := float32(0)
- for _, vd := range viewsData {
- if vd.start >= maxKeyPos || vd.length == 0 {
- continue
- }
- viewLimit := vd.length
- if vd.start+viewLimit > maxKeyPos {
- viewLimit = maxKeyPos - vd.start
- }
- for local := 0; local < viewLimit; local++ {
- kvIdx := local*strideKV + kvHeadOffset
- kVec := vd.kData[kvIdx : kvIdx+headDim]
- s := cpu.DotFloat32(qVec, kVec) * float32(scale)
- vVec := vd.vData[kvIdx : kvIdx+headDim]
- if s > m {
- alpha := expf(m - s)
- if l != 0 {
- for i := 0; i < headDim; i++ {
- outVec[i] *= alpha
- }
- l *= alpha
- }
- m = s
- l += 1
- cpu.Axpy(1, vVec, outVec)
- continue
- }
- w := expf(s - m)
- l += w
- cpu.Axpy(w, vVec, outVec)
- }
- }
- if l != 0 {
- inv := 1 / l
- for i := 0; i < headDim; i++ {
- outVec[i] *= inv
- }
- }
- }
- }
- }
- // tensorToFloat32 extracts float32 data from a tensor, handling both CPU and CUDA tensors.
- func tensorToFloat32(t tensor.Tensor) ([]float32, error) {
- switch tt := t.(type) {
- case *cpu.Tensor:
- switch tt.DType() {
- case tensor.Float32:
- return tt.DataFloat32(), nil
- case tensor.Float16:
- in := tt.DataUint16()
- out := make([]float32, len(in))
- for i := range in {
- out[i] = float16BitsToFloat32(in[i])
- }
- return out, nil
- case tensor.BFloat16:
- in := tt.DataUint16()
- out := make([]float32, len(in))
- for i := range in {
- out[i] = bfloat16BitsToFloat32(in[i])
- }
- return out, nil
- default:
- return nil, fmt.Errorf("unsupported CPU tensor dtype: %v", tt.DType())
- }
- case *cuda.Tensor:
- data := make([]float32, t.Shape().NumElements())
- if err := tt.CopyToHost(data); err != nil {
- return nil, err
- }
- return data, nil
- default:
- return nil, fmt.Errorf("unsupported tensor type: %T", t)
- }
- }
- func cpuDevice() tensor.DeviceType { return tensor.CPU }
|