Przeglądaj źródła

convert : fix TypeError when loading base model remotely in convert_lora_to_gguf (#17385)

* fix: TypeError when loading base model remotely in convert_lora_to_gguf

* refactor: simplify base model loading using cache_dir from HuggingFace

* Update convert_lora_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* feat: add remote_hf_model_id to trigger lazy mode in LoRA converter

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
o7si 1 miesiąc temu
rodzic
commit
5088b435d4
1 zmienionych plików z 10 dodań i 4 usunięć
  1. 10 4
      convert_lora_to_gguf.py

+ 10 - 4
convert_lora_to_gguf.py

@@ -277,10 +277,15 @@ def parse_args() -> argparse.Namespace:
     return parser.parse_args()
 
 
-def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]:
+def load_hparams_from_hf(hf_model_id: str) -> tuple[dict[str, Any], Path | None]:
+    from huggingface_hub import try_to_load_from_cache
+
     # normally, adapter does not come with base model config, we need to load it from AutoConfig
     config = AutoConfig.from_pretrained(hf_model_id)
-    return config.to_dict()
+    cache_dir = try_to_load_from_cache(hf_model_id, "config.json")
+    cache_dir = Path(cache_dir).parent if isinstance(cache_dir, str) else None
+
+    return config.to_dict(), cache_dir
 
 
 if __name__ == '__main__':
@@ -325,13 +330,13 @@ if __name__ == '__main__':
     # load base model
     if base_model_id is not None:
         logger.info(f"Loading base model from Hugging Face: {base_model_id}")
-        hparams = load_hparams_from_hf(base_model_id)
+        hparams, dir_base_model = load_hparams_from_hf(base_model_id)
     elif dir_base_model is None:
         if "base_model_name_or_path" in lparams:
             model_id = lparams["base_model_name_or_path"]
             logger.info(f"Loading base model from Hugging Face: {model_id}")
             try:
-                hparams = load_hparams_from_hf(model_id)
+                hparams, dir_base_model = load_hparams_from_hf(model_id)
             except OSError as e:
                 logger.error(f"Failed to load base model config: {e}")
                 logger.error("Please try downloading the base model and add its path to --base")
@@ -480,6 +485,7 @@ if __name__ == '__main__':
             dir_lora_model=dir_lora,
             lora_alpha=alpha,
             hparams=hparams,
+            remote_hf_model_id=base_model_id,
         )
 
         logger.info("Exporting model...")