Просмотр исходного кода

convert: fix Mistral3/Gemma3 model hparams init (#12571)

* Fix Mistral3/Gemma3 model hparams init

* set positional args correctly

* use existing hparams if passed
Sigbjørn Skjæret 9 месяцев назад
Родитель
Сommit
53af4dba42
1 измененных файлов с 3 добавлено и 3 удалено
  1. 3 3
      convert_hf_to_gguf.py

+ 3 - 3
convert_hf_to_gguf.py

@@ -1752,7 +1752,7 @@ class Mistral3Model(LlamaModel):
 
 
     # we need to merge the text_config into the root level of hparams
     # we need to merge the text_config into the root level of hparams
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
-        hparams = Model.load_hparams(kwargs["dir_model"])
+        hparams = kwargs["hparams"] if "hparams" in kwargs else Model.load_hparams(args[0])
         if "text_config" in hparams:
         if "text_config" in hparams:
             hparams = {**hparams, **hparams["text_config"]}
             hparams = {**hparams, **hparams["text_config"]}
             kwargs["hparams"] = hparams
             kwargs["hparams"] = hparams
@@ -3385,7 +3385,7 @@ class Gemma3Model(Model):
 
 
     # we need to merge the text_config into the root level of hparams
     # we need to merge the text_config into the root level of hparams
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
-        hparams = Model.load_hparams(kwargs["dir_model"])
+        hparams = kwargs["hparams"] if "hparams" in kwargs else Model.load_hparams(args[0])
         if "text_config" in hparams:
         if "text_config" in hparams:
             hparams = {**hparams, **hparams["text_config"]}
             hparams = {**hparams, **hparams["text_config"]}
             kwargs["hparams"] = hparams
             kwargs["hparams"] = hparams
@@ -5358,7 +5358,7 @@ def main() -> None:
             logger.error(f"Model {model_architecture} is not supported")
             logger.error(f"Model {model_architecture} is not supported")
             sys.exit(1)
             sys.exit(1)
 
 
-        model_instance = model_class(dir_model=dir_model, ftype=output_type, fname_out=fname_out,
+        model_instance = model_class(dir_model, output_type, fname_out,
                                      is_big_endian=args.bigendian, use_temp_file=args.use_temp_file,
                                      is_big_endian=args.bigendian, use_temp_file=args.use_temp_file,
                                      eager=args.no_lazy,
                                      eager=args.no_lazy,
                                      metadata_override=args.metadata, model_name=args.model_name,
                                      metadata_override=args.metadata, model_name=args.model_name,