| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164 |
- package nn
- import (
- "math"
- "sync"
- "makarna/pkg/backend/cpu"
- )
- // CausalAttention computes causal (masked) scaled dot-product attention with GQA support
- // Q: [seq_len, num_heads * head_dim]
- // K: [seq_len, num_kv_heads * head_dim]
- // V: [seq_len, num_kv_heads * head_dim]
- // Output: [seq_len, num_heads * head_dim]
- //
- // For Grouped Query Attention (GQA): num_heads is a multiple of num_kv_heads
- // Each KV head is shared across (num_heads / num_kv_heads) query heads
- func CausalAttention(q, k, v, output *cpu.Tensor, numHeads, numKVHeads, headDim int) error {
- seqLen := q.Shape()[0]
- qData := q.DataFloat32()
- kData := k.DataFloat32()
- vData := v.DataFloat32()
- outData := output.DataFloat32()
- scale := 1.0 / math.Sqrt(float64(headDim))
- // Number of Q heads per KV head (for GQA)
- groupSize := numHeads / numKVHeads
- workers := cpu.MaxThreads()
- if workers < 2 || numHeads < 2 {
- runCausalHeads(qData, kData, vData, outData, numHeads, numKVHeads, headDim, groupSize, seqLen, scale, 0, numHeads)
- return nil
- }
- chunk := (numHeads + workers - 1) / workers
- var wg sync.WaitGroup
- for start := 0; start < numHeads; start += chunk {
- end := start + chunk
- if end > numHeads {
- end = numHeads
- }
- wg.Add(1)
- go func(s, e int) {
- defer wg.Done()
- runCausalHeads(qData, kData, vData, outData, numHeads, numKVHeads, headDim, groupSize, seqLen, scale, s, e)
- }(start, end)
- }
- wg.Wait()
- return nil
- }
- func runCausalHeads(qData, kData, vData, outData []float32, numHeads, numKVHeads, headDim, groupSize, seqLen int, scale float64, hStart, hEnd int) {
- // Reuse a per-worker buffer to avoid per-token allocations.
- // max numKeys is seqLen.
- scoresBuf := make([]float32, seqLen)
- strideQ := numHeads * headDim
- strideKV := numKVHeads * headDim
- for h := hStart; h < hEnd; h++ {
- qHeadOffset := h * headDim
- kvHead := h / groupSize
- kvHeadOffset := kvHead * headDim
- for qi := 0; qi < seqLen; qi++ {
- numKeys := qi + 1
- scores := scoresBuf[:numKeys]
- qBase := qi*strideQ + qHeadOffset
- qVec := qData[qBase : qBase+headDim]
- for ki := 0; ki < numKeys; ki++ {
- kBase := ki*strideKV + kvHeadOffset
- kVec := kData[kBase : kBase+headDim]
- dot := cpu.DotFloat32(qVec, kVec)
- scores[ki] = dot * float32(scale)
- }
- softmaxInplace(scores)
- outBase := qi*strideQ + qHeadOffset
- outVec := outData[outBase : outBase+headDim]
- clear(outVec)
- for vi := 0; vi < numKeys; vi++ {
- alpha := scores[vi]
- vBase := vi*strideKV + kvHeadOffset
- vVec := vData[vBase : vBase+headDim]
- cpu.Axpy(alpha, vVec, outVec)
- }
- }
- }
- }
- // Attention computes full (non-causal) attention - for encoder models
- func Attention(q, k, v, output *cpu.Tensor, numHeads, numKVHeads, headDim int) error {
- seqLen := q.Shape()[0]
- qData := q.DataFloat32()
- kData := k.DataFloat32()
- vData := v.DataFloat32()
- outData := output.DataFloat32()
- scale := 1.0 / math.Sqrt(float64(headDim))
- groupSize := numHeads / numKVHeads
- workers := cpu.MaxThreads()
- if workers < 2 || numHeads < 2 {
- runFullHeads(qData, kData, vData, outData, numHeads, numKVHeads, headDim, groupSize, seqLen, scale, 0, numHeads)
- return nil
- }
- chunk := (numHeads + workers - 1) / workers
- var wg sync.WaitGroup
- for start := 0; start < numHeads; start += chunk {
- end := start + chunk
- if end > numHeads {
- end = numHeads
- }
- wg.Add(1)
- go func(s, e int) {
- defer wg.Done()
- runFullHeads(qData, kData, vData, outData, numHeads, numKVHeads, headDim, groupSize, seqLen, scale, s, e)
- }(start, end)
- }
- wg.Wait()
- return nil
- }
- func runFullHeads(qData, kData, vData, outData []float32, numHeads, numKVHeads, headDim, groupSize, seqLen int, scale float64, hStart, hEnd int) {
- // Reuse a per-worker buffer to avoid per-token allocations.
- scores := make([]float32, seqLen)
- strideQ := numHeads * headDim
- strideKV := numKVHeads * headDim
- for h := hStart; h < hEnd; h++ {
- qHeadOffset := h * headDim
- kvHead := h / groupSize
- kvHeadOffset := kvHead * headDim
- for qi := 0; qi < seqLen; qi++ {
- qBase := qi*strideQ + qHeadOffset
- qVec := qData[qBase : qBase+headDim]
- for ki := 0; ki < seqLen; ki++ {
- kBase := ki*strideKV + kvHeadOffset
- kVec := kData[kBase : kBase+headDim]
- dot := cpu.DotFloat32(qVec, kVec)
- scores[ki] = dot * float32(scale)
- }
- softmaxInplace(scores)
- outBase := qi*strideQ + qHeadOffset
- outVec := outData[outBase : outBase+headDim]
- clear(outVec)
- for vi := 0; vi < seqLen; vi++ {
- alpha := scores[vi]
- vBase := vi*strideKV + kvHeadOffset
- vVec := vData[vBase : vBase+headDim]
- cpu.Axpy(alpha, vVec, outVec)
- }
- }
- }
- }
|