package nn import ( "makarna/pkg/backend/cpu" "makarna/pkg/backend/cpu/matmul" "makarna/pkg/backend/cpu/ops" "makarna/pkg/kvcache" "makarna/pkg/tensor" ) // SelfAttentionConfig holds parameters for self-attention computation type SelfAttentionConfig struct { HeadDim int NumHeads int NumKVHeads int RopeTheta float32 RMSNormEps float32 } // SelfAttention computes full self-attention with optional QK norm and KV cache // x: input tensor [seqLen, hiddenSize] // wq, wk, wv, wo: projection weights // qNorm, kNorm: optional QK normalization weights (can be nil) // positions: position indices for RoPE // cache: optional KV cache (can be nil) // layerIdx: layer index for cache func SelfAttention( x *cpu.Tensor, wq, wk, wv, wo *cpu.Tensor, qNorm, kNorm *cpu.Tensor, positions []int, cfg SelfAttentionConfig, cache kvcache.KVCacheInterface, layerIdx int, ) *cpu.Tensor { seqLen := x.Shape()[0] hiddenSize := wo.Shape()[0] wqShape := wq.Shape() wkShape := wk.Shape() wvShape := wv.Shape() // Q, K, V projections xq := ops.Zeros(tensor.Shape{seqLen, wqShape[0]}) xk := ops.Zeros(tensor.Shape{seqLen, wkShape[0]}) xv := ops.Zeros(tensor.Shape{seqLen, wvShape[0]}) matmul.Linear(x, wq, xq) matmul.Linear(x, wk, xk) matmul.Linear(x, wv, xv) // Apply QK norm if available (Qwen3) if qNorm != nil { RMSNorm(xq, qNorm, cfg.RMSNormEps) } if kNorm != nil { RMSNorm(xk, kNorm, cfg.RMSNormEps) } // RoPE RoPE(xq, positions, cfg.HeadDim, cfg.RopeTheta) RoPE(xk, positions, cfg.HeadDim, cfg.RopeTheta) numQHeads := wqShape[0] / cfg.HeadDim numKVHeads := wkShape[0] / cfg.HeadDim attnOut := ops.Zeros(tensor.Shape{seqLen, wqShape[0]}) // KV Cache Logic if cache != nil { views, startPos, err := cache.Append(layerIdx, xk, xv) if err != nil { panic(err) } if pv, ok := cache.(kvcache.PackedViewsProvider); ok { pviews := pv.ViewsPacked(layerIdx) if len(pviews) > 0 { if err := CausalAttentionPackedBlocks(xq, pviews, attnOut, numQHeads, numKVHeads, cfg.HeadDim, startPos); err != nil { panic(err) } goto done } } if err := CausalAttentionBlocks(xq, views, attnOut, numQHeads, numKVHeads, cfg.HeadDim, startPos); err != nil { panic(err) } } else { CausalAttention(xq, xk, xv, attnOut, numQHeads, numKVHeads, cfg.HeadDim) } done: // Output projection out := ops.Zeros(tensor.Shape{seqLen, hiddenSize}) matmul.Linear(attnOut, wo, out) return out }