decoder.go 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. // Package arch provides transformer architecture implementations
  2. package arch
  3. import (
  4. "makarna/pkg/backend/cpu"
  5. "makarna/pkg/backend/cpu/matmul"
  6. "makarna/pkg/backend/cpu/nn"
  7. "makarna/pkg/backend/cpu/ops"
  8. "makarna/pkg/kvcache"
  9. "makarna/pkg/model"
  10. "makarna/pkg/tensor"
  11. )
  12. // DecoderConfig holds configuration for decoder-only transformer architectures
  13. type DecoderConfig struct {
  14. HiddenSize int
  15. NumHeads int
  16. NumKVHeads int
  17. NumLayers int
  18. Intermediate int
  19. VocabSize int
  20. HeadDim int
  21. RopeTheta float32
  22. RMSNormEps float32
  23. }
  24. // DecoderLayerWeights holds weight tensors for a decoder layer
  25. type DecoderLayerWeights struct {
  26. AttnNorm *cpu.Tensor
  27. Wq, Wk, Wv, Wo *cpu.Tensor
  28. QNorm, KNorm *cpu.Tensor // Optional: QK normalization (Qwen3)
  29. MlpNorm *cpu.Tensor
  30. WGate, WUp, WDown *cpu.Tensor
  31. }
  32. // DecoderForward runs forward pass for decoder-only transformer
  33. // Used by: Qwen3, Llama, Mistral, etc.
  34. // hidden: [seqLen, hiddenSize]
  35. // Returns: logits [seqLen, vocabSize]
  36. func DecoderForward(
  37. hidden *cpu.Tensor,
  38. layers []*DecoderLayerWeights,
  39. norm, output *cpu.Tensor,
  40. positions []int,
  41. cfg DecoderConfig,
  42. cache model.KVCache,
  43. ) *cpu.Tensor {
  44. seqLen := hidden.Shape()[0]
  45. var kv kvcache.KVCacheInterface
  46. if cache != nil {
  47. if as, ok := cache.(kvcache.KVCacheInterface); ok {
  48. kv = as
  49. }
  50. }
  51. attnCfg := nn.SelfAttentionConfig{
  52. HeadDim: cfg.HeadDim,
  53. RopeTheta: cfg.RopeTheta,
  54. RMSNormEps: cfg.RMSNormEps,
  55. }
  56. // Transformer layers
  57. for layerIdx, layer := range layers {
  58. hidden = nn.TransformerBlock(hidden,
  59. layer.AttnNorm, layer.MlpNorm, cfg.RMSNormEps,
  60. func(x *cpu.Tensor) *cpu.Tensor {
  61. return nn.SelfAttention(x,
  62. layer.Wq, layer.Wk, layer.Wv, layer.Wo,
  63. layer.QNorm, layer.KNorm,
  64. positions, attnCfg, kv, layerIdx,
  65. )
  66. },
  67. func(x *cpu.Tensor) *cpu.Tensor {
  68. return nn.SwiGLUMLP(x, layer.WGate, layer.WUp, layer.WDown)
  69. },
  70. )
  71. }
  72. // Final norm
  73. nn.RMSNorm(hidden, norm, cfg.RMSNormEps)
  74. // LM head
  75. logits := ops.Zeros(tensor.Shape{seqLen, cfg.VocabSize})
  76. matmul.Linear(hidden, output, logits)
  77. return logits
  78. }