package nn import ( "math" "sync" "makarna/pkg/backend/cpu" ) // CausalAttention computes causal (masked) scaled dot-product attention with GQA support // Q: [seq_len, num_heads * head_dim] // K: [seq_len, num_kv_heads * head_dim] // V: [seq_len, num_kv_heads * head_dim] // Output: [seq_len, num_heads * head_dim] // // For Grouped Query Attention (GQA): num_heads is a multiple of num_kv_heads // Each KV head is shared across (num_heads / num_kv_heads) query heads func CausalAttention(q, k, v, output *cpu.Tensor, numHeads, numKVHeads, headDim int) error { seqLen := q.Shape()[0] qData := q.DataFloat32() kData := k.DataFloat32() vData := v.DataFloat32() outData := output.DataFloat32() scale := 1.0 / math.Sqrt(float64(headDim)) // Number of Q heads per KV head (for GQA) groupSize := numHeads / numKVHeads workers := cpu.MaxThreads() if workers < 2 || numHeads < 2 { runCausalHeads(qData, kData, vData, outData, numHeads, numKVHeads, headDim, groupSize, seqLen, scale, 0, numHeads) return nil } chunk := (numHeads + workers - 1) / workers var wg sync.WaitGroup for start := 0; start < numHeads; start += chunk { end := start + chunk if end > numHeads { end = numHeads } wg.Add(1) go func(s, e int) { defer wg.Done() runCausalHeads(qData, kData, vData, outData, numHeads, numKVHeads, headDim, groupSize, seqLen, scale, s, e) }(start, end) } wg.Wait() return nil } func runCausalHeads(qData, kData, vData, outData []float32, numHeads, numKVHeads, headDim, groupSize, seqLen int, scale float64, hStart, hEnd int) { // Reuse a per-worker buffer to avoid per-token allocations. // max numKeys is seqLen. scoresBuf := make([]float32, seqLen) strideQ := numHeads * headDim strideKV := numKVHeads * headDim for h := hStart; h < hEnd; h++ { qHeadOffset := h * headDim kvHead := h / groupSize kvHeadOffset := kvHead * headDim for qi := 0; qi < seqLen; qi++ { numKeys := qi + 1 scores := scoresBuf[:numKeys] qBase := qi*strideQ + qHeadOffset qVec := qData[qBase : qBase+headDim] for ki := 0; ki < numKeys; ki++ { kBase := ki*strideKV + kvHeadOffset kVec := kData[kBase : kBase+headDim] dot := cpu.DotFloat32(qVec, kVec) scores[ki] = dot * float32(scale) } softmaxInplace(scores) outBase := qi*strideQ + qHeadOffset outVec := outData[outBase : outBase+headDim] clear(outVec) for vi := 0; vi < numKeys; vi++ { alpha := scores[vi] vBase := vi*strideKV + kvHeadOffset vVec := vData[vBase : vBase+headDim] cpu.Axpy(alpha, vVec, outVec) } } } } // Attention computes full (non-causal) attention - for encoder models func Attention(q, k, v, output *cpu.Tensor, numHeads, numKVHeads, headDim int) error { seqLen := q.Shape()[0] qData := q.DataFloat32() kData := k.DataFloat32() vData := v.DataFloat32() outData := output.DataFloat32() scale := 1.0 / math.Sqrt(float64(headDim)) groupSize := numHeads / numKVHeads workers := cpu.MaxThreads() if workers < 2 || numHeads < 2 { runFullHeads(qData, kData, vData, outData, numHeads, numKVHeads, headDim, groupSize, seqLen, scale, 0, numHeads) return nil } chunk := (numHeads + workers - 1) / workers var wg sync.WaitGroup for start := 0; start < numHeads; start += chunk { end := start + chunk if end > numHeads { end = numHeads } wg.Add(1) go func(s, e int) { defer wg.Done() runFullHeads(qData, kData, vData, outData, numHeads, numKVHeads, headDim, groupSize, seqLen, scale, s, e) }(start, end) } wg.Wait() return nil } func runFullHeads(qData, kData, vData, outData []float32, numHeads, numKVHeads, headDim, groupSize, seqLen int, scale float64, hStart, hEnd int) { // Reuse a per-worker buffer to avoid per-token allocations. scores := make([]float32, seqLen) strideQ := numHeads * headDim strideKV := numKVHeads * headDim for h := hStart; h < hEnd; h++ { qHeadOffset := h * headDim kvHead := h / groupSize kvHeadOffset := kvHead * headDim for qi := 0; qi < seqLen; qi++ { qBase := qi*strideQ + qHeadOffset qVec := qData[qBase : qBase+headDim] for ki := 0; ki < seqLen; ki++ { kBase := ki*strideKV + kvHeadOffset kVec := kData[kBase : kBase+headDim] dot := cpu.DotFloat32(qVec, kVec) scores[ki] = dot * float32(scale) } softmaxInplace(scores) outBase := qi*strideQ + qHeadOffset outVec := outData[outBase : outBase+headDim] clear(outVec) for vi := 0; vi < seqLen; vi++ { alpha := scores[vi] vBase := vi*strideKV + kvHeadOffset vVec := vData[vBase : vBase+headDim] cpu.Axpy(alpha, vVec, outVec) } } } }