| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- package nn
- import (
- "makarna/pkg/backend/cpu"
- "makarna/pkg/backend/cpu/matmul"
- "makarna/pkg/backend/cpu/ops"
- "makarna/pkg/kvcache"
- "makarna/pkg/tensor"
- )
- // SelfAttentionConfig holds parameters for self-attention computation
- type SelfAttentionConfig struct {
- HeadDim int
- NumHeads int
- NumKVHeads int
- RopeTheta float32
- RMSNormEps float32
- }
- // SelfAttention computes full self-attention with optional QK norm and KV cache
- // x: input tensor [seqLen, hiddenSize]
- // wq, wk, wv, wo: projection weights
- // qNorm, kNorm: optional QK normalization weights (can be nil)
- // positions: position indices for RoPE
- // cache: optional KV cache (can be nil)
- // layerIdx: layer index for cache
- func SelfAttention(
- x *cpu.Tensor,
- wq, wk, wv, wo *cpu.Tensor,
- qNorm, kNorm *cpu.Tensor,
- positions []int,
- cfg SelfAttentionConfig,
- cache kvcache.KVCacheInterface,
- layerIdx int,
- ) *cpu.Tensor {
- seqLen := x.Shape()[0]
- hiddenSize := wo.Shape()[0]
- wqShape := wq.Shape()
- wkShape := wk.Shape()
- wvShape := wv.Shape()
- // Q, K, V projections
- xq := ops.Zeros(tensor.Shape{seqLen, wqShape[0]})
- xk := ops.Zeros(tensor.Shape{seqLen, wkShape[0]})
- xv := ops.Zeros(tensor.Shape{seqLen, wvShape[0]})
- matmul.Linear(x, wq, xq)
- matmul.Linear(x, wk, xk)
- matmul.Linear(x, wv, xv)
- // Apply QK norm if available (Qwen3)
- if qNorm != nil {
- RMSNorm(xq, qNorm, cfg.RMSNormEps)
- }
- if kNorm != nil {
- RMSNorm(xk, kNorm, cfg.RMSNormEps)
- }
- // RoPE
- RoPE(xq, positions, cfg.HeadDim, cfg.RopeTheta)
- RoPE(xk, positions, cfg.HeadDim, cfg.RopeTheta)
- numQHeads := wqShape[0] / cfg.HeadDim
- numKVHeads := wkShape[0] / cfg.HeadDim
- attnOut := ops.Zeros(tensor.Shape{seqLen, wqShape[0]})
- // KV Cache Logic
- if cache != nil {
- views, startPos, err := cache.Append(layerIdx, xk, xv)
- if err != nil {
- panic(err)
- }
- if pv, ok := cache.(kvcache.PackedViewsProvider); ok {
- pviews := pv.ViewsPacked(layerIdx)
- if len(pviews) > 0 {
- if err := CausalAttentionPackedBlocks(xq, pviews, attnOut, numQHeads, numKVHeads, cfg.HeadDim, startPos); err != nil {
- panic(err)
- }
- goto done
- }
- }
- if err := CausalAttentionBlocks(xq, views, attnOut, numQHeads, numKVHeads, cfg.HeadDim, startPos); err != nil {
- panic(err)
- }
- } else {
- CausalAttention(xq, xk, xv, attnOut, numQHeads, numKVHeads, cfg.HeadDim)
- }
- done:
- // Output projection
- out := ops.Zeros(tensor.Shape{seqLen, hiddenSize})
- matmul.Linear(attnOut, wo, out)
- return out
- }
|