Explorar o código

model-conversion : add verbose flag in run-org-model.py (#18194)

This commit adds a --verbose flag to the run-org-model.py script to
enable or disable detailed debug output, such as input and output
tensors for each layer. Debug utilities (summarize, debug_hook,
setup_rope_debug) have been moved to utils/common.py.

The motivation for this is that the detailed debug output can be useful
for diagnosing issues with model conversion or execution, but it can
also produce a large amount of output that may not always be needed.

The script will also be further cleaned/refactored in follow-up commits.
Daniel Bevenius hai 4 semanas
pai
achega
0a271d82b4

+ 17 - 122
examples/model-conversion/scripts/causal/run-org-model.py

@@ -2,135 +2,22 @@
 
 import argparse
 import os
+import sys
 import importlib
 from pathlib import Path
 
+# Add parent directory to path for imports
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
+
 from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig
 import torch
 import numpy as np
-
-### If you want to dump RoPE activations, apply this monkey patch to the model
-### class from Transformers that you are running (replace apertus.modeling_apertus
-### with the proper package and class for your model
-### === START ROPE DEBUG ===
-# from transformers.models.apertus.modeling_apertus import apply_rotary_pos_emb
-
-# orig_rope = apply_rotary_pos_emb
-# torch.set_printoptions(threshold=float('inf'))
-# torch.set_printoptions(precision=6, sci_mode=False)
-
-# def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
-#     # log inputs
-#     summarize(q, "RoPE.q_in")
-#     summarize(k, "RoPE.k_in")
-
-#     # call original
-#     q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
-
-#     # log outputs
-#     summarize(q_out, "RoPE.q_out")
-#     summarize(k_out, "RoPE.k_out")
-
-#     return q_out, k_out
-
-# # Patch it
-# import transformers.models.apertus.modeling_apertus as apertus_mod  # noqa: E402
-# apertus_mod.apply_rotary_pos_emb = debug_rope
-### == END ROPE DEBUG ===
-
-
-def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
-    """
-    Print a tensor in llama.cpp debug style.
-
-    Supports:
-    - 2D tensors (seq, hidden)
-    - 3D tensors (batch, seq, hidden)
-    - 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
-
-    Shows first and last max_vals of each vector per sequence position.
-    """
-    t = tensor.detach().to(torch.float32).cpu()
-
-    # Determine dimensions
-    if t.ndim == 3:
-        _, s, _ = t.shape
-    elif t.ndim == 2:
-        _, s = 1, t.shape[0]
-        t = t.unsqueeze(0)
-    elif t.ndim == 4:
-        _, s, _, _ = t.shape
-    else:
-        print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
-        return
-
-    ten_shape = t.shape
-
-    print(f"ggml_debug: {name} = (f32)  ... = {{{ten_shape}}}")
-    print("                                     [")
-    print("                                      [")
-
-    # Determine indices for first and last sequences
-    first_indices = list(range(min(s, max_seq)))
-    last_indices = list(range(max(0, s - max_seq), s))
-
-    # Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
-    has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s)
-
-    # Combine indices
-    if has_overlap:
-        # If there's overlap, just use the combined unique indices
-        indices = sorted(list(set(first_indices + last_indices)))
-        separator_index = None
-    else:
-        # If no overlap, we'll add a separator between first and last sequences
-        indices = first_indices + last_indices
-        separator_index = len(first_indices)
-
-    for i, si in enumerate(indices):
-        # Add separator if needed
-        if separator_index is not None and i == separator_index:
-            print("                                       ...")
-
-        # Extract appropriate slice
-        vec = t[0, si]
-        if vec.ndim == 2:  # 4D case: flatten heads × dim_per_head
-            flat = vec.flatten().tolist()
-        else:  # 2D or 3D case
-            flat = vec.tolist()
-
-        # First and last slices
-        first = flat[:max_vals]
-        last = flat[-max_vals:] if len(flat) >= max_vals else flat
-        first_str = ", ".join(f"{v:12.4f}" for v in first)
-        last_str = ", ".join(f"{v:12.4f}" for v in last)
-
-        print(f"                                       [{first_str}, ..., {last_str}]")
-
-    print("                                      ],")
-    print("                                     ]")
-    print(f"                                     sum = {t.sum().item():.6f}\n")
-
-
-def debug_hook(name):
-    def fn(_m, input, output):
-        if isinstance(input, torch.Tensor):
-            summarize(input, name + "_in")
-        elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor):
-            summarize(input[0], name + "_in")
-        if isinstance(output, torch.Tensor):
-            summarize(output, name + "_out")
-        elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor):
-            summarize(output[0], name + "_out")
-
-    return fn
-
-
-unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
+from utils.common import debug_hook
 
 parser = argparse.ArgumentParser(description="Process model with specified path")
 parser.add_argument("--model-path", "-m", help="Path to the model")
 parser.add_argument("--prompt-file", "-f", help="Optional prompt file", required=False)
+parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose debug output")
 args = parser.parse_args()
 
 model_path = os.environ.get("MODEL_PATH", args.model_path)
@@ -139,6 +26,12 @@ if model_path is None:
         "Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
     )
 
+### If you want to dump RoPE activations, uncomment the following lines:
+### === START ROPE DEBUG ===
+# from utils.common import setup_rope_debug
+# setup_rope_debug("transformers.models.apertus.modeling_apertus")
+### == END ROPE DEBUG ===
+
 
 print("Loading model and tokenizer using AutoTokenizer:", model_path)
 tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
@@ -156,6 +49,7 @@ print("Number of layers: ", config.num_hidden_layers)
 print("BOS token id:     ", config.bos_token_id)
 print("EOS token id:     ", config.eos_token_id)
 
+unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
 if unreleased_model_name:
     model_name_lower = unreleased_model_name.lower()
     unreleased_module_path = (
@@ -184,9 +78,10 @@ else:
             model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config
         )
 
-for name, module in model.named_modules():
-    if len(list(module.children())) == 0:  # only leaf modules
-        module.register_forward_hook(debug_hook(name))
+if args.verbose:
+    for name, module in model.named_modules():
+        if len(list(module.children())) == 0:  # only leaf modules
+            module.register_forward_hook(debug_hook(name))
 
 model_name = os.path.basename(model_path)
 # Printing the Model class to allow for easier debugging. This can be useful

+ 130 - 0
examples/model-conversion/scripts/utils/common.py

@@ -2,6 +2,8 @@
 
 import os
 import sys
+import torch
+
 
 def get_model_name_from_env_path(env_path_name):
     model_path = os.getenv(env_path_name)
@@ -18,3 +20,131 @@ def get_model_name_from_env_path(env_path_name):
         name = name[:-5]
 
     return name
+
+
+def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
+    """
+    Print a tensor in llama.cpp debug style.
+
+    Supports:
+    - 2D tensors (seq, hidden)
+    - 3D tensors (batch, seq, hidden)
+    - 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
+
+    Shows first and last max_vals of each vector per sequence position.
+    """
+    t = tensor.detach().to(torch.float32).cpu()
+
+    # Determine dimensions
+    if t.ndim == 3:
+        _, s, _ = t.shape
+    elif t.ndim == 2:
+        _, s = 1, t.shape[0]
+        t = t.unsqueeze(0)
+    elif t.ndim == 4:
+        _, s, _, _ = t.shape
+    else:
+        print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
+        return
+
+    ten_shape = t.shape
+
+    print(f"ggml_debug: {name} = (f32)  ... = {{{ten_shape}}}")
+    print("                                     [")
+    print("                                      [")
+
+    # Determine indices for first and last sequences
+    first_indices = list(range(min(s, max_seq)))
+    last_indices = list(range(max(0, s - max_seq), s))
+
+    # Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
+    has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s)
+
+    # Combine indices
+    if has_overlap:
+        # If there's overlap, just use the combined unique indices
+        indices = sorted(list(set(first_indices + last_indices)))
+        separator_index = None
+    else:
+        # If no overlap, we'll add a separator between first and last sequences
+        indices = first_indices + last_indices
+        separator_index = len(first_indices)
+
+    for i, si in enumerate(indices):
+        # Add separator if needed
+        if separator_index is not None and i == separator_index:
+            print("                                       ...")
+
+        # Extract appropriate slice
+        vec = t[0, si]
+        if vec.ndim == 2:  # 4D case: flatten heads × dim_per_head
+            flat = vec.flatten().tolist()
+        else:  # 2D or 3D case
+            flat = vec.tolist()
+
+        # First and last slices
+        first = flat[:max_vals]
+        last = flat[-max_vals:] if len(flat) >= max_vals else flat
+        first_str = ", ".join(f"{v:12.4f}" for v in first)
+        last_str = ", ".join(f"{v:12.4f}" for v in last)
+
+        print(f"                                       [{first_str}, ..., {last_str}]")
+
+    print("                                      ],")
+    print("                                     ]")
+    print(f"                                     sum = {t.sum().item():.6f}\n")
+
+
+def debug_hook(name):
+    def fn(_m, input, output):
+        if isinstance(input, torch.Tensor):
+            summarize(input, name + "_in")
+        elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor):
+            summarize(input[0], name + "_in")
+        if isinstance(output, torch.Tensor):
+            summarize(output, name + "_out")
+        elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor):
+            summarize(output[0], name + "_out")
+
+    return fn
+
+
+def setup_rope_debug(model_module_path: str, function_name: str = "apply_rotary_pos_emb"):
+    """
+    Apply monkey patch to dump RoPE activations for debugging.
+
+    Args:
+        model_module_path: Path to the model module (e.g., "transformers.models.apertus.modeling_apertus")
+        function_name: Name of the RoPE function to patch (default: "apply_rotary_pos_emb")
+
+    Example:
+        from utils.common import setup_rope_debug
+        setup_rope_debug("transformers.models.apertus.modeling_apertus")
+    """
+    import importlib
+
+    # Import the module and get the original function
+    module = importlib.import_module(model_module_path)
+    orig_rope = getattr(module, function_name)
+
+    # Set torch print options for better debugging
+    torch.set_printoptions(threshold=float('inf'))
+    torch.set_printoptions(precision=6, sci_mode=False)
+
+    def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+        # log inputs
+        summarize(q, "RoPE.q_in")
+        summarize(k, "RoPE.k_in")
+
+        # call original
+        q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
+
+        # log outputs
+        summarize(q_out, "RoPE.q_out")
+        summarize(k_out, "RoPE.k_out")
+
+        return q_out, k_out
+
+    # Patch it
+    setattr(module, function_name, debug_rope)
+    print(f"RoPE debug patching applied to {model_module_path}.{function_name}")