| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- // Package arch provides transformer architecture implementations
- package arch
- import (
- "makarna/pkg/backend/cpu"
- "makarna/pkg/backend/cpu/matmul"
- "makarna/pkg/backend/cpu/nn"
- "makarna/pkg/backend/cpu/ops"
- "makarna/pkg/kvcache"
- "makarna/pkg/model"
- "makarna/pkg/tensor"
- )
- // DecoderConfig holds configuration for decoder-only transformer architectures
- type DecoderConfig struct {
- HiddenSize int
- NumHeads int
- NumKVHeads int
- NumLayers int
- Intermediate int
- VocabSize int
- HeadDim int
- RopeTheta float32
- RMSNormEps float32
- }
- // DecoderLayerWeights holds weight tensors for a decoder layer
- type DecoderLayerWeights struct {
- AttnNorm *cpu.Tensor
- Wq, Wk, Wv, Wo *cpu.Tensor
- QNorm, KNorm *cpu.Tensor // Optional: QK normalization (Qwen3)
- MlpNorm *cpu.Tensor
- WGate, WUp, WDown *cpu.Tensor
- }
- // DecoderForward runs forward pass for decoder-only transformer
- // Used by: Qwen3, Llama, Mistral, etc.
- // hidden: [seqLen, hiddenSize]
- // Returns: logits [seqLen, vocabSize]
- func DecoderForward(
- hidden *cpu.Tensor,
- layers []*DecoderLayerWeights,
- norm, output *cpu.Tensor,
- positions []int,
- cfg DecoderConfig,
- cache model.KVCache,
- ) *cpu.Tensor {
- seqLen := hidden.Shape()[0]
- var kv kvcache.KVCacheInterface
- if cache != nil {
- if as, ok := cache.(kvcache.KVCacheInterface); ok {
- kv = as
- }
- }
- attnCfg := nn.SelfAttentionConfig{
- HeadDim: cfg.HeadDim,
- RopeTheta: cfg.RopeTheta,
- RMSNormEps: cfg.RMSNormEps,
- }
- // Transformer layers
- for layerIdx, layer := range layers {
- hidden = nn.TransformerBlock(hidden,
- layer.AttnNorm, layer.MlpNorm, cfg.RMSNormEps,
- func(x *cpu.Tensor) *cpu.Tensor {
- return nn.SelfAttention(x,
- layer.Wq, layer.Wk, layer.Wv, layer.Wo,
- layer.QNorm, layer.KNorm,
- positions, attnCfg, kv, layerIdx,
- )
- },
- func(x *cpu.Tensor) *cpu.Tensor {
- return nn.SwiGLUMLP(x, layer.WGate, layer.WUp, layer.WDown)
- },
- )
- }
- // Final norm
- nn.RMSNorm(hidden, norm, cfg.RMSNormEps)
- // LM head
- logits := ops.Zeros(tensor.Shape{seqLen, cfg.VocabSize})
- matmul.Linear(hidden, output, logits)
- return logits
- }
|