model.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. package kimi_linear
  2. import (
  3. "fmt"
  4. "strconv"
  5. "strings"
  6. "makarna/pkg/model"
  7. "makarna/pkg/tensor"
  8. )
  9. type layerKind uint8
  10. const (
  11. layerKindUnknown layerKind = iota
  12. layerKindKDA
  13. layerKindFull
  14. )
  15. type Layer struct {
  16. idx int
  17. kind layerKind
  18. inputLayernorm tensor.Tensor
  19. postAttnNorm tensor.Tensor
  20. kdaQProj tensor.Tensor
  21. kdaKProj tensor.Tensor
  22. kdaVProj tensor.Tensor
  23. kdaQConv tensor.Tensor
  24. kdaKConv tensor.Tensor
  25. kdaVConv tensor.Tensor
  26. kdaFAProj tensor.Tensor
  27. kdaFBProj tensor.Tensor
  28. kdaDTBias tensor.Tensor
  29. kdaBProj tensor.Tensor
  30. kdaALog tensor.Tensor
  31. kdaGAProj tensor.Tensor
  32. kdaGBProj tensor.Tensor
  33. kdaONorm tensor.Tensor
  34. kdaOProj tensor.Tensor
  35. mlaQProj tensor.Tensor
  36. mlaKVAProjWithMQA tensor.Tensor
  37. mlaKVALayernorm tensor.Tensor
  38. mlaKVBProj tensor.Tensor
  39. mlaOProj tensor.Tensor
  40. mlpGateProj tensor.Tensor
  41. mlpUpProj tensor.Tensor
  42. mlpDownProj tensor.Tensor
  43. moeGateW tensor.Tensor
  44. moeGateBias tensor.Tensor
  45. moeW1 []tensor.Tensor
  46. moeW2 []tensor.Tensor
  47. moeW3 []tensor.Tensor
  48. moeSharedGate tensor.Tensor
  49. moeSharedUp tensor.Tensor
  50. moeSharedDown tensor.Tensor
  51. }
  52. type Model struct {
  53. config *model.Config
  54. tokenEmb tensor.Tensor
  55. norm tensor.Tensor
  56. output tensor.Tensor
  57. layers []*Layer
  58. kdaNumHeads int
  59. kdaHeadDim int
  60. mlaNumHeads int
  61. mlaKHeadDim int
  62. mlaVHeadDim int
  63. isKDALayer map[int]bool
  64. tensors map[string]tensor.Tensor
  65. }
  66. func New(cfg *model.Config) (model.Model, error) {
  67. kdaCfg, _ := parseLinearAttnConfig(cfg)
  68. mlaCfg, _ := parseMLAConfig(cfg)
  69. m := &Model{
  70. config: cfg,
  71. tensors: make(map[string]tensor.Tensor),
  72. kdaNumHeads: kdaCfg.NumHeads,
  73. kdaHeadDim: kdaCfg.HeadDim,
  74. mlaNumHeads: mlaCfg.NumHeads,
  75. mlaKHeadDim: mlaCfg.QKNopeHeadDim + mlaCfg.QKRopeHeadDim,
  76. mlaVHeadDim: mlaCfg.VHeadDim,
  77. isKDALayer: make(map[int]bool),
  78. }
  79. if cfg.NumLayers > 0 {
  80. m.layers = make([]*Layer, cfg.NumLayers)
  81. for i := range m.layers {
  82. m.layers[i] = &Layer{idx: i, kind: layerKindUnknown}
  83. }
  84. for _, oneBased := range kdaCfg.KDALayers {
  85. idx := oneBased - 1
  86. if idx >= 0 && idx < cfg.NumLayers {
  87. m.isKDALayer[idx] = true
  88. m.layers[idx].kind = layerKindKDA
  89. }
  90. }
  91. for _, oneBased := range kdaCfg.FullAttnLayers {
  92. idx := oneBased - 1
  93. if idx >= 0 && idx < cfg.NumLayers {
  94. if !m.isKDALayer[idx] {
  95. m.layers[idx].kind = layerKindFull
  96. }
  97. }
  98. }
  99. }
  100. return m, nil
  101. }
  102. func (m *Model) Config() *model.Config { return m.config }
  103. func (m *Model) Close() error { return nil }
  104. func (m *Model) Validate() error {
  105. if m.config == nil {
  106. return fmt.Errorf("kimi_linear: missing config")
  107. }
  108. if m.tokenEmb == nil {
  109. return fmt.Errorf("kimi_linear: missing model.embed_tokens.weight")
  110. }
  111. if m.norm == nil {
  112. return fmt.Errorf("kimi_linear: missing model.norm.weight")
  113. }
  114. if m.output == nil {
  115. return fmt.Errorf("kimi_linear: missing lm_head.weight")
  116. }
  117. if len(m.layers) != m.config.NumLayers {
  118. return fmt.Errorf("kimi_linear: layer count mismatch: have %d want %d", len(m.layers), m.config.NumLayers)
  119. }
  120. for i, l := range m.layers {
  121. if l == nil {
  122. return fmt.Errorf("kimi_linear: layer %d is nil", i)
  123. }
  124. if l.inputLayernorm == nil {
  125. return fmt.Errorf("kimi_linear: layer %d missing input_layernorm", i)
  126. }
  127. if l.postAttnNorm == nil {
  128. return fmt.Errorf("kimi_linear: layer %d missing post_attention_layernorm", i)
  129. }
  130. if l.kind == layerKindUnknown {
  131. return fmt.Errorf("kimi_linear: layer %d has unknown kind (check config kda/full layers and weight names)", i)
  132. }
  133. switch l.kind {
  134. case layerKindKDA:
  135. if l.kdaQProj == nil || l.kdaKProj == nil || l.kdaVProj == nil || l.kdaOProj == nil {
  136. return fmt.Errorf("kimi_linear: layer %d missing KDA projections", i)
  137. }
  138. if l.kdaQConv == nil || l.kdaKConv == nil || l.kdaVConv == nil {
  139. return fmt.Errorf("kimi_linear: layer %d missing KDA conv weights", i)
  140. }
  141. if l.kdaFAProj == nil || l.kdaFBProj == nil || l.kdaBProj == nil {
  142. return fmt.Errorf("kimi_linear: layer %d missing KDA gate/beta projections", i)
  143. }
  144. if l.kdaALog == nil || l.kdaDTBias == nil {
  145. return fmt.Errorf("kimi_linear: layer %d missing KDA A_log/dt_bias", i)
  146. }
  147. if l.kdaGAProj == nil || l.kdaGBProj == nil {
  148. return fmt.Errorf("kimi_linear: layer %d missing KDA o_norm gate projections", i)
  149. }
  150. if l.kdaONorm == nil {
  151. return fmt.Errorf("kimi_linear: layer %d missing KDA o_norm", i)
  152. }
  153. case layerKindFull:
  154. if l.mlaQProj == nil || l.mlaKVAProjWithMQA == nil || l.mlaKVALayernorm == nil || l.mlaKVBProj == nil || l.mlaOProj == nil {
  155. return fmt.Errorf("kimi_linear: layer %d missing MLA weights", i)
  156. }
  157. default:
  158. return fmt.Errorf("kimi_linear: layer %d has invalid kind=%d", i, l.kind)
  159. }
  160. // MLP must exist if MoE isn't present.
  161. useMoE := l.moeGateW != nil && len(l.moeW1) > 0 && len(l.moeW2) > 0 && len(l.moeW3) > 0
  162. if !useMoE {
  163. if l.mlpGateProj == nil || l.mlpUpProj == nil || l.mlpDownProj == nil {
  164. return fmt.Errorf("kimi_linear: layer %d missing MLP weights", i)
  165. }
  166. }
  167. }
  168. return nil
  169. }
  170. // CacheType implements model.CacheFactory - KimiLinear uses recurrent state cache (KDA).
  171. func (m *Model) CacheType() model.CacheType {
  172. return model.CacheTypeRecurrent
  173. }
  174. // CreateCache implements model.CacheFactory - creates a KimiCache for this model.
  175. func (m *Model) CreateCache() (model.KVCache, error) {
  176. cfg := m.config
  177. kdaCfg, _ := parseLinearAttnConfig(cfg)
  178. mlaCfg, _ := parseMLAConfig(cfg)
  179. return NewKimiCache(
  180. cfg.NumLayers,
  181. kdaCfg.NumHeads,
  182. kdaCfg.HeadDim,
  183. kdaCfg.ShortConvKernel,
  184. mlaCfg.NumHeads,
  185. mlaCfg.QKNopeHeadDim+mlaCfg.QKRopeHeadDim,
  186. mlaCfg.VHeadDim,
  187. )
  188. }
  189. func (m *Model) SetTensor(name string, t tensor.Tensor) error {
  190. switch name {
  191. case "model.embed_tokens.weight":
  192. m.tokenEmb = t
  193. case "model.norm.weight":
  194. m.norm = t
  195. case "lm_head.weight":
  196. m.output = t
  197. default:
  198. var idx int
  199. var suffix string
  200. if _, err := fmt.Sscanf(name, "model.layers.%d.%s", &idx, &suffix); err == nil && idx >= 0 && idx < len(m.layers) {
  201. m.layers[idx].setTensor(strings.TrimSuffix(suffix, ".weight"), name, t)
  202. return nil
  203. }
  204. if m.tensors == nil {
  205. m.tensors = make(map[string]tensor.Tensor)
  206. }
  207. m.tensors[name] = t
  208. }
  209. return nil
  210. }
  211. func (l *Layer) setTensor(suffix string, fullName string, t tensor.Tensor) {
  212. if strings.HasPrefix(suffix, "block_sparse_moe.") {
  213. if suffix == "block_sparse_moe.gate.weight" || suffix == "block_sparse_moe.gate" {
  214. l.moeGateW = t
  215. return
  216. }
  217. if suffix == "block_sparse_moe.gate.e_score_correction_bias" {
  218. l.moeGateBias = t
  219. return
  220. }
  221. if strings.HasPrefix(suffix, "block_sparse_moe.experts.") {
  222. rest := strings.TrimPrefix(suffix, "block_sparse_moe.experts.")
  223. parts := strings.Split(rest, ".")
  224. if len(parts) == 2 || len(parts) == 3 {
  225. idx, err := strconv.Atoi(parts[0])
  226. if err == nil && idx >= 0 {
  227. wName := parts[len(parts)-1]
  228. need := idx + 1
  229. if len(l.moeW1) < need {
  230. grow := make([]tensor.Tensor, need)
  231. copy(grow, l.moeW1)
  232. l.moeW1 = grow
  233. grow = make([]tensor.Tensor, need)
  234. copy(grow, l.moeW2)
  235. l.moeW2 = grow
  236. grow = make([]tensor.Tensor, need)
  237. copy(grow, l.moeW3)
  238. l.moeW3 = grow
  239. }
  240. switch wName {
  241. case "w1":
  242. l.moeW1[idx] = t
  243. return
  244. case "w2":
  245. l.moeW2[idx] = t
  246. return
  247. case "w3":
  248. l.moeW3[idx] = t
  249. return
  250. }
  251. }
  252. }
  253. }
  254. if suffix == "block_sparse_moe.shared_experts.gate_proj" {
  255. l.moeSharedGate = t
  256. return
  257. }
  258. if suffix == "block_sparse_moe.shared_experts.up_proj" {
  259. l.moeSharedUp = t
  260. return
  261. }
  262. if suffix == "block_sparse_moe.shared_experts.down_proj" {
  263. l.moeSharedDown = t
  264. return
  265. }
  266. }
  267. switch suffix {
  268. case "input_layernorm":
  269. l.inputLayernorm = t
  270. case "post_attention_layernorm":
  271. l.postAttnNorm = t
  272. case "mlp.gate_proj":
  273. l.mlpGateProj = t
  274. case "mlp.up_proj":
  275. l.mlpUpProj = t
  276. case "mlp.down_proj":
  277. l.mlpDownProj = t
  278. }
  279. // KDA
  280. switch suffix {
  281. case "self_attn.q_proj":
  282. l.kdaQProj = t
  283. case "self_attn.k_proj":
  284. l.kdaKProj = t
  285. case "self_attn.v_proj":
  286. l.kdaVProj = t
  287. case "self_attn.q_conv1d":
  288. l.kdaQConv = t
  289. case "self_attn.k_conv1d":
  290. l.kdaKConv = t
  291. case "self_attn.v_conv1d":
  292. l.kdaVConv = t
  293. case "self_attn.f_a_proj":
  294. l.kdaFAProj = t
  295. case "self_attn.f_b_proj":
  296. l.kdaFBProj = t
  297. case "self_attn.dt_bias":
  298. l.kdaDTBias = t
  299. case "self_attn.b_proj":
  300. l.kdaBProj = t
  301. case "self_attn.A_log":
  302. l.kdaALog = t
  303. case "self_attn.g_a_proj":
  304. l.kdaGAProj = t
  305. case "self_attn.g_b_proj":
  306. l.kdaGBProj = t
  307. case "self_attn.o_norm":
  308. l.kdaONorm = t
  309. case "self_attn.o_proj":
  310. l.kdaOProj = t
  311. }
  312. // MLA
  313. switch suffix {
  314. case "self_attn.q_proj":
  315. l.mlaQProj = t
  316. case "self_attn.kv_a_proj_with_mqa":
  317. l.mlaKVAProjWithMQA = t
  318. case "self_attn.kv_a_layernorm":
  319. l.mlaKVALayernorm = t
  320. case "self_attn.kv_b_proj":
  321. l.mlaKVBProj = t
  322. case "self_attn.o_proj":
  323. l.mlaOProj = t
  324. }
  325. _ = fullName
  326. }
  327. func init() {
  328. model.Register("KimiLinearForCausalLM", New)
  329. }