Bläddra i källkod

model-conversion : pass config to from_pretrained (#16963)

This commit modifies the script `run-org-model.py` to ensure that the
model configuration is explicitly passed to the `from_pretrained` method
when loading the model. It also removes a duplicate configuration
loading which was a mistake.

The motivation for this change is that enables the config object to be
modified and then passed to the model loading function, which can be
useful when testing new models.
Daniel Bevenius 2 månader sedan
förälder
incheckning
ed8aa63320
1 ändrade filer med 4 tillägg och 5 borttagningar
  1. 4 5
      examples/model-conversion/scripts/causal/run-org-model.py

+ 4 - 5
examples/model-conversion/scripts/causal/run-org-model.py

@@ -138,6 +138,9 @@ if model_path is None:
         "Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
     )
 
+
+print("Loading model and tokenizer using AutoTokenizer:", model_path)
+tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
 
 print("Model type:       ", config.model_type)
@@ -147,10 +150,6 @@ print("Number of layers: ", config.num_hidden_layers)
 print("BOS token id:     ", config.bos_token_id)
 print("EOS token id:     ", config.eos_token_id)
 
-print("Loading model and tokenizer using AutoTokenizer:", model_path)
-tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
-config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
-
 if unreleased_model_name:
     model_name_lower = unreleased_model_name.lower()
     unreleased_module_path = (
@@ -171,7 +170,7 @@ if unreleased_model_name:
         exit(1)
 else:
     model = AutoModelForCausalLM.from_pretrained(
-        model_path, device_map="auto", offload_folder="offload", trust_remote_code=True
+        model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config
     )
 
 for name, module in model.named_modules():