attention_cached_kv.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. package nn
  2. import (
  3. "math"
  4. "sync"
  5. "makarna/pkg/backend/cpu"
  6. )
  7. func CausalAttentionCachedKV(q, k, v, output *cpu.Tensor, numHeads, numKVHeads, headDimK, headDimV, startPos int) error {
  8. newTokens := q.Shape()[0]
  9. totalSeqLen := k.Shape()[0]
  10. qData := q.DataFloat32()
  11. kData := k.DataFloat32()
  12. vData := v.DataFloat32()
  13. outData := output.DataFloat32()
  14. scale := 1.0 / math.Sqrt(float64(headDimK))
  15. groupSize := numHeads / numKVHeads
  16. workers := cpu.MaxThreads()
  17. if workers < 2 || numHeads < 2 {
  18. runCausalCachedHeadsKV(qData, kData, vData, outData, numHeads, numKVHeads, headDimK, headDimV, groupSize, newTokens, totalSeqLen, startPos, scale, 0, numHeads)
  19. return nil
  20. }
  21. chunk := (numHeads + workers - 1) / workers
  22. var wg sync.WaitGroup
  23. for start := 0; start < numHeads; start += chunk {
  24. end := start + chunk
  25. if end > numHeads {
  26. end = numHeads
  27. }
  28. wg.Add(1)
  29. go func(s, e int) {
  30. defer wg.Done()
  31. runCausalCachedHeadsKV(qData, kData, vData, outData, numHeads, numKVHeads, headDimK, headDimV, groupSize, newTokens, totalSeqLen, startPos, scale, s, e)
  32. }(start, end)
  33. }
  34. wg.Wait()
  35. return nil
  36. }
  37. func runCausalCachedHeadsKV(qData, kData, vData, outData []float32, numHeads, numKVHeads, headDimK, headDimV, groupSize, newTokens, totalSeqLen, startPos int, scale float64, hStart, hEnd int) {
  38. strideQ := numHeads * headDimK
  39. strideK := numKVHeads * headDimK
  40. strideV := numKVHeads * headDimV
  41. strideOut := numHeads * headDimV
  42. for h := hStart; h < hEnd; h++ {
  43. qHeadOffset := h * headDimK
  44. outHeadOffset := h * headDimV
  45. kvHead := h / groupSize
  46. kHeadOffset := kvHead * headDimK
  47. vHeadOffset := kvHead * headDimV
  48. for qi := 0; qi < newTokens; qi++ {
  49. maxKeyPos := startPos + qi + 1
  50. if maxKeyPos > totalSeqLen {
  51. maxKeyPos = totalSeqLen
  52. }
  53. qBase := qi*strideQ + qHeadOffset
  54. qPtr := &qData[qBase]
  55. outBase := qi*strideOut + outHeadOffset
  56. outVec := outData[outBase : outBase+headDimV]
  57. outPtr := &outData[outBase]
  58. clear(outVec)
  59. m := float32(-math.MaxFloat32)
  60. l := float32(0)
  61. for ti := 0; ti < maxKeyPos; ti++ {
  62. kBase := ti*strideK + kHeadOffset
  63. kPtr := &kData[kBase]
  64. s := cpu.DotFloat32Ptr(qPtr, kPtr, headDimK) * float32(scale)
  65. vBase := ti*strideV + vHeadOffset
  66. vPtr := &vData[vBase]
  67. if s > m {
  68. alpha := expf(m - s)
  69. if l != 0 {
  70. for i := 0; i < headDimV; i++ {
  71. outVec[i] *= alpha
  72. }
  73. l *= alpha
  74. }
  75. m = s
  76. l += 1
  77. cpu.AxpyPtr(1, vPtr, outPtr, headDimV)
  78. continue
  79. }
  80. w := expf(s - m)
  81. l += w
  82. cpu.AxpyPtr(w, vPtr, outPtr, headDimV)
  83. }
  84. if l != 0 {
  85. inv := 1 / l
  86. for i := 0; i < headDimV; i++ {
  87. outVec[i] *= inv
  88. }
  89. }
  90. }
  91. }
  92. }