attention.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. package nn
  2. import (
  3. "math"
  4. "sync"
  5. "makarna/pkg/backend/cpu"
  6. )
  7. // CausalAttention computes causal (masked) scaled dot-product attention with GQA support
  8. // Q: [seq_len, num_heads * head_dim]
  9. // K: [seq_len, num_kv_heads * head_dim]
  10. // V: [seq_len, num_kv_heads * head_dim]
  11. // Output: [seq_len, num_heads * head_dim]
  12. //
  13. // For Grouped Query Attention (GQA): num_heads is a multiple of num_kv_heads
  14. // Each KV head is shared across (num_heads / num_kv_heads) query heads
  15. func CausalAttention(q, k, v, output *cpu.Tensor, numHeads, numKVHeads, headDim int) error {
  16. seqLen := q.Shape()[0]
  17. qData := q.DataFloat32()
  18. kData := k.DataFloat32()
  19. vData := v.DataFloat32()
  20. outData := output.DataFloat32()
  21. scale := 1.0 / math.Sqrt(float64(headDim))
  22. // Number of Q heads per KV head (for GQA)
  23. groupSize := numHeads / numKVHeads
  24. workers := cpu.MaxThreads()
  25. if workers < 2 || numHeads < 2 {
  26. runCausalHeads(qData, kData, vData, outData, numHeads, numKVHeads, headDim, groupSize, seqLen, scale, 0, numHeads)
  27. return nil
  28. }
  29. chunk := (numHeads + workers - 1) / workers
  30. var wg sync.WaitGroup
  31. for start := 0; start < numHeads; start += chunk {
  32. end := start + chunk
  33. if end > numHeads {
  34. end = numHeads
  35. }
  36. wg.Add(1)
  37. go func(s, e int) {
  38. defer wg.Done()
  39. runCausalHeads(qData, kData, vData, outData, numHeads, numKVHeads, headDim, groupSize, seqLen, scale, s, e)
  40. }(start, end)
  41. }
  42. wg.Wait()
  43. return nil
  44. }
  45. func runCausalHeads(qData, kData, vData, outData []float32, numHeads, numKVHeads, headDim, groupSize, seqLen int, scale float64, hStart, hEnd int) {
  46. // Reuse a per-worker buffer to avoid per-token allocations.
  47. // max numKeys is seqLen.
  48. scoresBuf := make([]float32, seqLen)
  49. strideQ := numHeads * headDim
  50. strideKV := numKVHeads * headDim
  51. for h := hStart; h < hEnd; h++ {
  52. qHeadOffset := h * headDim
  53. kvHead := h / groupSize
  54. kvHeadOffset := kvHead * headDim
  55. for qi := 0; qi < seqLen; qi++ {
  56. numKeys := qi + 1
  57. scores := scoresBuf[:numKeys]
  58. qBase := qi*strideQ + qHeadOffset
  59. qVec := qData[qBase : qBase+headDim]
  60. for ki := 0; ki < numKeys; ki++ {
  61. kBase := ki*strideKV + kvHeadOffset
  62. kVec := kData[kBase : kBase+headDim]
  63. dot := cpu.DotFloat32(qVec, kVec)
  64. scores[ki] = dot * float32(scale)
  65. }
  66. softmaxInplace(scores)
  67. outBase := qi*strideQ + qHeadOffset
  68. outVec := outData[outBase : outBase+headDim]
  69. clear(outVec)
  70. for vi := 0; vi < numKeys; vi++ {
  71. alpha := scores[vi]
  72. vBase := vi*strideKV + kvHeadOffset
  73. vVec := vData[vBase : vBase+headDim]
  74. cpu.Axpy(alpha, vVec, outVec)
  75. }
  76. }
  77. }
  78. }
  79. // Attention computes full (non-causal) attention - for encoder models
  80. func Attention(q, k, v, output *cpu.Tensor, numHeads, numKVHeads, headDim int) error {
  81. seqLen := q.Shape()[0]
  82. qData := q.DataFloat32()
  83. kData := k.DataFloat32()
  84. vData := v.DataFloat32()
  85. outData := output.DataFloat32()
  86. scale := 1.0 / math.Sqrt(float64(headDim))
  87. groupSize := numHeads / numKVHeads
  88. workers := cpu.MaxThreads()
  89. if workers < 2 || numHeads < 2 {
  90. runFullHeads(qData, kData, vData, outData, numHeads, numKVHeads, headDim, groupSize, seqLen, scale, 0, numHeads)
  91. return nil
  92. }
  93. chunk := (numHeads + workers - 1) / workers
  94. var wg sync.WaitGroup
  95. for start := 0; start < numHeads; start += chunk {
  96. end := start + chunk
  97. if end > numHeads {
  98. end = numHeads
  99. }
  100. wg.Add(1)
  101. go func(s, e int) {
  102. defer wg.Done()
  103. runFullHeads(qData, kData, vData, outData, numHeads, numKVHeads, headDim, groupSize, seqLen, scale, s, e)
  104. }(start, end)
  105. }
  106. wg.Wait()
  107. return nil
  108. }
  109. func runFullHeads(qData, kData, vData, outData []float32, numHeads, numKVHeads, headDim, groupSize, seqLen int, scale float64, hStart, hEnd int) {
  110. // Reuse a per-worker buffer to avoid per-token allocations.
  111. scores := make([]float32, seqLen)
  112. strideQ := numHeads * headDim
  113. strideKV := numKVHeads * headDim
  114. for h := hStart; h < hEnd; h++ {
  115. qHeadOffset := h * headDim
  116. kvHead := h / groupSize
  117. kvHeadOffset := kvHead * headDim
  118. for qi := 0; qi < seqLen; qi++ {
  119. qBase := qi*strideQ + qHeadOffset
  120. qVec := qData[qBase : qBase+headDim]
  121. for ki := 0; ki < seqLen; ki++ {
  122. kBase := ki*strideKV + kvHeadOffset
  123. kVec := kData[kBase : kBase+headDim]
  124. dot := cpu.DotFloat32(qVec, kVec)
  125. scores[ki] = dot * float32(scale)
  126. }
  127. softmaxInplace(scores)
  128. outBase := qi*strideQ + qHeadOffset
  129. outVec := outData[outBase : outBase+headDim]
  130. clear(outVec)
  131. for vi := 0; vi < seqLen; vi++ {
  132. alpha := scores[vi]
  133. vBase := vi*strideKV + kvHeadOffset
  134. vVec := vData[vBase : vBase+headDim]
  135. cpu.Axpy(alpha, vVec, outVec)
  136. }
  137. }
  138. }
  139. }