1
0

attention_batch.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. package nn
  2. import (
  3. "fmt"
  4. "math"
  5. "sync"
  6. "makarna/pkg/backend/cpu"
  7. "makarna/pkg/kvcache"
  8. )
  9. // CausalAttentionPackedBlocksBatch computes causal attention for a batch where each
  10. // row in q/output is an independent sequence (decode-style batch).
  11. //
  12. // q: [numTokens, numHeads*headDim] (float32, CPU)
  13. // views: per-token packed KV views (head-major blocks)
  14. // output: [numTokens, numHeads*headDim] (float32, CPU)
  15. // queryPos: per-token absolute query position (startPos for that token).
  16. //
  17. // This is optimized for CPU KV caches and uses SIMD-backed Dot/Axpy when available.
  18. func CausalAttentionPackedBlocksBatch(
  19. q *cpu.Tensor,
  20. viewsByToken [][]kvcache.PackedView,
  21. output *cpu.Tensor,
  22. numHeads, numKVHeads, headDim int,
  23. queryPos []int,
  24. ) error {
  25. if q == nil || output == nil {
  26. return fmt.Errorf("nil tensor")
  27. }
  28. if numHeads <= 0 || numKVHeads <= 0 || headDim <= 0 {
  29. return fmt.Errorf("invalid heads/dim: numHeads=%d numKVHeads=%d headDim=%d", numHeads, numKVHeads, headDim)
  30. }
  31. if numHeads%numKVHeads != 0 {
  32. return fmt.Errorf("numHeads %d not divisible by numKVHeads %d", numHeads, numKVHeads)
  33. }
  34. qShape := q.Shape()
  35. outShape := output.Shape()
  36. if len(qShape) != 2 || len(outShape) != 2 {
  37. return fmt.Errorf("expected 2D tensors (q=%v out=%v)", qShape, outShape)
  38. }
  39. numTokens := qShape[0]
  40. qStride := qShape[1]
  41. if numTokens <= 0 {
  42. return nil
  43. }
  44. if outShape[0] != numTokens || outShape[1] != qStride {
  45. return fmt.Errorf("output shape mismatch: q=%v out=%v", qShape, outShape)
  46. }
  47. if qStride != numHeads*headDim {
  48. return fmt.Errorf("q stride mismatch: got %d want %d (numHeads*headDim)", qStride, numHeads*headDim)
  49. }
  50. if len(viewsByToken) != numTokens {
  51. return fmt.Errorf("viewsByToken len %d != numTokens %d", len(viewsByToken), numTokens)
  52. }
  53. if len(queryPos) != numTokens {
  54. return fmt.Errorf("queryPos len %d != numTokens %d", len(queryPos), numTokens)
  55. }
  56. qData := q.DataFloat32()
  57. outData := output.DataFloat32()
  58. scale := float32(1.0 / math.Sqrt(float64(headDim)))
  59. groupSize := numHeads / numKVHeads
  60. workItems := numTokens * numHeads
  61. workers := cpu.MaxThreads()
  62. if workers < 2 || workItems < 2 {
  63. runPackedBatchWork(qData, outData, viewsByToken, queryPos, numTokens, numHeads, numKVHeads, headDim, groupSize, qStride, scale, 0, workItems)
  64. return nil
  65. }
  66. if workers > workItems {
  67. workers = workItems
  68. }
  69. chunk := (workItems + workers - 1) / workers
  70. var wg sync.WaitGroup
  71. wg.Add(workers)
  72. for w := 0; w < workers; w++ {
  73. start := w * chunk
  74. end := start + chunk
  75. if end > workItems {
  76. end = workItems
  77. }
  78. go func(s, e int) {
  79. defer wg.Done()
  80. runPackedBatchWork(qData, outData, viewsByToken, queryPos, numTokens, numHeads, numKVHeads, headDim, groupSize, qStride, scale, s, e)
  81. }(start, end)
  82. }
  83. wg.Wait()
  84. return nil
  85. }
  86. func runPackedBatchWork(
  87. qData, outData []float32,
  88. viewsByToken [][]kvcache.PackedView,
  89. queryPos []int,
  90. numTokens, numHeads, numKVHeads, headDim, groupSize, qStride int,
  91. scale float32,
  92. workStart, workEnd int,
  93. ) {
  94. for idx := workStart; idx < workEnd; idx++ {
  95. tok := idx / numHeads
  96. head := idx - tok*numHeads
  97. if tok < 0 || tok >= numTokens {
  98. continue
  99. }
  100. qHeadOffset := head * headDim
  101. qBase := tok*qStride + qHeadOffset
  102. if qBase < 0 || qBase+headDim > len(qData) {
  103. continue
  104. }
  105. outBase := tok*qStride + qHeadOffset
  106. if outBase < 0 || outBase+headDim > len(outData) {
  107. continue
  108. }
  109. qPtr := &qData[qBase]
  110. outVec := outData[outBase : outBase+headDim]
  111. outPtr := &outData[outBase]
  112. clear(outVec)
  113. kvHead := head / groupSize
  114. maxKeyPos := queryPos[tok] + 1
  115. if maxKeyPos <= 0 {
  116. continue
  117. }
  118. m := float32(-math.MaxFloat32)
  119. l := float32(0)
  120. for _, pv := range viewsByToken[tok] {
  121. if pv.Length == 0 || pv.Start >= maxKeyPos {
  122. continue
  123. }
  124. if pv.HeadDim != headDim || pv.NumKVHeads != numKVHeads {
  125. continue
  126. }
  127. blkStride := pv.BlockSize * headDim
  128. headBase := kvHead * blkStride
  129. if blkStride <= 0 || headBase < 0 || headBase+blkStride > len(pv.K) || headBase+blkStride > len(pv.V) {
  130. continue
  131. }
  132. viewLimit := pv.Length
  133. if pv.Start+viewLimit > maxKeyPos {
  134. viewLimit = maxKeyPos - pv.Start
  135. }
  136. if viewLimit <= 0 {
  137. continue
  138. }
  139. kHead := pv.K[headBase : headBase+blkStride]
  140. vHead := pv.V[headBase : headBase+blkStride]
  141. for t := 0; t < viewLimit; t++ {
  142. kPtr := &kHead[t*headDim]
  143. s := cpu.DotFloat32Ptr(qPtr, kPtr, headDim) * scale
  144. vPtr := &vHead[t*headDim]
  145. if s > m {
  146. alpha := expf(m - s)
  147. if l != 0 {
  148. for i := 0; i < headDim; i++ {
  149. outVec[i] *= alpha
  150. }
  151. l *= alpha
  152. }
  153. m = s
  154. l += 1
  155. cpu.AxpyPtr(1, vPtr, outPtr, headDim)
  156. continue
  157. }
  158. w := expf(s - m)
  159. l += w
  160. cpu.AxpyPtr(w, vPtr, outPtr, headDim)
  161. }
  162. }
  163. if l != 0 {
  164. inv := 1 / l
  165. for i := 0; i < headDim; i++ {
  166. outVec[i] *= inv
  167. }
  168. }
  169. }
  170. }