run-org-model.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. #!/usr/bin/env python3
  2. import argparse
  3. import os
  4. import importlib
  5. from pathlib import Path
  6. from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig
  7. import torch
  8. import numpy as np
  9. ### If you want to dump RoPE activations, apply this monkey patch to the model
  10. ### class from Transformers that you are running (replace apertus.modeling_apertus
  11. ### with the proper package and class for your model
  12. ### === START ROPE DEBUG ===
  13. # from transformers.models.apertus.modeling_apertus import apply_rotary_pos_emb
  14. # orig_rope = apply_rotary_pos_emb
  15. # torch.set_printoptions(threshold=float('inf'))
  16. # torch.set_printoptions(precision=6, sci_mode=False)
  17. # def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  18. # # log inputs
  19. # summarize(q, "RoPE.q_in")
  20. # summarize(k, "RoPE.k_in")
  21. # # call original
  22. # q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
  23. # # log outputs
  24. # summarize(q_out, "RoPE.q_out")
  25. # summarize(k_out, "RoPE.k_out")
  26. # return q_out, k_out
  27. # # Patch it
  28. # import transformers.models.apertus.modeling_apertus as apertus_mod # noqa: E402
  29. # apertus_mod.apply_rotary_pos_emb = debug_rope
  30. ### == END ROPE DEBUG ===
  31. def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
  32. """
  33. Print a tensor in llama.cpp debug style.
  34. Supports:
  35. - 2D tensors (seq, hidden)
  36. - 3D tensors (batch, seq, hidden)
  37. - 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
  38. Shows first and last max_vals of each vector per sequence position.
  39. """
  40. t = tensor.detach().to(torch.float32).cpu()
  41. # Determine dimensions
  42. if t.ndim == 3:
  43. _, s, _ = t.shape
  44. elif t.ndim == 2:
  45. _, s = 1, t.shape[0]
  46. t = t.unsqueeze(0)
  47. elif t.ndim == 4:
  48. _, s, _, _ = t.shape
  49. else:
  50. print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
  51. return
  52. ten_shape = t.shape
  53. print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}")
  54. print(" [")
  55. print(" [")
  56. # Determine indices for first and last sequences
  57. first_indices = list(range(min(s, max_seq)))
  58. last_indices = list(range(max(0, s - max_seq), s))
  59. # Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
  60. has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s)
  61. # Combine indices
  62. if has_overlap:
  63. # If there's overlap, just use the combined unique indices
  64. indices = sorted(list(set(first_indices + last_indices)))
  65. separator_index = None
  66. else:
  67. # If no overlap, we'll add a separator between first and last sequences
  68. indices = first_indices + last_indices
  69. separator_index = len(first_indices)
  70. for i, si in enumerate(indices):
  71. # Add separator if needed
  72. if separator_index is not None and i == separator_index:
  73. print(" ...")
  74. # Extract appropriate slice
  75. vec = t[0, si]
  76. if vec.ndim == 2: # 4D case: flatten heads × dim_per_head
  77. flat = vec.flatten().tolist()
  78. else: # 2D or 3D case
  79. flat = vec.tolist()
  80. # First and last slices
  81. first = flat[:max_vals]
  82. last = flat[-max_vals:] if len(flat) >= max_vals else flat
  83. first_str = ", ".join(f"{v:12.4f}" for v in first)
  84. last_str = ", ".join(f"{v:12.4f}" for v in last)
  85. print(f" [{first_str}, ..., {last_str}]")
  86. print(" ],")
  87. print(" ]")
  88. print(f" sum = {t.sum().item():.6f}\n")
  89. def debug_hook(name):
  90. def fn(_m, input, output):
  91. if isinstance(input, torch.Tensor):
  92. summarize(input, name + "_in")
  93. elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor):
  94. summarize(input[0], name + "_in")
  95. if isinstance(output, torch.Tensor):
  96. summarize(output, name + "_out")
  97. elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor):
  98. summarize(output[0], name + "_out")
  99. return fn
  100. unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
  101. parser = argparse.ArgumentParser(description="Process model with specified path")
  102. parser.add_argument("--model-path", "-m", help="Path to the model")
  103. parser.add_argument("--prompt-file", "-f", help="Optional prompt file", required=False)
  104. args = parser.parse_args()
  105. model_path = os.environ.get("MODEL_PATH", args.model_path)
  106. if model_path is None:
  107. parser.error(
  108. "Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
  109. )
  110. print("Loading model and tokenizer using AutoTokenizer:", model_path)
  111. tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
  112. config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
  113. multimodal = False
  114. full_config = config
  115. print("Model type: ", config.model_type)
  116. if "vocab_size" not in config and "text_config" in config:
  117. config = config.text_config
  118. multimodal = True
  119. print("Vocab size: ", config.vocab_size)
  120. print("Hidden size: ", config.hidden_size)
  121. print("Number of layers: ", config.num_hidden_layers)
  122. print("BOS token id: ", config.bos_token_id)
  123. print("EOS token id: ", config.eos_token_id)
  124. if unreleased_model_name:
  125. model_name_lower = unreleased_model_name.lower()
  126. unreleased_module_path = (
  127. f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
  128. )
  129. class_name = f"{unreleased_model_name}ForCausalLM"
  130. print(f"Importing unreleased model module: {unreleased_module_path}")
  131. try:
  132. model_class = getattr(
  133. importlib.import_module(unreleased_module_path), class_name
  134. )
  135. model = model_class.from_pretrained(
  136. model_path
  137. ) # Note: from_pretrained, not fromPretrained
  138. except (ImportError, AttributeError) as e:
  139. print(f"Failed to import or load model: {e}")
  140. exit(1)
  141. else:
  142. if multimodal:
  143. model = AutoModelForImageTextToText.from_pretrained(
  144. model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=full_config
  145. )
  146. else:
  147. model = AutoModelForCausalLM.from_pretrained(
  148. model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config
  149. )
  150. for name, module in model.named_modules():
  151. if len(list(module.children())) == 0: # only leaf modules
  152. module.register_forward_hook(debug_hook(name))
  153. model_name = os.path.basename(model_path)
  154. # Printing the Model class to allow for easier debugging. This can be useful
  155. # when working with models that have not been publicly released yet and this
  156. # migth require that the concrete class is imported and used directly instead
  157. # of using AutoModelForCausalLM.
  158. print(f"Model class: {model.__class__.__name__}")
  159. device = next(model.parameters()).device
  160. if args.prompt_file:
  161. with open(args.prompt_file, encoding='utf-8') as f:
  162. prompt = f.read()
  163. elif os.getenv("MODEL_TESTING_PROMPT"):
  164. prompt = os.getenv("MODEL_TESTING_PROMPT")
  165. else:
  166. prompt = "Hello, my name is"
  167. input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
  168. print(f"Input tokens: {input_ids}")
  169. print(f"Input text: {repr(prompt)}")
  170. print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
  171. batch_size = 512
  172. with torch.no_grad():
  173. past = None
  174. outputs = None
  175. for i in range(0, input_ids.size(1), batch_size):
  176. print(f"Processing chunk with tokens {i} to {i + batch_size}")
  177. chunk = input_ids[:, i:i + batch_size]
  178. outputs = model(chunk.to(model.device), past_key_values=past, use_cache=True)
  179. past = outputs.past_key_values
  180. logits = outputs.logits # type: ignore
  181. # Extract logits for the last token (next token prediction)
  182. last_logits = logits[0, -1, :].float().cpu().numpy()
  183. print(f"Logits shape: {logits.shape}")
  184. print(f"Last token logits shape: {last_logits.shape}")
  185. print(f"Vocab size: {len(last_logits)}")
  186. data_dir = Path("data")
  187. data_dir.mkdir(exist_ok=True)
  188. bin_filename = data_dir / f"pytorch-{model_name}.bin"
  189. txt_filename = data_dir / f"pytorch-{model_name}.txt"
  190. # Save to file for comparison
  191. last_logits.astype(np.float32).tofile(bin_filename)
  192. # Also save as text file for easy inspection
  193. with open(txt_filename, "w") as f:
  194. for i, logit in enumerate(last_logits):
  195. f.write(f"{i}: {logit:.6f}\n")
  196. # Print some sample logits for quick verification
  197. print(f"First 10 logits: {last_logits[:10]}")
  198. print(f"Last 10 logits: {last_logits[-10:]}")
  199. # Show top 5 predicted tokens
  200. top_indices = np.argsort(last_logits)[-5:][::-1]
  201. print("Top 5 predictions:")
  202. for idx in top_indices:
  203. token = tokenizer.decode([idx])
  204. print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
  205. print(f"Saved bin logits to: {bin_filename}")
  206. print(f"Saved txt logist to: {txt_filename}")