Просмотр исходного кода

model-conversion : pass config to from_pretrained (#16963)

This commit modifies the script `run-org-model.py` to ensure that the
model configuration is explicitly passed to the `from_pretrained` method
when loading the model. It also removes a duplicate configuration
loading which was a mistake.

The motivation for this change is that enables the config object to be
modified and then passed to the model loading function, which can be
useful when testing new models.
Daniel Bevenius 2 месяцев назад
Родитель
Сommit
ed8aa63320
1 измененных файлов с 4 добавлено и 5 удалено
  1. 4 5
      examples/model-conversion/scripts/causal/run-org-model.py

+ 4 - 5
examples/model-conversion/scripts/causal/run-org-model.py

@@ -138,6 +138,9 @@ if model_path is None:
         "Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
     )
 
+
+print("Loading model and tokenizer using AutoTokenizer:", model_path)
+tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
 
 print("Model type:       ", config.model_type)
@@ -147,10 +150,6 @@ print("Number of layers: ", config.num_hidden_layers)
 print("BOS token id:     ", config.bos_token_id)
 print("EOS token id:     ", config.eos_token_id)
 
-print("Loading model and tokenizer using AutoTokenizer:", model_path)
-tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
-config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
-
 if unreleased_model_name:
     model_name_lower = unreleased_model_name.lower()
     unreleased_module_path = (
@@ -171,7 +170,7 @@ if unreleased_model_name:
         exit(1)
 else:
     model = AutoModelForCausalLM.from_pretrained(
-        model_path, device_map="auto", offload_folder="offload", trust_remote_code=True
+        model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config
     )
 
 for name, module in model.named_modules():