Просмотр исходного кода

model: support Ministral3 (#17644)

* conversion script

* support ministral 3

* maybe this is better?

* add TODO for rope_yarn_log_mul

* better ppl (tested on 14B-Instruct)

* Add Ministral3 support to Mistral format

* improve arch handling

* add sizes

* Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* nits

---------

Co-authored-by: Julien Denize <julien.denize@mistral.ai>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Xuan-Son Nguyen 1 месяц назад
Родитель
Сommit
cd3c118908

+ 70 - 4
convert_hf_to_gguf.py

@@ -1581,10 +1581,27 @@ class MmprojModel(ModelBase):
 
 
         # load preprocessor config
         # load preprocessor config
         self.preprocessor_config = {}
         self.preprocessor_config = {}
-        if not self.is_mistral_format:
-            with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
+
+        # prefer preprocessor_config.json if possible
+        preprocessor_config_path = self.dir_model / "preprocessor_config.json"
+        if preprocessor_config_path.is_file():
+            with open(preprocessor_config_path, "r", encoding="utf-8") as f:
                 self.preprocessor_config = json.load(f)
                 self.preprocessor_config = json.load(f)
 
 
+        # prefer processor_config.json if possible
+        processor_config_path = self.dir_model / "processor_config.json"
+        if processor_config_path.is_file():
+            with open(processor_config_path, "r", encoding="utf-8") as f:
+                cfg = json.load(f)
+                # move image_processor to root level for compat
+                if "image_processor" in cfg:
+                    cfg = {
+                        **cfg,
+                        **cfg["image_processor"],
+                    }
+                # merge configs
+                self.preprocessor_config = {**self.preprocessor_config, **cfg}
+
     def get_vision_config(self) -> dict[str, Any] | None:
     def get_vision_config(self) -> dict[str, Any] | None:
         config_name = "vision_config" if not self.is_mistral_format else "vision_encoder"
         config_name = "vision_config" if not self.is_mistral_format else "vision_encoder"
         return self.global_config.get(config_name)
         return self.global_config.get(config_name)
@@ -2797,7 +2814,32 @@ class Llama4VisionModel(MmprojModel):
 
 
 @ModelBase.register("Mistral3ForConditionalGeneration")
 @ModelBase.register("Mistral3ForConditionalGeneration")
 class Mistral3Model(LlamaModel):
 class Mistral3Model(LlamaModel):
-    model_arch = gguf.MODEL_ARCH.LLAMA
+    model_arch = gguf.MODEL_ARCH.MISTRAL3
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        # for compatibility, we use LLAMA arch for older models
+        # TODO: remove this once everyone has migrated to newer version of llama.cpp
+        if self.hparams.get("model_type") != "ministral3":
+            self.model_arch = gguf.MODEL_ARCH.LLAMA
+            self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
+            self.gguf_writer.add_architecture()
+            self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        rope_params = self.hparams.get("rope_parameters")
+        if self.hparams.get("model_type") == "ministral3":
+            assert rope_params is not None, "ministral3 must have 'rope_parameters' config"
+            assert rope_params["rope_type"] == "yarn", "ministral3 rope_type must be 'yarn'"
+            self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
+            self.gguf_writer.add_rope_scaling_factor(rope_params["factor"])
+            self.gguf_writer.add_rope_scaling_yarn_beta_fast(rope_params["beta_fast"])
+            self.gguf_writer.add_rope_scaling_yarn_beta_slow(rope_params["beta_slow"])
+            self.gguf_writer.add_rope_scaling_yarn_log_mul(rope_params["mscale_all_dim"])
+            self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_params["original_max_position_embeddings"])
+            self.gguf_writer.add_rope_freq_base(rope_params["rope_theta"])
+            self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"])
 
 
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
         name = name.replace("language_model.", "")
         name = name.replace("language_model.", "")
@@ -9809,12 +9851,22 @@ class ApertusModel(LlamaModel):
 
 
 
 
 class MistralModel(LlamaModel):
 class MistralModel(LlamaModel):
-    model_arch = gguf.MODEL_ARCH.LLAMA
+    model_arch = gguf.MODEL_ARCH.MISTRAL3
     model_name = "Mistral"
     model_name = "Mistral"
     hf_arch = ""
     hf_arch = ""
     is_mistral_format = True
     is_mistral_format = True
     undo_permute = False
     undo_permute = False
 
 
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        # for compatibility, we use LLAMA arch for older models
+        # TODO: remove this once everyone migrates to newer version of llama.cpp
+        if "llama_4_scaling" not in self.hparams:
+            self.model_arch = gguf.MODEL_ARCH.LLAMA
+            self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
+            self.gguf_writer.add_architecture()
+            self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
+
     @staticmethod
     @staticmethod
     def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mistral_format: bool):
     def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mistral_format: bool):
         assert TokenizerVersion is not None and Tekkenizer is not None and SentencePieceTokenizer is not None, _mistral_import_error_msg
         assert TokenizerVersion is not None and Tekkenizer is not None and SentencePieceTokenizer is not None, _mistral_import_error_msg
@@ -9854,6 +9906,20 @@ class MistralModel(LlamaModel):
 
 
         return template
         return template
 
 
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        if "yarn" in self.hparams:
+            yarn_params = self.hparams["yarn"]
+            self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
+            self.gguf_writer.add_rope_scaling_factor(yarn_params["factor"])
+            self.gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_params["beta"])
+            self.gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_params["alpha"])
+            self.gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim
+            self.gguf_writer.add_rope_scaling_orig_ctx_len(yarn_params["original_max_position_embeddings"])
+
+        if "llama_4_scaling" in self.hparams:
+            self.gguf_writer.add_attn_temperature_scale(self.hparams["llama_4_scaling"]["beta"])
+
 
 
 class PixtralModel(LlavaVisionModel):
 class PixtralModel(LlavaVisionModel):
     model_name = "Pixtral"
     model_name = "Pixtral"

+ 23 - 0
gguf-py/gguf/constants.py

@@ -175,6 +175,7 @@ class Keys:
         VALUE_LENGTH_MLA             = "{arch}.attention.value_length_mla"
         VALUE_LENGTH_MLA             = "{arch}.attention.value_length_mla"
         SHARED_KV_LAYERS             = "{arch}.attention.shared_kv_layers"
         SHARED_KV_LAYERS             = "{arch}.attention.shared_kv_layers"
         SLIDING_WINDOW_PATTERN       = "{arch}.attention.sliding_window_pattern"
         SLIDING_WINDOW_PATTERN       = "{arch}.attention.sliding_window_pattern"
+        TEMPERATURE_SCALE            = "{arch}.attention.temperature_scale"
 
 
     class Rope:
     class Rope:
         DIMENSION_COUNT          = "{arch}.rope.dimension_count"
         DIMENSION_COUNT          = "{arch}.rope.dimension_count"
@@ -444,6 +445,7 @@ class MODEL_ARCH(IntEnum):
     MINIMAXM2        = auto()
     MINIMAXM2        = auto()
     RND1             = auto()
     RND1             = auto()
     PANGU_EMBED      = auto()
     PANGU_EMBED      = auto()
+    MISTRAL3         = auto()
 
 
 
 
 class VISION_PROJECTOR_TYPE(IntEnum):
 class VISION_PROJECTOR_TYPE(IntEnum):
@@ -817,6 +819,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.COGVLM:           "cogvlm",
     MODEL_ARCH.COGVLM:           "cogvlm",
     MODEL_ARCH.RND1:             "rnd1",
     MODEL_ARCH.RND1:             "rnd1",
     MODEL_ARCH.PANGU_EMBED:      "pangu-embedded",
     MODEL_ARCH.PANGU_EMBED:      "pangu-embedded",
+    MODEL_ARCH.MISTRAL3:         "mistral3",
 }
 }
 
 
 VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
 VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -3071,6 +3074,26 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN,
         MODEL_TENSOR.FFN_DOWN,
         MODEL_TENSOR.FFN_UP,
         MODEL_TENSOR.FFN_UP,
     ],
     ],
+    MODEL_ARCH.MISTRAL3: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.ATTN_ROT_EMBD,
+        MODEL_TENSOR.FFN_GATE_INP,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.FFN_GATE_EXP,
+        MODEL_TENSOR.FFN_DOWN_EXP,
+        MODEL_TENSOR.FFN_UP_EXP,
+    ],
     # TODO
     # TODO
 }
 }
 
 

+ 3 - 0
gguf-py/gguf/gguf_writer.py

@@ -904,6 +904,9 @@ class GGUFWriter:
     def add_attn_temperature_length(self, value: int) -> None:
     def add_attn_temperature_length(self, value: int) -> None:
         self.add_uint32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value)
         self.add_uint32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value)
 
 
+    def add_attn_temperature_scale(self, value: float) -> None:
+        self.add_float32(Keys.Attention.TEMPERATURE_SCALE.format(arch=self.arch), value)
+
     def add_pooling_type(self, value: PoolingType) -> None:
     def add_pooling_type(self, value: PoolingType) -> None:
         self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
         self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
 
 

+ 1 - 0
src/CMakeLists.txt

@@ -132,6 +132,7 @@ add_library(llama
             models/t5-enc.cpp
             models/t5-enc.cpp
             models/wavtokenizer-dec.cpp
             models/wavtokenizer-dec.cpp
             models/xverse.cpp
             models/xverse.cpp
+            models/mistral3.cpp
             models/graph-context-mamba.cpp
             models/graph-context-mamba.cpp
             )
             )
 
 

+ 28 - 0
src/llama-arch.cpp

@@ -111,6 +111,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_COGVLM,           "cogvlm"           },
     { LLM_ARCH_COGVLM,           "cogvlm"           },
     { LLM_ARCH_RND1,             "rnd1"             },
     { LLM_ARCH_RND1,             "rnd1"             },
     { LLM_ARCH_PANGU_EMBED,      "pangu-embedded"   },
     { LLM_ARCH_PANGU_EMBED,      "pangu-embedded"   },
+    { LLM_ARCH_MISTRAL3,         "mistral3"         },
     { LLM_ARCH_UNKNOWN,          "(unknown)"        },
     { LLM_ARCH_UNKNOWN,          "(unknown)"        },
 };
 };
 
 
@@ -204,6 +205,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_ATTENTION_SCALE,                        "%s.attention.scale"                        },
     { LLM_KV_ATTENTION_SCALE,                        "%s.attention.scale"                        },
     { LLM_KV_ATTENTION_OUTPUT_SCALE,                 "%s.attention.output_scale"                 },
     { LLM_KV_ATTENTION_OUTPUT_SCALE,                 "%s.attention.output_scale"                 },
     { LLM_KV_ATTENTION_TEMPERATURE_LENGTH,           "%s.attention.temperature_length"           },
     { LLM_KV_ATTENTION_TEMPERATURE_LENGTH,           "%s.attention.temperature_length"           },
+    { LLM_KV_ATTENTION_TEMPERATURE_SCALE,            "%s.attention.temperature_scale"            },
     { LLM_KV_ATTENTION_KEY_LENGTH_MLA,               "%s.attention.key_length_mla"               },
     { LLM_KV_ATTENTION_KEY_LENGTH_MLA,               "%s.attention.key_length_mla"               },
     { LLM_KV_ATTENTION_VALUE_LENGTH_MLA,             "%s.attention.value_length_mla"             },
     { LLM_KV_ATTENTION_VALUE_LENGTH_MLA,             "%s.attention.value_length_mla"             },
 
 
@@ -2512,6 +2514,32 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_UP_EXPS,        "blk.%d.ffn_up_exps" },
             { LLM_TENSOR_FFN_UP_EXPS,        "blk.%d.ffn_up_exps" },
         },
         },
     },
     },
+    {
+        LLM_ARCH_MISTRAL3,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
+            { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_FFN_GATE_EXP,    "blk.%d.ffn_gate.%d" },
+            { LLM_TENSOR_FFN_DOWN_EXP,    "blk.%d.ffn_down.%d" },
+            { LLM_TENSOR_FFN_UP_EXP,      "blk.%d.ffn_up.%d" },
+            { LLM_TENSOR_FFN_GATE_EXPS,   "blk.%d.ffn_gate_exps" },
+            { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" },
+            { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
+        },
+    },
     {
     {
         LLM_ARCH_UNKNOWN,
         LLM_ARCH_UNKNOWN,
         {
         {

+ 2 - 0
src/llama-arch.h

@@ -115,6 +115,7 @@ enum llm_arch {
     LLM_ARCH_COGVLM,
     LLM_ARCH_COGVLM,
     LLM_ARCH_RND1,
     LLM_ARCH_RND1,
     LLM_ARCH_PANGU_EMBED,
     LLM_ARCH_PANGU_EMBED,
+    LLM_ARCH_MISTRAL3,
     LLM_ARCH_UNKNOWN,
     LLM_ARCH_UNKNOWN,
 };
 };
 
 
@@ -208,6 +209,7 @@ enum llm_kv {
     LLM_KV_ATTENTION_SCALE,
     LLM_KV_ATTENTION_SCALE,
     LLM_KV_ATTENTION_OUTPUT_SCALE,
     LLM_KV_ATTENTION_OUTPUT_SCALE,
     LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
     LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
+    LLM_KV_ATTENTION_TEMPERATURE_SCALE,
     LLM_KV_ATTENTION_KEY_LENGTH_MLA,
     LLM_KV_ATTENTION_KEY_LENGTH_MLA,
     LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
     LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
 
 

+ 3 - 0
src/llama-graph.cpp

@@ -71,6 +71,9 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
     if (ubatch->pos && attn_scale) {
     if (ubatch->pos && attn_scale) {
         const int64_t n_tokens = ubatch->n_tokens;
         const int64_t n_tokens = ubatch->n_tokens;
 
 
+        GGML_ASSERT(f_attn_temp_scale != 0.0f);
+        GGML_ASSERT(n_attn_temp_floor_scale != 0);
+
         std::vector<float> attn_scale_data(n_tokens, 0.0f);
         std::vector<float> attn_scale_data(n_tokens, 0.0f);
         for (int i = 0; i < n_tokens; ++i) {
         for (int i = 0; i < n_tokens; ++i) {
             const float pos = ubatch->pos[i];
             const float pos = ubatch->pos[i];

+ 2 - 2
src/llama-hparams.h

@@ -162,8 +162,8 @@ struct llama_hparams {
     // llama4 smallthinker
     // llama4 smallthinker
     uint32_t n_moe_layer_step        = 0;
     uint32_t n_moe_layer_step        = 0;
     uint32_t n_no_rope_layer_step    = 4;
     uint32_t n_no_rope_layer_step    = 4;
-    uint32_t n_attn_temp_floor_scale = 8192;
-    float    f_attn_temp_scale       = 0.1;
+    uint32_t n_attn_temp_floor_scale = 0;
+    float    f_attn_temp_scale       = 0.0f;
 
 
     // gemma3n altup
     // gemma3n altup
     uint32_t n_altup      = 4; // altup_num_inputs
     uint32_t n_altup      = 4; // altup_num_inputs

+ 46 - 4
src/llama-model.cpp

@@ -626,8 +626,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
     switch (arch) {
     switch (arch) {
         case LLM_ARCH_LLAMA:
         case LLM_ARCH_LLAMA:
             {
             {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-
                 if (hparams.n_expert == 8) {
                 if (hparams.n_expert == 8) {
                     switch (hparams.n_layer) {
                     switch (hparams.n_layer) {
                         case 32: type = LLM_TYPE_8x7B; break;
                         case 32: type = LLM_TYPE_8x7B; break;
@@ -663,8 +661,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     hparams.swa_type             = LLAMA_SWA_TYPE_NONE;
                     hparams.swa_type             = LLAMA_SWA_TYPE_NONE;
                     hparams.n_no_rope_layer_step = hparams.n_layer; // always use rope
                     hparams.n_no_rope_layer_step = hparams.n_layer; // always use rope
                 } else {
                 } else {
-                    hparams.swa_type      = LLAMA_SWA_TYPE_CHUNKED;
-                    hparams.n_swa         = 8192;
+                    hparams.swa_type                = LLAMA_SWA_TYPE_CHUNKED;
+                    hparams.n_swa                   = 8192;
+                    hparams.n_attn_temp_floor_scale = 8192;
+                    hparams.f_attn_temp_scale       = 0.1f;
                     hparams.set_swa_pattern(4);   // pattern: 3 chunked - 1 full
                     hparams.set_swa_pattern(4);   // pattern: 3 chunked - 1 full
                 }
                 }
 
 
@@ -2247,6 +2247,42 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                     default: type = LLM_TYPE_UNKNOWN;
                 }
                 }
             } break;
             } break;
+        case LLM_ARCH_MISTRAL3:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false);
+
+                ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST,   hparams.yarn_beta_fast, false);
+                ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW,   hparams.yarn_beta_slow, false);
+                ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL,     hparams.rope_yarn_log_mul, false);
+
+                // TODO: maybe add n_attn_temp_floor_scale as a separate KV?
+                if (hparams.f_attn_temp_scale != 0.0f) {
+                    hparams.n_attn_temp_floor_scale = hparams.n_ctx_orig_yarn;
+                    if (hparams.n_attn_temp_floor_scale == 0) {
+                        throw std::runtime_error("invalid n_ctx_orig_yarn for attention temperature scaling");
+                    }
+                }
+
+                // TODO: this seems to be correct with the case of mscale == mscale_all_dims == 1.0f
+                //       but may need further verification with other values
+                if (hparams.rope_yarn_log_mul != 0.0f) {
+                    float factor = 1.0f / hparams.rope_freq_scale_train;
+                    float mscale = 1.0f;
+                    float mscale_all_dims = hparams.rope_yarn_log_mul;
+                    static auto get_mscale = [](float scale, float mscale) {
+                        return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
+                    };
+                    hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
+                }
+
+                switch (hparams.n_layer) {
+                    case 26: type = LLM_TYPE_3B; break;
+                    case 34: type = LLM_TYPE_8B; break;
+                    case 40: type = LLM_TYPE_14B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         default: throw std::runtime_error("unsupported model architecture");
         default: throw std::runtime_error("unsupported model architecture");
     }
     }
 
 
@@ -2560,6 +2596,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
             case LLM_ARCH_MINICPM:
             case LLM_ARCH_MINICPM:
             case LLM_ARCH_GRANITE:
             case LLM_ARCH_GRANITE:
             case LLM_ARCH_GRANITE_MOE:
             case LLM_ARCH_GRANITE_MOE:
+            case LLM_ARCH_MISTRAL3:
                 {
                 {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
 
@@ -7522,6 +7559,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
             {
             {
                 llm = std::make_unique<llm_build_qwen3next>(*this, params);
                 llm = std::make_unique<llm_build_qwen3next>(*this, params);
             } break;
             } break;
+        case LLM_ARCH_MISTRAL3:
+            {
+                llm = std::make_unique<llm_build_mistral3>(*this, params);
+            } break;
         default:
         default:
             GGML_ABORT("fatal error");
             GGML_ABORT("fatal error");
     }
     }
@@ -7690,6 +7731,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_ARCEE:
         case LLM_ARCH_ARCEE:
         case LLM_ARCH_ERNIE4_5:
         case LLM_ARCH_ERNIE4_5:
         case LLM_ARCH_ERNIE4_5_MOE:
         case LLM_ARCH_ERNIE4_5_MOE:
+        case LLM_ARCH_MISTRAL3:
             return LLAMA_ROPE_TYPE_NORM;
             return LLAMA_ROPE_TYPE_NORM;
 
 
         // the pairs of head values are offset by n_rot/2
         // the pairs of head values are offset by n_rot/2

+ 160 - 0
src/models/mistral3.cpp

@@ -0,0 +1,160 @@
+#include "models.h"
+
+llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+    const int64_t n_embd_head = hparams.n_embd_head_v;
+
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+
+    // inp_pos - contains the positions
+    ggml_tensor * inp_pos = build_inp_pos();
+
+    // (optional) temperature tuning
+    ggml_tensor * inp_attn_scale = nullptr;
+    if (hparams.f_attn_temp_scale != 0.0f) {
+        inp_attn_scale = build_inp_attn_scale();
+    }
+
+    auto * inp_attn = build_attn_inp_kv();
+
+    const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_tensor * inpSA = inpL;
+
+        // norm
+        cur = build_norm(inpL,
+                model.layers[il].attn_norm, NULL,
+                LLM_NORM_RMS, il);
+        cb(cur, "attn_norm", il);
+
+        // self-attention
+        {
+            // rope freq factors for llama3; may return nullptr for llama2 and other models
+            ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
+
+            // compute Q and K and RoPE them
+            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+            cb(Qcur, "Qcur", il);
+            if (model.layers[il].bq) {
+                Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                cb(Qcur, "Qcur", il);
+            }
+            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+            cb(Kcur, "Kcur", il);
+            if (model.layers[il].bk) {
+                Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                cb(Kcur, "Kcur", il);
+            }
+            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+            cb(Vcur, "Vcur", il);
+            if (model.layers[il].bv) {
+                Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                cb(Vcur, "Vcur", il);
+            }
+            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+            Qcur = ggml_rope_ext(
+                    ctx0, Qcur, inp_pos, rope_factors,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            Kcur = ggml_rope_ext(
+                    ctx0, Kcur, inp_pos, rope_factors,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            cb(Qcur, "Qcur", il);
+            cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
+
+            if (inp_attn_scale) {
+                // apply llama 4 temperature scaling
+                Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale);
+                cb(Qcur, "Qcur_attn_temp_scaled", il);
+            }
+
+            cur = build_attn(inp_attn,
+                    model.layers[il].wo, model.layers[il].bo,
+                    Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
+            cb(cur, "attn_out", il);
+        }
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+        cb(ffn_inp, "ffn_inp", il);
+
+        // feed-forward network (non-MoE)
+        if (model.layers[il].ffn_gate_inp == nullptr) {
+
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                    model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                    model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+        } else {
+            // MoE branch
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_moe_ffn(cur,
+                    model.layers[il].ffn_gate_inp,
+                    model.layers[il].ffn_up_exps,
+                    model.layers[il].ffn_gate_exps,
+                    model.layers[il].ffn_down_exps,
+                    nullptr,
+                    n_expert, n_expert_used,
+                    LLM_FFN_SILU, true,
+                    false, 0.0,
+                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                    il);
+            cb(cur, "ffn_moe_out", il);
+        }
+        cur = ggml_add(ctx0, cur, ffn_inp);
+        cb(cur, "ffn_out", il);
+
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        // input for next layer
+        inpL = cur;
+    }
+    cur = inpL;
+
+    cur = build_norm(cur,
+            model.output_norm, NULL,
+            LLM_NORM_RMS, -1);
+
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    // lm_head
+    cur = build_lora_mm(model.output, cur);
+
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}

+ 4 - 0
src/models/models.h

@@ -322,6 +322,10 @@ struct llm_build_minimax_m2 : public llm_graph_context {
     llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params);
     llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params);
 };
 };
 
 
+struct llm_build_mistral3 : public llm_graph_context {
+    llm_build_mistral3(const llama_model & model, const llm_graph_params & params);
+};
+
 struct llm_build_mpt : public llm_graph_context {
 struct llm_build_mpt : public llm_graph_context {
     llm_build_mpt(const llama_model & model, const llm_graph_params & params);
     llm_build_mpt(const llama_model & model, const llm_graph_params & params);
 };
 };