package nn import ( "math" "sync" "makarna/pkg/backend/cpu" ) func CausalAttentionCachedKV(q, k, v, output *cpu.Tensor, numHeads, numKVHeads, headDimK, headDimV, startPos int) error { newTokens := q.Shape()[0] totalSeqLen := k.Shape()[0] qData := q.DataFloat32() kData := k.DataFloat32() vData := v.DataFloat32() outData := output.DataFloat32() scale := 1.0 / math.Sqrt(float64(headDimK)) groupSize := numHeads / numKVHeads workers := cpu.MaxThreads() if workers < 2 || numHeads < 2 { runCausalCachedHeadsKV(qData, kData, vData, outData, numHeads, numKVHeads, headDimK, headDimV, groupSize, newTokens, totalSeqLen, startPos, scale, 0, numHeads) return nil } chunk := (numHeads + workers - 1) / workers var wg sync.WaitGroup for start := 0; start < numHeads; start += chunk { end := start + chunk if end > numHeads { end = numHeads } wg.Add(1) go func(s, e int) { defer wg.Done() runCausalCachedHeadsKV(qData, kData, vData, outData, numHeads, numKVHeads, headDimK, headDimV, groupSize, newTokens, totalSeqLen, startPos, scale, s, e) }(start, end) } wg.Wait() return nil } func runCausalCachedHeadsKV(qData, kData, vData, outData []float32, numHeads, numKVHeads, headDimK, headDimV, groupSize, newTokens, totalSeqLen, startPos int, scale float64, hStart, hEnd int) { strideQ := numHeads * headDimK strideK := numKVHeads * headDimK strideV := numKVHeads * headDimV strideOut := numHeads * headDimV for h := hStart; h < hEnd; h++ { qHeadOffset := h * headDimK outHeadOffset := h * headDimV kvHead := h / groupSize kHeadOffset := kvHead * headDimK vHeadOffset := kvHead * headDimV for qi := 0; qi < newTokens; qi++ { maxKeyPos := startPos + qi + 1 if maxKeyPos > totalSeqLen { maxKeyPos = totalSeqLen } qBase := qi*strideQ + qHeadOffset qPtr := &qData[qBase] outBase := qi*strideOut + outHeadOffset outVec := outData[outBase : outBase+headDimV] outPtr := &outData[outBase] clear(outVec) m := float32(-math.MaxFloat32) l := float32(0) for ti := 0; ti < maxKeyPos; ti++ { kBase := ti*strideK + kHeadOffset kPtr := &kData[kBase] s := cpu.DotFloat32Ptr(qPtr, kPtr, headDimK) * float32(scale) vBase := ti*strideV + vHeadOffset vPtr := &vData[vBase] if s > m { alpha := expf(m - s) if l != 0 { for i := 0; i < headDimV; i++ { outVec[i] *= alpha } l *= alpha } m = s l += 1 cpu.AxpyPtr(1, vPtr, outPtr, headDimV) continue } w := expf(s - m) l += w cpu.AxpyPtr(w, vPtr, outPtr, headDimV) } if l != 0 { inv := 1 / l for i := 0; i < headDimV; i++ { outVec[i] *= inv } } } } }