package nn import ( "fmt" "math" "sort" "sync" "makarna/pkg/backend/cpu" "makarna/pkg/backend/cuda" "makarna/pkg/kvcache" "makarna/pkg/tensor" ) var useFastExp = true func float16BitsToFloat32(bits uint16) float32 { sign := uint32(bits&0x8000) << 16 exp := int32((bits & 0x7C00) >> 10) mant := uint32(bits & 0x03FF) if exp == 0 { if mant == 0 { return math.Float32frombits(sign) } for mant&0x0400 == 0 { mant <<= 1 exp-- } exp++ mant &= 0x03FF } else if exp == 0x1F { if mant == 0 { return math.Float32frombits(sign | 0x7F800000) } return math.Float32frombits(sign | 0x7FC00000) } exp = exp + (127 - 15) return math.Float32frombits(sign | (uint32(exp) << 23) | (mant << 13)) } func bfloat16BitsToFloat32(bits uint16) float32 { return math.Float32frombits(uint32(bits) << 16) } func expf(x float32) float32 { if useFastExp { // Clamp to a reasonable range for stability. // For softmax weights, very negative values underflow to ~0 anyway. if x < -20 { x = -20 } else if x > 10 { x = 10 } // Schraudolph-style fast exp approximation. // Good tradeoff for softmax weights; much faster than math.Exp. const a = 12102203.0 // (1<<23)/ln(2) const b = 1065353216.0 return math.Float32frombits(uint32(float32(a)*x + float32(b))) } return float32(math.Exp(float64(x))) } type viewData struct { kData []float32 vData []float32 start int length int } // CausalAttentionCached computes causal attention using cached K/V // Q: [newTokens, numHeads * headDim] - query for new tokens only // K: [totalSeqLen, numKVHeads * headDim] - full K history including current // V: [totalSeqLen, numKVHeads * headDim] - full V history including current // Output: [newTokens, numHeads * headDim] // startPos: position of first new token in sequence func CausalAttentionCached(q, k, v, output *cpu.Tensor, numHeads, numKVHeads, headDim, startPos int) error { newTokens := q.Shape()[0] totalSeqLen := k.Shape()[0] qData := q.DataFloat32() kData := k.DataFloat32() vData := v.DataFloat32() outData := output.DataFloat32() scale := 1.0 / math.Sqrt(float64(headDim)) groupSize := numHeads / numKVHeads workers := cpu.MaxThreads() if workers < 2 || numHeads < 2 { runCausalCachedHeads(qData, kData, vData, outData, numHeads, numKVHeads, headDim, groupSize, newTokens, totalSeqLen, startPos, scale, 0, numHeads) return nil } chunk := (numHeads + workers - 1) / workers var wg sync.WaitGroup for start := 0; start < numHeads; start += chunk { end := start + chunk if end > numHeads { end = numHeads } wg.Add(1) go func(s, e int) { defer wg.Done() runCausalCachedHeads(qData, kData, vData, outData, numHeads, numKVHeads, headDim, groupSize, newTokens, totalSeqLen, startPos, scale, s, e) }(start, end) } wg.Wait() return nil } // CausalAttentionPackedBlocks computes causal attention over packed KV views. // Packed layout is head-major: [kvHead][tokenWithinBlock][headDim] as a flat slice. // This avoids kvDim-stride traversal and is a fast CPU path. func CausalAttentionPackedBlocks( q *cpu.Tensor, views []kvcache.PackedView, output *cpu.Tensor, numHeads, numKVHeads, headDim, startPos int, ) error { newTokens := q.Shape()[0] qData := q.DataFloat32() outData := output.DataFloat32() scale := 1.0 / math.Sqrt(float64(headDim)) groupSize := numHeads / numKVHeads // Sort to guarantee increasing start positions. sort.Slice(views, func(i, j int) bool { return views[i].Start < views[j].Start }) workers := cpu.MaxThreads() if workers < 2 || numHeads < 2 { runCausalPackedHeads(qData, outData, views, numHeads, numKVHeads, headDim, groupSize, newTokens, startPos, scale, 0, numHeads) return nil } chunk := (numHeads + workers - 1) / workers var wg sync.WaitGroup for start := 0; start < numHeads; start += chunk { end := start + chunk if end > numHeads { end = numHeads } wg.Add(1) go func(s, e int) { defer wg.Done() runCausalPackedHeads(qData, outData, views, numHeads, numKVHeads, headDim, groupSize, newTokens, startPos, scale, s, e) }(start, end) } wg.Wait() return nil } func runCausalCachedHeads(qData, kData, vData, outData []float32, numHeads, numKVHeads, headDim, groupSize, newTokens, totalSeqLen, startPos int, scale float64, hStart, hEnd int) { strideQ := numHeads * headDim strideKV := numKVHeads * headDim for h := hStart; h < hEnd; h++ { qHeadOffset := h * headDim kvHead := h / groupSize kvHeadOffset := kvHead * headDim for qi := 0; qi < newTokens; qi++ { maxKeyPos := startPos + qi + 1 if maxKeyPos > totalSeqLen { maxKeyPos = totalSeqLen } qBase := qi*strideQ + qHeadOffset qPtr := &qData[qBase] outBase := qi*strideQ + qHeadOffset outVec := outData[outBase : outBase+headDim] outPtr := &outData[outBase] clear(outVec) m := float32(-math.MaxFloat32) l := float32(0) for ti := 0; ti < maxKeyPos; ti++ { kBase := ti*strideKV + kvHeadOffset kPtr := &kData[kBase] s := cpu.DotFloat32Ptr(qPtr, kPtr, headDim) * float32(scale) vBase := ti*strideKV + kvHeadOffset vPtr := &vData[vBase] if s > m { alpha := expf(m - s) if l != 0 { for i := 0; i < headDim; i++ { outVec[i] *= alpha } l *= alpha } m = s l += 1 cpu.AxpyPtr(1, vPtr, outPtr, headDim) continue } w := expf(s - m) l += w cpu.AxpyPtr(w, vPtr, outPtr, headDim) } if l != 0 { inv := 1 / l for i := 0; i < headDim; i++ { outVec[i] *= inv } } } } } func runCausalPackedHeads(qData, outData []float32, views []kvcache.PackedView, numHeads, numKVHeads, headDim, groupSize, newTokens, startPos int, scale float64, hStart, hEnd int) { strideQ := numHeads * headDim for h := hStart; h < hEnd; h++ { qHeadOffset := h * headDim kvHead := h / groupSize for qi := 0; qi < newTokens; qi++ { maxKeyPos := startPos + qi + 1 qBase := qi*strideQ + qHeadOffset qPtr := &qData[qBase] outBase := qi*strideQ + qHeadOffset outVec := outData[outBase : outBase+headDim] outPtr := &outData[outBase] clear(outVec) m := float32(-math.MaxFloat32) l := float32(0) for _, pv := range views { if pv.Length == 0 || pv.Start >= maxKeyPos { continue } if pv.HeadDim != headDim || pv.NumKVHeads != numKVHeads { continue } blkStride := pv.BlockSize * headDim headBase := kvHead * blkStride if headBase+blkStride > len(pv.K) || headBase+blkStride > len(pv.V) { continue } viewLimit := pv.Length if pv.Start+viewLimit > maxKeyPos { viewLimit = maxKeyPos - pv.Start } kHead := pv.K[headBase : headBase+blkStride] vHead := pv.V[headBase : headBase+blkStride] for t := 0; t < viewLimit; t++ { kPtr := &kHead[t*headDim] s := cpu.DotFloat32Ptr(qPtr, kPtr, headDim) * float32(scale) vPtr := &vHead[t*headDim] if s > m { alpha := expf(m - s) if l != 0 { for i := 0; i < headDim; i++ { outVec[i] *= alpha } l *= alpha } m = s l += 1 cpu.AxpyPtr(1, vPtr, outPtr, headDim) continue } w := expf(s - m) l += w cpu.AxpyPtr(w, vPtr, outPtr, headDim) } } if l != 0 { inv := 1 / l for i := 0; i < headDim; i++ { outVec[i] *= inv } } } } } // CausalAttentionBlocks computes attention directly over KV block views without // materializing a contiguous history tensor. startPos is the absolute position // of the first new token (current cache length before the append). func CausalAttentionBlocks( q *cpu.Tensor, views []kvcache.View, output *cpu.Tensor, numHeads, numKVHeads, headDim, startPos int, ) error { newTokens := q.Shape()[0] qData := q.DataFloat32() outData := output.DataFloat32() scale := 1.0 / math.Sqrt(float64(headDim)) groupSize := numHeads / numKVHeads // Pre-extract data from all views (handles CPU and GPU tensors) viewsData := make([]viewData, len(views)) for i, v := range views { if v.Length == 0 { continue } kData, err := tensorToFloat32(v.K) if err != nil { return fmt.Errorf("failed to get K data from view: %w", err) } vData, err := tensorToFloat32(v.V) if err != nil { return fmt.Errorf("failed to get V data from view: %w", err) } viewsData[i] = viewData{ kData: kData, vData: vData, start: v.Start, length: v.Length, } } sort.Slice(viewsData, func(i, j int) bool { return viewsData[i].start < viewsData[j].start }) workers := cpu.MaxThreads() if workers < 2 || numHeads < 2 { runCausalBlockHeads(qData, outData, viewsData, numHeads, numKVHeads, headDim, groupSize, newTokens, startPos, scale, 0, numHeads) return nil } chunk := (numHeads + workers - 1) / workers var wg sync.WaitGroup for start := 0; start < numHeads; start += chunk { end := start + chunk if end > numHeads { end = numHeads } wg.Add(1) go func(s, e int) { defer wg.Done() runCausalBlockHeads(qData, outData, viewsData, numHeads, numKVHeads, headDim, groupSize, newTokens, startPos, scale, s, e) }(start, end) } wg.Wait() return nil } func runCausalBlockHeads(qData, outData []float32, viewsData []viewData, numHeads, numKVHeads, headDim, groupSize, newTokens, startPos int, scale float64, hStart, hEnd int) { strideQ := numHeads * headDim strideKV := numKVHeads * headDim for h := hStart; h < hEnd; h++ { qHeadOffset := h * headDim kvHead := h / groupSize kvHeadOffset := kvHead * headDim for qi := 0; qi < newTokens; qi++ { maxKeyPos := startPos + qi + 1 qBase := qi*strideQ + qHeadOffset qVec := qData[qBase : qBase+headDim] outBase := qi*strideQ + qHeadOffset outVec := outData[outBase : outBase+headDim] clear(outVec) m := float32(-math.MaxFloat32) l := float32(0) for _, vd := range viewsData { if vd.start >= maxKeyPos || vd.length == 0 { continue } viewLimit := vd.length if vd.start+viewLimit > maxKeyPos { viewLimit = maxKeyPos - vd.start } for local := 0; local < viewLimit; local++ { kvIdx := local*strideKV + kvHeadOffset kVec := vd.kData[kvIdx : kvIdx+headDim] s := cpu.DotFloat32(qVec, kVec) * float32(scale) vVec := vd.vData[kvIdx : kvIdx+headDim] if s > m { alpha := expf(m - s) if l != 0 { for i := 0; i < headDim; i++ { outVec[i] *= alpha } l *= alpha } m = s l += 1 cpu.Axpy(1, vVec, outVec) continue } w := expf(s - m) l += w cpu.Axpy(w, vVec, outVec) } } if l != 0 { inv := 1 / l for i := 0; i < headDim; i++ { outVec[i] *= inv } } } } } // tensorToFloat32 extracts float32 data from a tensor, handling both CPU and CUDA tensors. func tensorToFloat32(t tensor.Tensor) ([]float32, error) { switch tt := t.(type) { case *cpu.Tensor: switch tt.DType() { case tensor.Float32: return tt.DataFloat32(), nil case tensor.Float16: in := tt.DataUint16() out := make([]float32, len(in)) for i := range in { out[i] = float16BitsToFloat32(in[i]) } return out, nil case tensor.BFloat16: in := tt.DataUint16() out := make([]float32, len(in)) for i := range in { out[i] = bfloat16BitsToFloat32(in[i]) } return out, nil default: return nil, fmt.Errorf("unsupported CPU tensor dtype: %v", tt.DType()) } case *cuda.Tensor: data := make([]float32, t.Shape().NumElements()) if err := tt.CopyToHost(data); err != nil { return nil, err } return data, nil default: return nil, fmt.Errorf("unsupported tensor type: %T", t) } } func cpuDevice() tensor.DeviceType { return tensor.CPU }