package kimi_linear import ( "fmt" "strconv" "strings" "makarna/pkg/model" "makarna/pkg/tensor" ) type layerKind uint8 const ( layerKindUnknown layerKind = iota layerKindKDA layerKindFull ) type Layer struct { idx int kind layerKind inputLayernorm tensor.Tensor postAttnNorm tensor.Tensor kdaQProj tensor.Tensor kdaKProj tensor.Tensor kdaVProj tensor.Tensor kdaQConv tensor.Tensor kdaKConv tensor.Tensor kdaVConv tensor.Tensor kdaFAProj tensor.Tensor kdaFBProj tensor.Tensor kdaDTBias tensor.Tensor kdaBProj tensor.Tensor kdaALog tensor.Tensor kdaGAProj tensor.Tensor kdaGBProj tensor.Tensor kdaONorm tensor.Tensor kdaOProj tensor.Tensor mlaQProj tensor.Tensor mlaKVAProjWithMQA tensor.Tensor mlaKVALayernorm tensor.Tensor mlaKVBProj tensor.Tensor mlaOProj tensor.Tensor mlpGateProj tensor.Tensor mlpUpProj tensor.Tensor mlpDownProj tensor.Tensor moeGateW tensor.Tensor moeGateBias tensor.Tensor moeW1 []tensor.Tensor moeW2 []tensor.Tensor moeW3 []tensor.Tensor moeSharedGate tensor.Tensor moeSharedUp tensor.Tensor moeSharedDown tensor.Tensor } type Model struct { config *model.Config tokenEmb tensor.Tensor norm tensor.Tensor output tensor.Tensor layers []*Layer kdaNumHeads int kdaHeadDim int mlaNumHeads int mlaKHeadDim int mlaVHeadDim int isKDALayer map[int]bool tensors map[string]tensor.Tensor } func New(cfg *model.Config) (model.Model, error) { kdaCfg, _ := parseLinearAttnConfig(cfg) mlaCfg, _ := parseMLAConfig(cfg) m := &Model{ config: cfg, tensors: make(map[string]tensor.Tensor), kdaNumHeads: kdaCfg.NumHeads, kdaHeadDim: kdaCfg.HeadDim, mlaNumHeads: mlaCfg.NumHeads, mlaKHeadDim: mlaCfg.QKNopeHeadDim + mlaCfg.QKRopeHeadDim, mlaVHeadDim: mlaCfg.VHeadDim, isKDALayer: make(map[int]bool), } if cfg.NumLayers > 0 { m.layers = make([]*Layer, cfg.NumLayers) for i := range m.layers { m.layers[i] = &Layer{idx: i, kind: layerKindUnknown} } for _, oneBased := range kdaCfg.KDALayers { idx := oneBased - 1 if idx >= 0 && idx < cfg.NumLayers { m.isKDALayer[idx] = true m.layers[idx].kind = layerKindKDA } } for _, oneBased := range kdaCfg.FullAttnLayers { idx := oneBased - 1 if idx >= 0 && idx < cfg.NumLayers { if !m.isKDALayer[idx] { m.layers[idx].kind = layerKindFull } } } } return m, nil } func (m *Model) Config() *model.Config { return m.config } func (m *Model) Close() error { return nil } func (m *Model) Validate() error { if m.config == nil { return fmt.Errorf("kimi_linear: missing config") } if m.tokenEmb == nil { return fmt.Errorf("kimi_linear: missing model.embed_tokens.weight") } if m.norm == nil { return fmt.Errorf("kimi_linear: missing model.norm.weight") } if m.output == nil { return fmt.Errorf("kimi_linear: missing lm_head.weight") } if len(m.layers) != m.config.NumLayers { return fmt.Errorf("kimi_linear: layer count mismatch: have %d want %d", len(m.layers), m.config.NumLayers) } for i, l := range m.layers { if l == nil { return fmt.Errorf("kimi_linear: layer %d is nil", i) } if l.inputLayernorm == nil { return fmt.Errorf("kimi_linear: layer %d missing input_layernorm", i) } if l.postAttnNorm == nil { return fmt.Errorf("kimi_linear: layer %d missing post_attention_layernorm", i) } if l.kind == layerKindUnknown { return fmt.Errorf("kimi_linear: layer %d has unknown kind (check config kda/full layers and weight names)", i) } switch l.kind { case layerKindKDA: if l.kdaQProj == nil || l.kdaKProj == nil || l.kdaVProj == nil || l.kdaOProj == nil { return fmt.Errorf("kimi_linear: layer %d missing KDA projections", i) } if l.kdaQConv == nil || l.kdaKConv == nil || l.kdaVConv == nil { return fmt.Errorf("kimi_linear: layer %d missing KDA conv weights", i) } if l.kdaFAProj == nil || l.kdaFBProj == nil || l.kdaBProj == nil { return fmt.Errorf("kimi_linear: layer %d missing KDA gate/beta projections", i) } if l.kdaALog == nil || l.kdaDTBias == nil { return fmt.Errorf("kimi_linear: layer %d missing KDA A_log/dt_bias", i) } if l.kdaGAProj == nil || l.kdaGBProj == nil { return fmt.Errorf("kimi_linear: layer %d missing KDA o_norm gate projections", i) } if l.kdaONorm == nil { return fmt.Errorf("kimi_linear: layer %d missing KDA o_norm", i) } case layerKindFull: if l.mlaQProj == nil || l.mlaKVAProjWithMQA == nil || l.mlaKVALayernorm == nil || l.mlaKVBProj == nil || l.mlaOProj == nil { return fmt.Errorf("kimi_linear: layer %d missing MLA weights", i) } default: return fmt.Errorf("kimi_linear: layer %d has invalid kind=%d", i, l.kind) } // MLP must exist if MoE isn't present. useMoE := l.moeGateW != nil && len(l.moeW1) > 0 && len(l.moeW2) > 0 && len(l.moeW3) > 0 if !useMoE { if l.mlpGateProj == nil || l.mlpUpProj == nil || l.mlpDownProj == nil { return fmt.Errorf("kimi_linear: layer %d missing MLP weights", i) } } } return nil } // CacheType implements model.CacheFactory - KimiLinear uses recurrent state cache (KDA). func (m *Model) CacheType() model.CacheType { return model.CacheTypeRecurrent } // CreateCache implements model.CacheFactory - creates a KimiCache for this model. func (m *Model) CreateCache() (model.KVCache, error) { cfg := m.config kdaCfg, _ := parseLinearAttnConfig(cfg) mlaCfg, _ := parseMLAConfig(cfg) return NewKimiCache( cfg.NumLayers, kdaCfg.NumHeads, kdaCfg.HeadDim, kdaCfg.ShortConvKernel, mlaCfg.NumHeads, mlaCfg.QKNopeHeadDim+mlaCfg.QKRopeHeadDim, mlaCfg.VHeadDim, ) } func (m *Model) SetTensor(name string, t tensor.Tensor) error { switch name { case "model.embed_tokens.weight": m.tokenEmb = t case "model.norm.weight": m.norm = t case "lm_head.weight": m.output = t default: var idx int var suffix string if _, err := fmt.Sscanf(name, "model.layers.%d.%s", &idx, &suffix); err == nil && idx >= 0 && idx < len(m.layers) { m.layers[idx].setTensor(strings.TrimSuffix(suffix, ".weight"), name, t) return nil } if m.tensors == nil { m.tensors = make(map[string]tensor.Tensor) } m.tensors[name] = t } return nil } func (l *Layer) setTensor(suffix string, fullName string, t tensor.Tensor) { if strings.HasPrefix(suffix, "block_sparse_moe.") { if suffix == "block_sparse_moe.gate.weight" || suffix == "block_sparse_moe.gate" { l.moeGateW = t return } if suffix == "block_sparse_moe.gate.e_score_correction_bias" { l.moeGateBias = t return } if strings.HasPrefix(suffix, "block_sparse_moe.experts.") { rest := strings.TrimPrefix(suffix, "block_sparse_moe.experts.") parts := strings.Split(rest, ".") if len(parts) == 2 || len(parts) == 3 { idx, err := strconv.Atoi(parts[0]) if err == nil && idx >= 0 { wName := parts[len(parts)-1] need := idx + 1 if len(l.moeW1) < need { grow := make([]tensor.Tensor, need) copy(grow, l.moeW1) l.moeW1 = grow grow = make([]tensor.Tensor, need) copy(grow, l.moeW2) l.moeW2 = grow grow = make([]tensor.Tensor, need) copy(grow, l.moeW3) l.moeW3 = grow } switch wName { case "w1": l.moeW1[idx] = t return case "w2": l.moeW2[idx] = t return case "w3": l.moeW3[idx] = t return } } } } if suffix == "block_sparse_moe.shared_experts.gate_proj" { l.moeSharedGate = t return } if suffix == "block_sparse_moe.shared_experts.up_proj" { l.moeSharedUp = t return } if suffix == "block_sparse_moe.shared_experts.down_proj" { l.moeSharedDown = t return } } switch suffix { case "input_layernorm": l.inputLayernorm = t case "post_attention_layernorm": l.postAttnNorm = t case "mlp.gate_proj": l.mlpGateProj = t case "mlp.up_proj": l.mlpUpProj = t case "mlp.down_proj": l.mlpDownProj = t } // KDA switch suffix { case "self_attn.q_proj": l.kdaQProj = t case "self_attn.k_proj": l.kdaKProj = t case "self_attn.v_proj": l.kdaVProj = t case "self_attn.q_conv1d": l.kdaQConv = t case "self_attn.k_conv1d": l.kdaKConv = t case "self_attn.v_conv1d": l.kdaVConv = t case "self_attn.f_a_proj": l.kdaFAProj = t case "self_attn.f_b_proj": l.kdaFBProj = t case "self_attn.dt_bias": l.kdaDTBias = t case "self_attn.b_proj": l.kdaBProj = t case "self_attn.A_log": l.kdaALog = t case "self_attn.g_a_proj": l.kdaGAProj = t case "self_attn.g_b_proj": l.kdaGBProj = t case "self_attn.o_norm": l.kdaONorm = t case "self_attn.o_proj": l.kdaOProj = t } // MLA switch suffix { case "self_attn.q_proj": l.mlaQProj = t case "self_attn.kv_a_proj_with_mqa": l.mlaKVAProjWithMQA = t case "self_attn.kv_a_layernorm": l.mlaKVALayernorm = t case "self_attn.kv_b_proj": l.mlaKVBProj = t case "self_attn.o_proj": l.mlaOProj = t } _ = fullName } func init() { model.Register("KimiLinearForCausalLM", New) }