model.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. // Package qwen3 implements the Qwen3 model family with device-agnostic execution.
  2. // Supports: Qwen3-0.6B, Qwen3-1.7B, Qwen3-4B, Qwen3-8B, Qwen3-14B, Qwen3-32B
  3. // The model works with both CPU and GPU placement - the compute package handles dispatching.
  4. package qwen3
  5. import (
  6. "fmt"
  7. "makarna/pkg/backend/cpu"
  8. "makarna/pkg/model"
  9. "makarna/pkg/tensor"
  10. )
  11. // Model implements the Qwen3 architecture
  12. type Model struct {
  13. config *model.Config
  14. tokenEmb tensor.Tensor
  15. layers []*Layer
  16. norm tensor.Tensor
  17. output tensor.Tensor
  18. }
  19. // Layer represents a single Qwen3 transformer layer
  20. type Layer struct {
  21. idx int
  22. attnNorm tensor.Tensor
  23. wq, wk, wv, wo tensor.Tensor
  24. qNorm, kNorm tensor.Tensor
  25. mlpNorm tensor.Tensor
  26. wGate, wUp, wDown tensor.Tensor
  27. }
  28. // New creates a new Qwen3 model
  29. func New(cfg *model.Config) (model.Model, error) {
  30. m := &Model{config: cfg, layers: make([]*Layer, cfg.NumLayers)}
  31. for i := range m.layers {
  32. m.layers[i] = &Layer{idx: i}
  33. }
  34. return m, nil
  35. }
  36. func (m *Model) Config() *model.Config { return m.config }
  37. func (m *Model) Close() error { return nil }
  38. func (m *Model) SetTensor(name string, t tensor.Tensor) error {
  39. switch name {
  40. case "model.embed_tokens.weight":
  41. m.tokenEmb = t
  42. case "model.norm.weight":
  43. m.norm = t
  44. case "lm_head.weight":
  45. m.output = t
  46. default:
  47. var idx int
  48. var suffix string
  49. if _, err := fmt.Sscanf(name, "model.layers.%d.%s", &idx, &suffix); err == nil && idx < len(m.layers) {
  50. m.layers[idx].setTensor(suffix, t)
  51. }
  52. }
  53. return nil
  54. }
  55. func (l *Layer) setTensor(name string, t tensor.Tensor) {
  56. switch name {
  57. case "input_layernorm.weight":
  58. l.attnNorm = t
  59. case "self_attn.q_proj.weight":
  60. l.wq = t
  61. case "self_attn.k_proj.weight":
  62. l.wk = t
  63. case "self_attn.v_proj.weight":
  64. l.wv = t
  65. case "self_attn.o_proj.weight":
  66. l.wo = t
  67. case "self_attn.q_norm.weight":
  68. l.qNorm = t
  69. case "self_attn.k_norm.weight":
  70. l.kNorm = t
  71. case "post_attention_layernorm.weight":
  72. l.mlpNorm = t
  73. case "mlp.gate_proj.weight":
  74. l.wGate = t
  75. case "mlp.up_proj.weight":
  76. l.wUp = t
  77. case "mlp.down_proj.weight":
  78. l.wDown = t
  79. }
  80. }
  81. // asCPU safely converts a tensor to *cpu.Tensor
  82. // This is a transitional helper - eventually all ops will be device-aware
  83. func asCPU(t tensor.Tensor) *cpu.Tensor {
  84. if ct, ok := t.(*cpu.Tensor); ok {
  85. return ct
  86. }
  87. panic(fmt.Sprintf("expected *cpu.Tensor, got %T", t))
  88. }
  89. // Forward is implemented in forward_device.go to use device-aware operations.
  90. // This allows the same code to work with both CPU and GPU without duplication.
  91. func init() {
  92. model.Register("Qwen3ForCausalLM", New)
  93. }