فهرست منبع

convert: allow using quantized Mistral weight (#17889)

* convert: allow using quantized Mistral weight

* data_torch.ndim

* update dequant fn

Co-authored-by: compilade <compilade@users.noreply.github.com>

---------

Co-authored-by: compilade <compilade@users.noreply.github.com>
Xuan-Son Nguyen 1 ماه پیش
والد
کامیت
9e79b0116e
1فایلهای تغییر یافته به همراه24 افزوده شده و 4 حذف شده
  1. 24 4
      convert_hf_to_gguf.py

+ 24 - 4
convert_hf_to_gguf.py

@@ -383,6 +383,17 @@ class ModelBase:
                         s = self.model_tensors[name]
                         self.model_tensors[weight_name] = lambda w=w, s=s, bs=block_size: dequant_simple(w(), s(), bs)
                         tensors_to_remove.append(name)
+                    if name.endswith(".activation_scale"):  # unused
+                        tensors_to_remove.append(name)
+                    # mistral format
+                    if name.endswith(".qscale_weight"):
+                        weight_name = name.removesuffix("qscale_weight") + "weight"
+                        w = self.model_tensors[weight_name]
+                        s = self.model_tensors[name]
+                        self.model_tensors[weight_name] = lambda w=w, s=s, bs=block_size: dequant_simple(w(), s(), bs)
+                        tensors_to_remove.append(name)
+                    if name.endswith(".qscale_act"):
+                        tensors_to_remove.append(name)
             elif quant_method == "gptq":
                 for name in self.model_tensors.keys():
                     if name.endswith(".qweight"):
@@ -2854,13 +2865,10 @@ class Mistral3Model(LlamaModel):
             self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"])
 
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
-        # TODO: probably not worth supporting quantized weight, as official BF16 is also available
-        if name.endswith("weight_scale_inv"):
-            raise ValueError("This is a quantized weight, please use BF16 weight instead")
-
         name = name.replace("language_model.", "")
         if "multi_modal_projector" in name or "vision_tower" in name:
             return []
+
         return super().modify_tensors(data_torch, name, bid)
 
 
@@ -9898,6 +9906,18 @@ class MistralModel(LlamaModel):
             self.gguf_writer.add_architecture()
             self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
 
+    def dequant_model(self):
+        # transform quantization config into HF format
+        quant_config = self.hparams.get("quantization")
+        if quant_config is not None:
+            assert quant_config["qformat_weight"] == "fp8_e4m3"
+            self.hparams["quantization_config"] = {
+                "activation_scheme": "static",
+                "quant_method": "fp8",
+                "weight_block_size": None,
+            }
+        return super().dequant_model()
+
     @staticmethod
     def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mistral_format: bool):
         assert TokenizerVersion is not None and Tekkenizer is not None and SentencePieceTokenizer is not None, _mistral_import_error_msg