cturan 2 месяцев назад
Родитель
Сommit
54ed0123a6

+ 4 - 0
.gitignore

@@ -152,3 +152,7 @@ poetry.toml
 # IDE
 *.code-workspace
 .windsurf/
+
+
+# Devfiles
+devfiles/

+ 105 - 0
common/chat.cpp

@@ -579,6 +579,21 @@ common_chat_templates_ptr common_chat_templates_init(
             "{%- if false %}");
     }
 
+    // Fix MiniMax-M2 template bug: reasoning_content should be rendered for ALL assistant messages, not just after last user
+    // Original template has: {%- if reasoning_content and loop.index0 > ns.last_user_index -%}
+    // This causes reasoning from history to be lost, breaking interleaved thinking performance
+    // TODO remove this once the template is fixed, I just don't have server to upload gguf's yet.
+    if (default_template_src.find("]~!b[") != std::string::npos 
+            && default_template_src.find("]~b]") != std::string::npos
+            && default_template_src.find("loop.index0 > ns.last_user_index") != std::string::npos) {
+        LOG_INF("Detected MiniMax-M2 template with reasoning_content bug, applying automatic fix...\n");
+        // Remove the condition that prevents rendering reasoning_content for historical messages
+        string_replace_all(default_template_src,
+            "{%- if reasoning_content and loop.index0 > ns.last_user_index -%}",
+            "{%- if reasoning_content -%}");
+        LOG_INF("MiniMax-M2 template fixed: reasoning_content will now be preserved in conversation history\n");
+    }
+
     std::string token_bos = bos_token_override;
     std::string token_eos = eos_token_override;
     bool add_bos = false;
@@ -640,6 +655,7 @@ const char * common_chat_format_name(common_chat_format format) {
         case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
         case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
         case COMMON_CHAT_FORMAT_APERTUS: return "Apertus";
+        case COMMON_CHAT_FORMAT_MINIMAX_M2: return "MiniMax-M2";
         default:
             throw std::runtime_error("Unknown chat format");
     }
@@ -1603,6 +1619,33 @@ static common_chat_params common_chat_params_init_deepseek_v3_1(const common_cha
     return data;
 }
 
+static common_chat_params common_chat_params_init_minimax_m2(const common_chat_template & tmpl, const struct templates_params & params) {
+    common_chat_params data;
+    data.prompt = apply(tmpl, params);
+    data.format = COMMON_CHAT_FORMAT_MINIMAX_M2;
+    
+    // Handle thinking tags based on prompt ending
+    if (string_ends_with(data.prompt, "<think>\n")) {
+        if (!params.enable_thinking) {
+            // Close the thinking tag immediately if thinking is disabled
+            data.prompt += "</think>\n\n";
+        } else {
+            // Mark thinking as forced open (template started with <think>)
+            data.thinking_forced_open = true;
+        }
+    }
+    
+    // Preserve MiniMax-M2 special tokens
+    data.preserved_tokens = {
+        "<think>",
+        "</think>",
+        "<minimax:tool_call>",
+        "</minimax:tool_call>",
+    };
+    
+    return data;
+}
+
 static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
     builder.try_parse_reasoning("<think>", "</think>");
     if (!builder.syntax().parse_tool_calls) {
@@ -1624,6 +1667,60 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
         tool_calls_end);
 }
 
+static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) {
+    // MiniMax-M2 uses <think>...</think> tags for reasoning content
+    builder.try_parse_reasoning("<think>", "</think>");
+    
+    if (!builder.syntax().parse_tool_calls) {
+        builder.add_content(builder.consume_rest());
+        return;
+    }
+
+    // MiniMax-M2 uses <minimax:tool_call>...</minimax:tool_call> for tool calls
+    // Format: <invoke name="tool-name"><parameter name="key">value</parameter>...</invoke>
+    static const common_regex tool_call_begin_regex(regex_escape("<minimax:tool_call>"));
+    static const common_regex tool_call_end_regex(regex_escape("</minimax:tool_call>"));
+    static const common_regex invoke_begin_regex(regex_escape("<invoke name=\"") + "([^\"]+)" + regex_escape("\">"));
+    static const common_regex invoke_end_regex(regex_escape("</invoke>"));
+    static const common_regex param_regex(regex_escape("<parameter name=\"") + "([^\"]+)" + regex_escape("\">") + "([\\s\\S]*?)" + regex_escape("</parameter>"));
+    
+    if (builder.try_consume_regex(tool_call_begin_regex)) {
+        const auto & input = builder.input();
+        // Parse multiple <invoke> blocks within tool_call
+        while (auto invoke_match = builder.try_consume_regex(invoke_begin_regex)) {
+            auto & tool_name_range = invoke_match->groups[1];
+            std::string tool_name = input.substr(tool_name_range.begin, tool_name_range.end - tool_name_range.begin);
+            json arguments = json::object();
+            
+            // Parse parameters until </invoke>
+            while (!builder.try_consume_regex(invoke_end_regex)) {
+                if (auto param_match = builder.try_consume_regex(param_regex)) {
+                    auto & param_name_range = param_match->groups[1];
+                    auto & param_value_range = param_match->groups[2];
+                    std::string param_name = input.substr(param_name_range.begin, param_name_range.end - param_name_range.begin);
+                    std::string param_value = input.substr(param_value_range.begin, param_value_range.end - param_value_range.begin);
+                    
+                    // Try to parse as JSON, fallback to string
+                    try {
+                        arguments[param_name] = json::parse(param_value);
+                    } catch (...) {
+                        arguments[param_name] = param_value;
+                    }
+                } else {
+                    // If no more params, expect </invoke>
+                    break;
+                }
+            }
+            
+            builder.add_tool_call(tool_name, "", arguments.dump());
+        }
+        builder.consume_regex(tool_call_end_regex);
+    } else {
+        // No tool calls, just regular content
+        builder.add_content(builder.consume_rest());
+    }
+}
+
 static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) {
     static const common_regex function_regex("(?:<|tool▁call▁begin|>)?([^\\n<]+)(?:<|tool▁sep|>)");
 
@@ -2748,6 +2845,11 @@ static common_chat_params common_chat_templates_apply_jinja(
         return common_chat_params_init_apertus(tmpl, params);
     }
 
+    // MiniMax-M2 format detection
+    if (src.find("]~!b[") != std::string::npos && src.find("]~b]") != std::string::npos) {
+        return common_chat_params_init_minimax_m2(tmpl, params);
+    }
+
     // Use generic handler when mixing tools + JSON schema.
     // TODO: support that mix in handlers below.
     if ((params.tools.is_array() && params.json_schema.is_object())) {
@@ -2926,6 +3028,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
         case COMMON_CHAT_FORMAT_APERTUS:
             common_chat_parse_apertus(builder);
             break;
+        case COMMON_CHAT_FORMAT_MINIMAX_M2:
+            common_chat_parse_minimax_m2(builder);
+            break;
         default:
             throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
     }

+ 1 - 0
common/chat.h

@@ -116,6 +116,7 @@ enum common_chat_format {
     COMMON_CHAT_FORMAT_SEED_OSS,
     COMMON_CHAT_FORMAT_NEMOTRON_V2,
     COMMON_CHAT_FORMAT_APERTUS,
+    COMMON_CHAT_FORMAT_MINIMAX_M2,
 
     COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
 };

+ 141 - 0
convert_hf_to_gguf.py

@@ -928,6 +928,9 @@ class TextModel(ModelBase):
         if chkhsh == "3ce83efda5659b07b1ad37ca97ca5797ea4285d9b9ab0dc679e4a720c9da7454":
             # ref: https://huggingface.co/openai-community/gpt2
             res = "gpt-2"
+        if chkhsh == "f4f37b6c8eb9ea29b3eac6bb8c8487c5ab7885f8d8022e67edc1c68ce8403e95":
+            # ref: MiniMax M2 (GPT2Tokenizer) – recognize as GPT-2 BPE pre-tokenizer
+            res = "gpt-2"
         if chkhsh == "32d85c31273f8019248f2559fed492d929ea28b17e51d81d3bb36fff23ca72b3":
             # ref: https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b
             res = "stablelm2"
@@ -4029,6 +4032,144 @@ class GPT2Model(TextModel):
         return tensors
 
 
+@ModelBase.register("MiniMaxM2ForCausalLM", "MiniMaxM2MiniForCausalLM")
+class MiniMaxM2Model(TextModel):
+    model_arch = gguf.MODEL_ARCH.MINIMAX_M2
+
+    _experts: list[dict[str, Tensor]] | None = None
+
+    def set_vocab(self):
+        # Try SentencePiece, then Llama-HF, then GPT2 (merges+vocab)
+        try:
+            self._set_vocab_sentencepiece()
+        except FileNotFoundError:
+            try:
+                self._set_vocab_llama_hf()
+            except FileNotFoundError:
+                self._set_vocab_gpt2()
+
+        tokenizer_config_file = self.dir_model / "tokenizer_config.json"
+        if tokenizer_config_file.is_file():
+            with open(tokenizer_config_file, "r", encoding="utf-8") as f:
+                tokenizer_config_json = json.load(f)
+                if "add_prefix_space" in tokenizer_config_json:
+                    self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"])
+
+    def set_gguf_parameters(self):
+        hparams = self.hparams
+
+        block_count = hparams["num_hidden_layers"]
+        n_embd = hparams["hidden_size"]
+        n_head = hparams["num_attention_heads"]
+        n_head_kv = hparams["num_key_value_heads"]
+        
+        # MiniMax M2 uses partial RoPE: head_dim=128 but only rotary_dim=64 gets RoPE applied
+        rope_dim = hparams.get("rotary_dim", n_embd // n_head)
+
+        # MiniMax M2 expert FFN uses intermediate_size (1536), NOT mlp_intermediate_size (8192)
+        # mlp_intermediate_size in config.json is misleading/unused
+        n_ff = hparams.get("intermediate_size", 8192)
+
+        self.gguf_writer.add_block_count(block_count)
+        self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
+        self.gguf_writer.add_embedding_length(n_embd)
+        self.gguf_writer.add_feed_forward_length(n_ff)
+        self.gguf_writer.add_head_count(n_head)
+        self.gguf_writer.add_head_count_kv(n_head_kv)
+        self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
+        self.gguf_writer.add_rope_dimension_count(rope_dim)
+        self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10000.0))
+        self.gguf_writer.add_file_type(self.ftype)
+
+        if hparams.get("num_local_experts", 0) > 0:
+            self.gguf_writer.add_expert_count(hparams["num_local_experts"])
+            self.gguf_writer.add_expert_used_count(hparams["num_experts_per_tok"])
+            self.gguf_writer.add_expert_feed_forward_length(n_ff)
+            self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
+
+        if hparams.get("use_qk_norm", False):
+            self.gguf_writer.add_bool(gguf.Keys.Attention.QK_NORM.format(arch=self.gguf_writer.arch), True)
+        if (eps := hparams.get("attention_qk_norm_eps")) is not None:
+            self.gguf_writer.add_float32(gguf.Keys.Attention.QK_NORM_EPS.format(arch=self.gguf_writer.arch), eps)
+        
+        # Set head dimensions explicitly (critical for GQA models with head_dim != n_embd/n_head)
+        head_dim = hparams.get("head_dim", hparams["hidden_size"] // hparams["num_attention_heads"])
+        self.gguf_writer.add_uint32(gguf.Keys.Attention.KEY_LENGTH.format(arch=self.gguf_writer.arch), head_dim)
+        self.gguf_writer.add_uint32(gguf.Keys.Attention.VALUE_LENGTH.format(arch=self.gguf_writer.arch), head_dim)
+
+    def prepare_metadata(self, vocab_only: bool):
+        super().prepare_metadata(vocab_only=vocab_only)
+        # Override size label to '230x10B' format (total params in 10B × active 10B)
+        total_params = self.gguf_writer.get_total_parameter_count()[0]
+        total_b = int(round(total_params / 1e10) * 10)  # round to nearest 10B
+        size_label = f"{total_b}x10B"
+        self.gguf_writer.add_size_label(size_label)
+
+    # Force GPT-2 style BPE pre-tokenizer for MiniMax M2
+    def get_vocab_base_pre(self, tokenizer) -> str:  # type: ignore[override]
+        return "gpt-2"
+
+    def tensor_force_quant(self, name, new_name, bid, n_dims):
+        del bid, n_dims
+        if name.endswith(""):
+            return False
+        return super().tensor_force_quant(name, new_name, bid, n_dims)
+
+    def _flush_experts(self, bid: int, n_experts: int) -> Iterable[tuple[str, Tensor]]:
+        assert self._experts is not None
+        tensors: list[tuple[str, Tensor]] = []
+        buckets = self._experts[bid]
+
+        def _stack(prefix: str) -> Tensor:
+            parts: list[Tensor] = []
+            for xid in range(n_experts):
+                key = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{prefix}.weight"
+                parts.append(buckets[key])
+                del buckets[key]
+            # torch dims: [n_expert, rows, cols]
+            return torch.stack(parts, dim=0)
+
+        # Provide torch dims so GGUF/ggml (which reverses dims) ends up with:
+        # gate/up -> [n_embd, n_ff, n_expert], down -> [n_ff, n_embd, n_expert]
+        # w1, w3 in HF are typically [n_ff, n_embd]; w2 is [n_embd, n_ff].
+        gate = _stack("w1")            # [n_expert, n_ff, n_embd]
+        up   = _stack("w3")            # [n_expert, n_ff, n_embd]
+        down = _stack("w2")           # [n_expert, n_embd, n_ff]
+
+        tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), gate))
+        tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up))
+        tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN_EXP, bid), down))
+        return tensors
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        # Skip FP8 quantization scale tensors - they will be handled separately if needed
+        if "weight_scale_inv" in name:
+            return []
+
+        # MoE experts aggregation
+        if name.find("block_sparse_moe.experts") != -1:
+            assert bid is not None
+            n_experts = self.hparams["num_local_experts"]
+            if self._experts is None:
+                self._experts = [{} for _ in range(self.block_count)]
+            self._experts[bid][name] = data_torch
+            if len(self._experts[bid]) >= n_experts * 3:
+                return self._flush_experts(bid, n_experts)
+            return []
+
+        if name.endswith("e_score_correction_bias"):
+            name = name.replace("e_score_correction_bias", "e_score_correction.bias")
+
+        return [(self.map_tensor_name(name), data_torch)]
+
+    def prepare_tensors(self):
+        super().prepare_tensors()
+
+        if self._experts is not None:
+            leftovers = [k for d in self._experts for k in d.keys()]
+            if leftovers:
+                raise ValueError(f"Unprocessed experts: {leftovers}")
+
 @ModelBase.register("PhiForCausalLM")
 class Phi2Model(TextModel):
     model_arch = gguf.MODEL_ARCH.PHI2

+ 22 - 0
gguf-py/gguf/constants.py

@@ -138,6 +138,8 @@ class Keys:
         HEAD_COUNT_KV                = "{arch}.attention.head_count_kv"
         MAX_ALIBI_BIAS               = "{arch}.attention.max_alibi_bias"
         CLAMP_KQV                    = "{arch}.attention.clamp_kqv"
+        QK_NORM                      = "{arch}.attention.qk_norm"
+        QK_NORM_EPS                  = "{arch}.attention.qk_norm_eps"
         KEY_LENGTH                   = "{arch}.attention.key_length"
         VALUE_LENGTH                 = "{arch}.attention.value_length"
         LAYERNORM_EPS                = "{arch}.attention.layer_norm_epsilon"
@@ -420,6 +422,7 @@ class MODEL_ARCH(IntEnum):
     SEED_OSS         = auto()
     GROVEMOE         = auto()
     APERTUS          = auto()
+    MINIMAX_M2       = auto()
 
 
 class VISION_PROJECTOR_TYPE(IntEnum):
@@ -766,6 +769,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.SEED_OSS:         "seed_oss",
     MODEL_ARCH.GROVEMOE:         "grovemoe",
     MODEL_ARCH.APERTUS:          "apertus",
+    MODEL_ARCH.MINIMAX_M2:       "minimax-m2",
 }
 
 VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -1766,6 +1770,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.LAUREL_R,
         MODEL_TENSOR.LAUREL_POST_NORM,
     ],
+    MODEL_ARCH.MINIMAX_M2: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.ATTN_Q_NORM,
+        MODEL_TENSOR.ATTN_K_NORM,
+        MODEL_TENSOR.FFN_GATE_INP,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE_EXP,
+        MODEL_TENSOR.FFN_UP_EXP,
+        MODEL_TENSOR.FFN_DOWN_EXP,
+        MODEL_TENSOR.FFN_EXP_PROBS_B,
+    ],
     MODEL_ARCH.GEMMA_EMBEDDING: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT,

+ 4 - 3
gguf-py/gguf/tensor_mapping.py

@@ -179,7 +179,7 @@ class TensorNameMap:
             "transformer.h.{bid}.mixer.Wqkv",                                      # phi2
             "encoder.layers.{bid}.attn.Wqkv",                                      # nomic-bert
             "encoder.layers.{bid}.mixer.Wqkv",                                     # jina
-            "model.layers.{bid}.self_attn.qkv_proj",                               # phi3
+            "model.layers.{bid}.self_attn.qkv_proj",                               # phi3 minimax-m2
             "model.layers.layers.{bid}.mixer.qkv_proj",                            # plamo2
             "encoder.layers.{bid}.self_attention.query_key_value",                 # chatglm
             "transformer.layers.{bid}.attn.qkv_proj",                              # openelm
@@ -377,6 +377,7 @@ class TensorNameMap:
             "model.layers.{bid}.mlp.moe_statics.e_score_correction",        # ernie4.5-moe
             "model.layers.{bid}.mlp.gate.expert_bias",                      # bailingmoe2
             "model.layers.{bid}.feed_forward.expert_bias",                  # lfm2moe
+            "model.layers.{bid}.block_sparse_moe.e_score_correction.bias",  # minimax-m2
         ),
 
         # Feed-forward up
@@ -553,7 +554,7 @@ class TensorNameMap:
             "model.layers.{bid}.self_attn.q_layernorm",                       # persimmon
             "model.layers.{bid}.self_attn.query_layernorm",                   # hunyuan
             "model.layers.{bid}.attention.query_layernorm",                   # bailingmoe2
-            "model.layers.{bid}.self_attn.q_norm",                            # cohere olmoe chameleon olmo2
+            "model.layers.{bid}.self_attn.q_norm",                            # cohere olmoe chameleon olmo2 minimax-m2
             "layers.{bid}.self_attn.q_norm",                                  # embeddinggemma
             "transformer.blocks.{bid}.attn.q_ln",                             # sea-lion
             "encoder.layer.{bid}.attention.self.layer_norm_q",                # jina-bert-v2
@@ -568,7 +569,7 @@ class TensorNameMap:
             "model.layers.{bid}.self_attn.k_layernorm",                       # persimmon
             "model.layers.{bid}.self_attn.key_layernorm",                     # hunyuan
             "model.layers.{bid}.attention.key_layernorm",                     # bailingmoe2
-            "model.layers.{bid}.self_attn.k_norm",                            # cohere olmoe chameleon olmo2
+            "model.layers.{bid}.self_attn.k_norm",                            # cohere olmoe chameleon olmo2 minimax-m2
             "layers.{bid}.self_attn.k_norm",                                  # embeddinggemma
             "transformer.blocks.{bid}.attn.k_ln",                             # sea-lion
             "encoder.layer.{bid}.attention.self.layer_norm_k",                # jina-bert-v2

+ 22 - 0
src/llama-arch.cpp

@@ -103,6 +103,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_SEED_OSS,         "seed_oss"         },
     { LLM_ARCH_GROVEMOE,         "grovemoe"         },
     { LLM_ARCH_APERTUS,          "apertus"          },
+    { LLM_ARCH_MINIMAX_M2,       "minimax-m2"       },
     { LLM_ARCH_UNKNOWN,          "(unknown)"        },
 };
 
@@ -779,6 +780,27 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_UP_EXPS,        "blk.%d.ffn_up_exps" },
         },
     },
+    {
+        LLM_ARCH_MINIMAX_M2,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
+            { LLM_TENSOR_OUTPUT,             "output" },
+            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_Q_NORM,        "blk.%d.attn_q_norm" },
+            { LLM_TENSOR_ATTN_K,             "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_K_NORM,        "blk.%d.attn_k_norm" },
+            { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE_INP,       "blk.%d.ffn_gate_inp" },
+            { LLM_TENSOR_FFN_EXP_PROBS_B,    "blk.%d.exp_probs_b" },
+            { LLM_TENSOR_FFN_GATE_EXPS,      "blk.%d.ffn_gate_exps" },
+            { LLM_TENSOR_FFN_DOWN_EXPS,      "blk.%d.ffn_down_exps" },
+            { LLM_TENSOR_FFN_UP_EXPS,        "blk.%d.ffn_up_exps" },
+        },
+    },
     {
         LLM_ARCH_PHI2,
         {

+ 1 - 0
src/llama-arch.h

@@ -107,6 +107,7 @@ enum llm_arch {
     LLM_ARCH_SEED_OSS,
     LLM_ARCH_GROVEMOE,
     LLM_ARCH_APERTUS,
+    LLM_ARCH_MINIMAX_M2,
     LLM_ARCH_UNKNOWN,
 };
 

+ 229 - 0
src/llama-model.cpp

@@ -125,6 +125,7 @@ const char * llm_type_name(llm_type type) {
         case LLM_TYPE_355B_A32B:     return "355B.A32B";
         case LLM_TYPE_E2B:           return "E2B";
         case LLM_TYPE_E4B:           return "E4B";
+        case LLM_TYPE_256xA10B:      return "230x10B";
         default:                     return "?B";
     }
 }
@@ -2124,6 +2125,26 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_MINIMAX_M2:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,  hparams.n_ff_exp);
+                ml.get_key(LLM_KV_EXPERT_GATING_FUNC,          hparams.expert_gating_func);
+
+                // MiniMax M2 uses GQA with head_dim=128, not n_embd/n_head=64
+                // Override if KEY_LENGTH is not explicitly set in GGUF
+                if (hparams.n_embd_head_k == hparams.n_embd / hparams.n_head()) {
+                    // Model uses GQA: n_head=48, n_head_kv=8, head_dim=128
+                    // Q dim = 48*128=6144, K/V dim = 8*128=1024
+                    hparams.n_embd_head_k = 128;
+                    hparams.n_embd_head_v = 128;
+                }
+
+                switch (hparams.n_layer) {
+                    case 62: type = LLM_TYPE_256xA10B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         default: throw std::runtime_error("unsupported model architecture");
     }
 
@@ -2575,6 +2596,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
 
                         const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
 
+                        // Gate/Up in file are ordered [n_embd, n_ff_exp, n_expert]
                         layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
                         layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
                         layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
@@ -3263,6 +3285,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         // MoE branch
                         const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
 
+                        // optional router bias (e_score_correction.bias)
+                        layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
+
+                        // merged expert tensors
                         layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
                         layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
                         layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
@@ -3349,6 +3375,59 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         // MoE branch
                         const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
 
+                        // optional router bias (e_score_correction.bias)
+                        layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
+
+                        // merged expert tensors
+                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
+                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+                    }
+                } break;
+            case LLM_ARCH_MINIMAX_M2:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+                        // QK norm (per-head: each of n_head Q heads and n_head_kv K heads has separate norm params)
+                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k * n_head}, 0);
+                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k * n_head_kv}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+
+                        if (n_expert == 0) {
+                            throw std::runtime_error("n_expert must be > 0 for MINIMAX_M2");
+                        }
+                        if (n_expert_used == 0) {
+                            throw std::runtime_error("n_expert_used must be > 0 for MINIMAX_M2");
+                        }
+
+                        // MoE branch
+                        const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
+
+                        // optional router bias (e_score_correction_bias -> exp_probs_b, no suffix)
+                        layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, nullptr, i), {n_expert}, TENSOR_NOT_REQUIRED);
+
+                        // merged expert tensors
                         layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
                         layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
                         layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
@@ -9484,6 +9563,151 @@ struct llm_build_qwen3 : public llm_graph_context {
     }
 };
 
+struct llm_build_minimax_m2 : public llm_graph_context {
+    llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        // MiniMax M2 uses partial RoPE: head_dim=128, rotary_dim=64
+
+        llama_expert_gating_func_type gating_func =
+            static_cast<llama_expert_gating_func_type>(hparams.expert_gating_func);
+        if (gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
+            gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX;
+        }
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv();
+
+        ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self_attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                // MiniMax M2: QK norm is applied to flattened Q/K before reshape
+                // Q: {n_embd_head_k * n_head, n_tokens} -> norm -> reshape to 3D
+                // K: {n_embd_head_k * n_head_kv, n_tokens} -> norm -> reshape to 3D
+                if (model.layers[il].attn_q_norm) {
+                    Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+                    cb(Qcur, "Qcur_normed", il);
+                }
+
+                if (model.layers[il].attn_k_norm) {
+                    Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+                    cb(Kcur, "Kcur_normed", il);
+                }
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(inp_attn,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1 && inp_out_ids) {
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // MoE branch
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            ggml_tensor * moe_out =
+                build_moe_ffn(cur,
+                        model.layers[il].ffn_gate_inp,
+                        model.layers[il].ffn_gate_inp_b,
+                        model.layers[il].ffn_up_exps,
+                        model.layers[il].ffn_up_exps_b,
+                        model.layers[il].ffn_gate_exps,
+                        model.layers[il].ffn_gate_exps_b,
+                        model.layers[il].ffn_down_exps,
+                        model.layers[il].ffn_down_exps_b,
+                        model.layers[il].ffn_exp_probs_b,
+                        n_expert, n_expert_used,
+                        LLM_FFN_SILU, true,
+                        false, 0.0f,
+                        gating_func,
+                        il);
+            cb(moe_out, "ffn_moe_out", il);
+            cur = moe_out;
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
 struct llm_build_qwen3moe : public llm_graph_context {
     llm_build_qwen3moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -19888,6 +20112,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
             {
                 llm = std::make_unique<llm_build_qwen3moe>(*this, params);
             } break;
+        case LLM_ARCH_MINIMAX_M2:
+            {
+                llm = std::make_unique<llm_build_minimax_m2>(*this, params);
+            } break;
         case LLM_ARCH_PHI2:
             {
                 llm = std::make_unique<llm_build_phi2>(*this, params);
@@ -20397,6 +20625,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_SEED_OSS:
         case LLM_ARCH_GROVEMOE:
         case LLM_ARCH_APERTUS:
+        case LLM_ARCH_MINIMAX_M2:
             return LLAMA_ROPE_TYPE_NEOX;
 
         case LLM_ARCH_QWEN2VL:

+ 1 - 0
src/llama-model.h

@@ -119,6 +119,7 @@ enum llm_type {
     LLM_TYPE_355B_A32B, // GLM-4.5
     LLM_TYPE_E2B,
     LLM_TYPE_E4B,
+    LLM_TYPE_256xA10B, // MiniMax M2 - 256 experts, 10B active
 };
 
 std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type);

BIN
tools/server/public/index.html.gz


+ 33 - 6
tools/server/webui/src/lib/services/chat.ts

@@ -114,10 +114,23 @@ export class ChatService {
 		const processedMessages = this.injectSystemMessage(normalizedMessages);
 
 		const requestBody: ApiChatCompletionRequest = {
-			messages: processedMessages.map((msg: ApiChatMessageData) => ({
-				role: msg.role,
-				content: msg.content
-			})),
+			messages: processedMessages.map((msg: ApiChatMessageData) => {
+				const apiMsg: {
+					role: ChatRole;
+					content: string | ApiChatMessageContentPart[];
+					reasoning_content?: string;
+				} = {
+					role: msg.role,
+					content: msg.content
+				};
+				
+				// Include reasoning_content if present (for interleaved thinking models like MiniMax-M2)
+				if (msg.reasoning_content) {
+					apiMsg.reasoning_content = msg.reasoning_content;
+				}
+				
+				return apiMsg;
+			}),
 			stream
 		};
 
@@ -449,10 +462,17 @@ export class ChatService {
 		message: DatabaseMessage & { extra?: DatabaseMessageExtra[] }
 	): ApiChatMessageData {
 		if (!message.extra || message.extra.length === 0) {
-			return {
+			const result: ApiChatMessageData = {
 				role: message.role as 'user' | 'assistant' | 'system',
 				content: message.content
 			};
+			
+			// Preserve reasoning content (thinking) for interleaved thinking models
+			if (message.thinking) {
+				result.reasoning_content = message.thinking;
+			}
+			
+			return result;
 		}
 
 		const contentParts: ApiChatMessageContentPart[] = [];
@@ -537,10 +557,17 @@ export class ChatService {
 			}
 		}
 
-		return {
+		const result: ApiChatMessageData = {
 			role: message.role as 'user' | 'assistant' | 'system',
 			content: contentParts
 		};
+		
+		// Preserve reasoning content (thinking) for interleaved thinking models
+		if (message.thinking) {
+			result.reasoning_content = message.thinking;
+		}
+		
+		return result;
 	}
 
 	/**

+ 2 - 0
tools/server/webui/src/lib/types/api.d.ts

@@ -33,6 +33,7 @@ export interface ApiErrorResponse {
 export interface ApiChatMessageData {
 	role: ChatRole;
 	content: string | ApiChatMessageContentPart[];
+	reasoning_content?: string;
 	timestamp?: number;
 }
 
@@ -153,6 +154,7 @@ export interface ApiChatCompletionRequest {
 	messages: Array<{
 		role: ChatRole;
 		content: string | ApiChatMessageContentPart[];
+		reasoning_content?: string;
 	}>;
 	stream?: boolean;
 	model?: string;