self_attention.go 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. package nn
  2. import (
  3. "makarna/pkg/backend/cpu"
  4. "makarna/pkg/backend/cpu/matmul"
  5. "makarna/pkg/backend/cpu/ops"
  6. "makarna/pkg/kvcache"
  7. "makarna/pkg/tensor"
  8. )
  9. // SelfAttentionConfig holds parameters for self-attention computation
  10. type SelfAttentionConfig struct {
  11. HeadDim int
  12. NumHeads int
  13. NumKVHeads int
  14. RopeTheta float32
  15. RMSNormEps float32
  16. }
  17. // SelfAttention computes full self-attention with optional QK norm and KV cache
  18. // x: input tensor [seqLen, hiddenSize]
  19. // wq, wk, wv, wo: projection weights
  20. // qNorm, kNorm: optional QK normalization weights (can be nil)
  21. // positions: position indices for RoPE
  22. // cache: optional KV cache (can be nil)
  23. // layerIdx: layer index for cache
  24. func SelfAttention(
  25. x *cpu.Tensor,
  26. wq, wk, wv, wo *cpu.Tensor,
  27. qNorm, kNorm *cpu.Tensor,
  28. positions []int,
  29. cfg SelfAttentionConfig,
  30. cache kvcache.KVCacheInterface,
  31. layerIdx int,
  32. ) *cpu.Tensor {
  33. seqLen := x.Shape()[0]
  34. hiddenSize := wo.Shape()[0]
  35. wqShape := wq.Shape()
  36. wkShape := wk.Shape()
  37. wvShape := wv.Shape()
  38. // Q, K, V projections
  39. xq := ops.Zeros(tensor.Shape{seqLen, wqShape[0]})
  40. xk := ops.Zeros(tensor.Shape{seqLen, wkShape[0]})
  41. xv := ops.Zeros(tensor.Shape{seqLen, wvShape[0]})
  42. matmul.Linear(x, wq, xq)
  43. matmul.Linear(x, wk, xk)
  44. matmul.Linear(x, wv, xv)
  45. // Apply QK norm if available (Qwen3)
  46. if qNorm != nil {
  47. RMSNorm(xq, qNorm, cfg.RMSNormEps)
  48. }
  49. if kNorm != nil {
  50. RMSNorm(xk, kNorm, cfg.RMSNormEps)
  51. }
  52. // RoPE
  53. RoPE(xq, positions, cfg.HeadDim, cfg.RopeTheta)
  54. RoPE(xk, positions, cfg.HeadDim, cfg.RopeTheta)
  55. numQHeads := wqShape[0] / cfg.HeadDim
  56. numKVHeads := wkShape[0] / cfg.HeadDim
  57. attnOut := ops.Zeros(tensor.Shape{seqLen, wqShape[0]})
  58. // KV Cache Logic
  59. if cache != nil {
  60. views, startPos, err := cache.Append(layerIdx, xk, xv)
  61. if err != nil {
  62. panic(err)
  63. }
  64. if pv, ok := cache.(kvcache.PackedViewsProvider); ok {
  65. pviews := pv.ViewsPacked(layerIdx)
  66. if len(pviews) > 0 {
  67. if err := CausalAttentionPackedBlocks(xq, pviews, attnOut, numQHeads, numKVHeads, cfg.HeadDim, startPos); err != nil {
  68. panic(err)
  69. }
  70. goto done
  71. }
  72. }
  73. if err := CausalAttentionBlocks(xq, views, attnOut, numQHeads, numKVHeads, cfg.HeadDim, startPos); err != nil {
  74. panic(err)
  75. }
  76. } else {
  77. CausalAttention(xq, xk, xv, attnOut, numQHeads, numKVHeads, cfg.HeadDim)
  78. }
  79. done:
  80. // Output projection
  81. out := ops.Zeros(tensor.Shape{seqLen, hiddenSize})
  82. matmul.Linear(attnOut, wo, out)
  83. return out
  84. }