|
@@ -277,10 +277,15 @@ def parse_args() -> argparse.Namespace:
|
|
|
return parser.parse_args()
|
|
return parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
-def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]:
|
|
|
|
|
|
|
+def load_hparams_from_hf(hf_model_id: str) -> tuple[dict[str, Any], Path | None]:
|
|
|
|
|
+ from huggingface_hub import try_to_load_from_cache
|
|
|
|
|
+
|
|
|
# normally, adapter does not come with base model config, we need to load it from AutoConfig
|
|
# normally, adapter does not come with base model config, we need to load it from AutoConfig
|
|
|
config = AutoConfig.from_pretrained(hf_model_id)
|
|
config = AutoConfig.from_pretrained(hf_model_id)
|
|
|
- return config.to_dict()
|
|
|
|
|
|
|
+ cache_dir = try_to_load_from_cache(hf_model_id, "config.json")
|
|
|
|
|
+ cache_dir = Path(cache_dir).parent if isinstance(cache_dir, str) else None
|
|
|
|
|
+
|
|
|
|
|
+ return config.to_dict(), cache_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
@@ -325,13 +330,13 @@ if __name__ == '__main__':
|
|
|
# load base model
|
|
# load base model
|
|
|
if base_model_id is not None:
|
|
if base_model_id is not None:
|
|
|
logger.info(f"Loading base model from Hugging Face: {base_model_id}")
|
|
logger.info(f"Loading base model from Hugging Face: {base_model_id}")
|
|
|
- hparams = load_hparams_from_hf(base_model_id)
|
|
|
|
|
|
|
+ hparams, dir_base_model = load_hparams_from_hf(base_model_id)
|
|
|
elif dir_base_model is None:
|
|
elif dir_base_model is None:
|
|
|
if "base_model_name_or_path" in lparams:
|
|
if "base_model_name_or_path" in lparams:
|
|
|
model_id = lparams["base_model_name_or_path"]
|
|
model_id = lparams["base_model_name_or_path"]
|
|
|
logger.info(f"Loading base model from Hugging Face: {model_id}")
|
|
logger.info(f"Loading base model from Hugging Face: {model_id}")
|
|
|
try:
|
|
try:
|
|
|
- hparams = load_hparams_from_hf(model_id)
|
|
|
|
|
|
|
+ hparams, dir_base_model = load_hparams_from_hf(model_id)
|
|
|
except OSError as e:
|
|
except OSError as e:
|
|
|
logger.error(f"Failed to load base model config: {e}")
|
|
logger.error(f"Failed to load base model config: {e}")
|
|
|
logger.error("Please try downloading the base model and add its path to --base")
|
|
logger.error("Please try downloading the base model and add its path to --base")
|
|
@@ -480,6 +485,7 @@ if __name__ == '__main__':
|
|
|
dir_lora_model=dir_lora,
|
|
dir_lora_model=dir_lora,
|
|
|
lora_alpha=alpha,
|
|
lora_alpha=alpha,
|
|
|
hparams=hparams,
|
|
hparams=hparams,
|
|
|
|
|
+ remote_hf_model_id=base_model_id,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
logger.info("Exporting model...")
|
|
logger.info("Exporting model...")
|