Sfoglia il codice sorgente

Extend run-org-model.py, add (a) batching (b) loading prompt from file (c) multimodal capacity (#18034)

Piotr Wilkin (ilintar) 1 mese fa
parent
commit
8faa87db02
1 ha cambiato i file con 32 aggiunte e 9 eliminazioni
  1. 32 9
      examples/model-conversion/scripts/causal/run-org-model.py

+ 32 - 9
examples/model-conversion/scripts/causal/run-org-model.py

@@ -5,7 +5,7 @@ import os
 import importlib
 from pathlib import Path
 
-from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
+from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig
 import torch
 import numpy as np
 
@@ -116,11 +116,11 @@ def debug_hook(name):
     def fn(_m, input, output):
         if isinstance(input, torch.Tensor):
             summarize(input, name + "_in")
-        elif isinstance(input, (tuple, list)) and isinstance(input[0], torch.Tensor):
+        elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor):
             summarize(input[0], name + "_in")
         if isinstance(output, torch.Tensor):
             summarize(output, name + "_out")
-        elif isinstance(output, (tuple, list)) and isinstance(output[0], torch.Tensor):
+        elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor):
             summarize(output[0], name + "_out")
 
     return fn
@@ -130,6 +130,7 @@ unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
 
 parser = argparse.ArgumentParser(description="Process model with specified path")
 parser.add_argument("--model-path", "-m", help="Path to the model")
+parser.add_argument("--prompt-file", "-f", help="Optional prompt file", required=False)
 args = parser.parse_args()
 
 model_path = os.environ.get("MODEL_PATH", args.model_path)
@@ -142,8 +143,13 @@ if model_path is None:
 print("Loading model and tokenizer using AutoTokenizer:", model_path)
 tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
+multimodal = False
+full_config = config
 
 print("Model type:       ", config.model_type)
+if "vocab_size" not in config and "text_config" in config:
+    config = config.text_config
+    multimodal = True
 print("Vocab size:       ", config.vocab_size)
 print("Hidden size:      ", config.hidden_size)
 print("Number of layers: ", config.num_hidden_layers)
@@ -169,9 +175,14 @@ if unreleased_model_name:
         print(f"Failed to import or load model: {e}")
         exit(1)
 else:
-    model = AutoModelForCausalLM.from_pretrained(
-        model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config
-    )
+    if multimodal:
+        model = AutoModelForImageTextToText.from_pretrained(
+            model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=full_config
+        )
+    else:
+        model = AutoModelForCausalLM.from_pretrained(
+            model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config
+        )
 
 for name, module in model.named_modules():
     if len(list(module.children())) == 0:  # only leaf modules
@@ -185,7 +196,10 @@ model_name = os.path.basename(model_path)
 print(f"Model class: {model.__class__.__name__}")
 
 device = next(model.parameters()).device
-if os.getenv("MODEL_TESTING_PROMPT"):
+if args.prompt_file:
+    with open(args.prompt_file, encoding='utf-8') as f:
+        prompt = f.read()
+elif os.getenv("MODEL_TESTING_PROMPT"):
     prompt = os.getenv("MODEL_TESTING_PROMPT")
 else:
     prompt = "Hello, my name is"
@@ -195,9 +209,18 @@ print(f"Input tokens: {input_ids}")
 print(f"Input text: {repr(prompt)}")
 print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
 
+batch_size = 512
+
 with torch.no_grad():
-    outputs = model(input_ids.to(model.device))
-    logits = outputs.logits
+    past = None
+    outputs = None
+    for i in range(0, input_ids.size(1), batch_size):
+        print(f"Processing chunk with tokens {i} to {i + batch_size}")
+        chunk = input_ids[:, i:i + batch_size]
+        outputs = model(chunk.to(model.device), past_key_values=past, use_cache=True)
+        past = outputs.past_key_values
+
+    logits = outputs.logits # type: ignore
 
     # Extract logits for the last token (next token prediction)
     last_logits = logits[0, -1, :].float().cpu().numpy()