| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361 |
- package kimi_linear
- import (
- "fmt"
- "strconv"
- "strings"
- "makarna/pkg/model"
- "makarna/pkg/tensor"
- )
- type layerKind uint8
- const (
- layerKindUnknown layerKind = iota
- layerKindKDA
- layerKindFull
- )
- type Layer struct {
- idx int
- kind layerKind
- inputLayernorm tensor.Tensor
- postAttnNorm tensor.Tensor
- kdaQProj tensor.Tensor
- kdaKProj tensor.Tensor
- kdaVProj tensor.Tensor
- kdaQConv tensor.Tensor
- kdaKConv tensor.Tensor
- kdaVConv tensor.Tensor
- kdaFAProj tensor.Tensor
- kdaFBProj tensor.Tensor
- kdaDTBias tensor.Tensor
- kdaBProj tensor.Tensor
- kdaALog tensor.Tensor
- kdaGAProj tensor.Tensor
- kdaGBProj tensor.Tensor
- kdaONorm tensor.Tensor
- kdaOProj tensor.Tensor
- mlaQProj tensor.Tensor
- mlaKVAProjWithMQA tensor.Tensor
- mlaKVALayernorm tensor.Tensor
- mlaKVBProj tensor.Tensor
- mlaOProj tensor.Tensor
- mlpGateProj tensor.Tensor
- mlpUpProj tensor.Tensor
- mlpDownProj tensor.Tensor
- moeGateW tensor.Tensor
- moeGateBias tensor.Tensor
- moeW1 []tensor.Tensor
- moeW2 []tensor.Tensor
- moeW3 []tensor.Tensor
- moeSharedGate tensor.Tensor
- moeSharedUp tensor.Tensor
- moeSharedDown tensor.Tensor
- }
- type Model struct {
- config *model.Config
- tokenEmb tensor.Tensor
- norm tensor.Tensor
- output tensor.Tensor
- layers []*Layer
- kdaNumHeads int
- kdaHeadDim int
- mlaNumHeads int
- mlaKHeadDim int
- mlaVHeadDim int
- isKDALayer map[int]bool
- tensors map[string]tensor.Tensor
- }
- func New(cfg *model.Config) (model.Model, error) {
- kdaCfg, _ := parseLinearAttnConfig(cfg)
- mlaCfg, _ := parseMLAConfig(cfg)
- m := &Model{
- config: cfg,
- tensors: make(map[string]tensor.Tensor),
- kdaNumHeads: kdaCfg.NumHeads,
- kdaHeadDim: kdaCfg.HeadDim,
- mlaNumHeads: mlaCfg.NumHeads,
- mlaKHeadDim: mlaCfg.QKNopeHeadDim + mlaCfg.QKRopeHeadDim,
- mlaVHeadDim: mlaCfg.VHeadDim,
- isKDALayer: make(map[int]bool),
- }
- if cfg.NumLayers > 0 {
- m.layers = make([]*Layer, cfg.NumLayers)
- for i := range m.layers {
- m.layers[i] = &Layer{idx: i, kind: layerKindUnknown}
- }
- for _, oneBased := range kdaCfg.KDALayers {
- idx := oneBased - 1
- if idx >= 0 && idx < cfg.NumLayers {
- m.isKDALayer[idx] = true
- m.layers[idx].kind = layerKindKDA
- }
- }
- for _, oneBased := range kdaCfg.FullAttnLayers {
- idx := oneBased - 1
- if idx >= 0 && idx < cfg.NumLayers {
- if !m.isKDALayer[idx] {
- m.layers[idx].kind = layerKindFull
- }
- }
- }
- }
- return m, nil
- }
- func (m *Model) Config() *model.Config { return m.config }
- func (m *Model) Close() error { return nil }
- func (m *Model) Validate() error {
- if m.config == nil {
- return fmt.Errorf("kimi_linear: missing config")
- }
- if m.tokenEmb == nil {
- return fmt.Errorf("kimi_linear: missing model.embed_tokens.weight")
- }
- if m.norm == nil {
- return fmt.Errorf("kimi_linear: missing model.norm.weight")
- }
- if m.output == nil {
- return fmt.Errorf("kimi_linear: missing lm_head.weight")
- }
- if len(m.layers) != m.config.NumLayers {
- return fmt.Errorf("kimi_linear: layer count mismatch: have %d want %d", len(m.layers), m.config.NumLayers)
- }
- for i, l := range m.layers {
- if l == nil {
- return fmt.Errorf("kimi_linear: layer %d is nil", i)
- }
- if l.inputLayernorm == nil {
- return fmt.Errorf("kimi_linear: layer %d missing input_layernorm", i)
- }
- if l.postAttnNorm == nil {
- return fmt.Errorf("kimi_linear: layer %d missing post_attention_layernorm", i)
- }
- if l.kind == layerKindUnknown {
- return fmt.Errorf("kimi_linear: layer %d has unknown kind (check config kda/full layers and weight names)", i)
- }
- switch l.kind {
- case layerKindKDA:
- if l.kdaQProj == nil || l.kdaKProj == nil || l.kdaVProj == nil || l.kdaOProj == nil {
- return fmt.Errorf("kimi_linear: layer %d missing KDA projections", i)
- }
- if l.kdaQConv == nil || l.kdaKConv == nil || l.kdaVConv == nil {
- return fmt.Errorf("kimi_linear: layer %d missing KDA conv weights", i)
- }
- if l.kdaFAProj == nil || l.kdaFBProj == nil || l.kdaBProj == nil {
- return fmt.Errorf("kimi_linear: layer %d missing KDA gate/beta projections", i)
- }
- if l.kdaALog == nil || l.kdaDTBias == nil {
- return fmt.Errorf("kimi_linear: layer %d missing KDA A_log/dt_bias", i)
- }
- if l.kdaGAProj == nil || l.kdaGBProj == nil {
- return fmt.Errorf("kimi_linear: layer %d missing KDA o_norm gate projections", i)
- }
- if l.kdaONorm == nil {
- return fmt.Errorf("kimi_linear: layer %d missing KDA o_norm", i)
- }
- case layerKindFull:
- if l.mlaQProj == nil || l.mlaKVAProjWithMQA == nil || l.mlaKVALayernorm == nil || l.mlaKVBProj == nil || l.mlaOProj == nil {
- return fmt.Errorf("kimi_linear: layer %d missing MLA weights", i)
- }
- default:
- return fmt.Errorf("kimi_linear: layer %d has invalid kind=%d", i, l.kind)
- }
- // MLP must exist if MoE isn't present.
- useMoE := l.moeGateW != nil && len(l.moeW1) > 0 && len(l.moeW2) > 0 && len(l.moeW3) > 0
- if !useMoE {
- if l.mlpGateProj == nil || l.mlpUpProj == nil || l.mlpDownProj == nil {
- return fmt.Errorf("kimi_linear: layer %d missing MLP weights", i)
- }
- }
- }
- return nil
- }
- // CacheType implements model.CacheFactory - KimiLinear uses recurrent state cache (KDA).
- func (m *Model) CacheType() model.CacheType {
- return model.CacheTypeRecurrent
- }
- // CreateCache implements model.CacheFactory - creates a KimiCache for this model.
- func (m *Model) CreateCache() (model.KVCache, error) {
- cfg := m.config
- kdaCfg, _ := parseLinearAttnConfig(cfg)
- mlaCfg, _ := parseMLAConfig(cfg)
- return NewKimiCache(
- cfg.NumLayers,
- kdaCfg.NumHeads,
- kdaCfg.HeadDim,
- kdaCfg.ShortConvKernel,
- mlaCfg.NumHeads,
- mlaCfg.QKNopeHeadDim+mlaCfg.QKRopeHeadDim,
- mlaCfg.VHeadDim,
- )
- }
- func (m *Model) SetTensor(name string, t tensor.Tensor) error {
- switch name {
- case "model.embed_tokens.weight":
- m.tokenEmb = t
- case "model.norm.weight":
- m.norm = t
- case "lm_head.weight":
- m.output = t
- default:
- var idx int
- var suffix string
- if _, err := fmt.Sscanf(name, "model.layers.%d.%s", &idx, &suffix); err == nil && idx >= 0 && idx < len(m.layers) {
- m.layers[idx].setTensor(strings.TrimSuffix(suffix, ".weight"), name, t)
- return nil
- }
- if m.tensors == nil {
- m.tensors = make(map[string]tensor.Tensor)
- }
- m.tensors[name] = t
- }
- return nil
- }
- func (l *Layer) setTensor(suffix string, fullName string, t tensor.Tensor) {
- if strings.HasPrefix(suffix, "block_sparse_moe.") {
- if suffix == "block_sparse_moe.gate.weight" || suffix == "block_sparse_moe.gate" {
- l.moeGateW = t
- return
- }
- if suffix == "block_sparse_moe.gate.e_score_correction_bias" {
- l.moeGateBias = t
- return
- }
- if strings.HasPrefix(suffix, "block_sparse_moe.experts.") {
- rest := strings.TrimPrefix(suffix, "block_sparse_moe.experts.")
- parts := strings.Split(rest, ".")
- if len(parts) == 2 || len(parts) == 3 {
- idx, err := strconv.Atoi(parts[0])
- if err == nil && idx >= 0 {
- wName := parts[len(parts)-1]
- need := idx + 1
- if len(l.moeW1) < need {
- grow := make([]tensor.Tensor, need)
- copy(grow, l.moeW1)
- l.moeW1 = grow
- grow = make([]tensor.Tensor, need)
- copy(grow, l.moeW2)
- l.moeW2 = grow
- grow = make([]tensor.Tensor, need)
- copy(grow, l.moeW3)
- l.moeW3 = grow
- }
- switch wName {
- case "w1":
- l.moeW1[idx] = t
- return
- case "w2":
- l.moeW2[idx] = t
- return
- case "w3":
- l.moeW3[idx] = t
- return
- }
- }
- }
- }
- if suffix == "block_sparse_moe.shared_experts.gate_proj" {
- l.moeSharedGate = t
- return
- }
- if suffix == "block_sparse_moe.shared_experts.up_proj" {
- l.moeSharedUp = t
- return
- }
- if suffix == "block_sparse_moe.shared_experts.down_proj" {
- l.moeSharedDown = t
- return
- }
- }
- switch suffix {
- case "input_layernorm":
- l.inputLayernorm = t
- case "post_attention_layernorm":
- l.postAttnNorm = t
- case "mlp.gate_proj":
- l.mlpGateProj = t
- case "mlp.up_proj":
- l.mlpUpProj = t
- case "mlp.down_proj":
- l.mlpDownProj = t
- }
- // KDA
- switch suffix {
- case "self_attn.q_proj":
- l.kdaQProj = t
- case "self_attn.k_proj":
- l.kdaKProj = t
- case "self_attn.v_proj":
- l.kdaVProj = t
- case "self_attn.q_conv1d":
- l.kdaQConv = t
- case "self_attn.k_conv1d":
- l.kdaKConv = t
- case "self_attn.v_conv1d":
- l.kdaVConv = t
- case "self_attn.f_a_proj":
- l.kdaFAProj = t
- case "self_attn.f_b_proj":
- l.kdaFBProj = t
- case "self_attn.dt_bias":
- l.kdaDTBias = t
- case "self_attn.b_proj":
- l.kdaBProj = t
- case "self_attn.A_log":
- l.kdaALog = t
- case "self_attn.g_a_proj":
- l.kdaGAProj = t
- case "self_attn.g_b_proj":
- l.kdaGBProj = t
- case "self_attn.o_norm":
- l.kdaONorm = t
- case "self_attn.o_proj":
- l.kdaOProj = t
- }
- // MLA
- switch suffix {
- case "self_attn.q_proj":
- l.mlaQProj = t
- case "self_attn.kv_a_proj_with_mqa":
- l.mlaKVAProjWithMQA = t
- case "self_attn.kv_a_layernorm":
- l.mlaKVALayernorm = t
- case "self_attn.kv_b_proj":
- l.mlaKVBProj = t
- case "self_attn.o_proj":
- l.mlaOProj = t
- }
- _ = fullName
- }
- func init() {
- model.Register("KimiLinearForCausalLM", New)
- }
|