run-org-model.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. #!/usr/bin/env python3
  2. import argparse
  3. import os
  4. import importlib
  5. from pathlib import Path
  6. from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
  7. import torch
  8. import numpy as np
  9. ### If you want to dump RoPE activations, apply this monkey patch to the model
  10. ### class from Transformers that you are running (replace apertus.modeling_apertus
  11. ### with the proper package and class for your model
  12. ### === START ROPE DEBUG ===
  13. # from transformers.models.apertus.modeling_apertus import apply_rotary_pos_emb
  14. # orig_rope = apply_rotary_pos_emb
  15. # torch.set_printoptions(threshold=float('inf'))
  16. # torch.set_printoptions(precision=6, sci_mode=False)
  17. # def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  18. # # log inputs
  19. # summarize(q, "RoPE.q_in")
  20. # summarize(k, "RoPE.k_in")
  21. # # call original
  22. # q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
  23. # # log outputs
  24. # summarize(q_out, "RoPE.q_out")
  25. # summarize(k_out, "RoPE.k_out")
  26. # return q_out, k_out
  27. # # Patch it
  28. # import transformers.models.apertus.modeling_apertus as apertus_mod # noqa: E402
  29. # apertus_mod.apply_rotary_pos_emb = debug_rope
  30. ### == END ROPE DEBUG ===
  31. def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
  32. """
  33. Print a tensor in llama.cpp debug style.
  34. Supports:
  35. - 2D tensors (seq, hidden)
  36. - 3D tensors (batch, seq, hidden)
  37. - 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
  38. Shows first and last max_vals of each vector per sequence position.
  39. """
  40. t = tensor.detach().to(torch.float32).cpu()
  41. # Determine dimensions
  42. if t.ndim == 3:
  43. _, s, _ = t.shape
  44. elif t.ndim == 2:
  45. _, s = 1, t.shape[0]
  46. t = t.unsqueeze(0)
  47. elif t.ndim == 4:
  48. _, s, _, _ = t.shape
  49. else:
  50. print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
  51. return
  52. ten_shape = t.shape
  53. print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}")
  54. print(" [")
  55. print(" [")
  56. # Determine indices for first and last sequences
  57. first_indices = list(range(min(s, max_seq)))
  58. last_indices = list(range(max(0, s - max_seq), s))
  59. # Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
  60. has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s)
  61. # Combine indices
  62. if has_overlap:
  63. # If there's overlap, just use the combined unique indices
  64. indices = sorted(list(set(first_indices + last_indices)))
  65. separator_index = None
  66. else:
  67. # If no overlap, we'll add a separator between first and last sequences
  68. indices = first_indices + last_indices
  69. separator_index = len(first_indices)
  70. for i, si in enumerate(indices):
  71. # Add separator if needed
  72. if separator_index is not None and i == separator_index:
  73. print(" ...")
  74. # Extract appropriate slice
  75. vec = t[0, si]
  76. if vec.ndim == 2: # 4D case: flatten heads × dim_per_head
  77. flat = vec.flatten().tolist()
  78. else: # 2D or 3D case
  79. flat = vec.tolist()
  80. # First and last slices
  81. first = flat[:max_vals]
  82. last = flat[-max_vals:] if len(flat) >= max_vals else flat
  83. first_str = ", ".join(f"{v:12.4f}" for v in first)
  84. last_str = ", ".join(f"{v:12.4f}" for v in last)
  85. print(f" [{first_str}, ..., {last_str}]")
  86. print(" ],")
  87. print(" ]")
  88. print(f" sum = {t.sum().item():.6f}\n")
  89. def debug_hook(name):
  90. def fn(_m, input, output):
  91. if isinstance(input, torch.Tensor):
  92. summarize(input, name + "_in")
  93. elif isinstance(input, (tuple, list)) and isinstance(input[0], torch.Tensor):
  94. summarize(input[0], name + "_in")
  95. if isinstance(output, torch.Tensor):
  96. summarize(output, name + "_out")
  97. elif isinstance(output, (tuple, list)) and isinstance(output[0], torch.Tensor):
  98. summarize(output[0], name + "_out")
  99. return fn
  100. unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
  101. parser = argparse.ArgumentParser(description="Process model with specified path")
  102. parser.add_argument("--model-path", "-m", help="Path to the model")
  103. args = parser.parse_args()
  104. model_path = os.environ.get("MODEL_PATH", args.model_path)
  105. if model_path is None:
  106. parser.error(
  107. "Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
  108. )
  109. config = AutoConfig.from_pretrained(model_path)
  110. print("Model type: ", config.model_type)
  111. print("Vocab size: ", config.vocab_size)
  112. print("Hidden size: ", config.hidden_size)
  113. print("Number of layers: ", config.num_hidden_layers)
  114. print("BOS token id: ", config.bos_token_id)
  115. print("EOS token id: ", config.eos_token_id)
  116. print("Loading model and tokenizer using AutoTokenizer:", model_path)
  117. tokenizer = AutoTokenizer.from_pretrained(model_path)
  118. config = AutoConfig.from_pretrained(model_path)
  119. if unreleased_model_name:
  120. model_name_lower = unreleased_model_name.lower()
  121. unreleased_module_path = (
  122. f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
  123. )
  124. class_name = f"{unreleased_model_name}ForCausalLM"
  125. print(f"Importing unreleased model module: {unreleased_module_path}")
  126. try:
  127. model_class = getattr(
  128. importlib.import_module(unreleased_module_path), class_name
  129. )
  130. model = model_class.from_pretrained(
  131. model_path
  132. ) # Note: from_pretrained, not fromPretrained
  133. except (ImportError, AttributeError) as e:
  134. print(f"Failed to import or load model: {e}")
  135. exit(1)
  136. else:
  137. model = AutoModelForCausalLM.from_pretrained(
  138. model_path, device_map="auto", offload_folder="offload"
  139. )
  140. for name, module in model.named_modules():
  141. if len(list(module.children())) == 0: # only leaf modules
  142. module.register_forward_hook(debug_hook(name))
  143. model_name = os.path.basename(model_path)
  144. # Printing the Model class to allow for easier debugging. This can be useful
  145. # when working with models that have not been publicly released yet and this
  146. # migth require that the concrete class is imported and used directly instead
  147. # of using AutoModelForCausalLM.
  148. print(f"Model class: {model.__class__.__name__}")
  149. prompt = "Hello, my name is"
  150. input_ids = tokenizer(prompt, return_tensors="pt").input_ids
  151. print(f"Input tokens: {input_ids}")
  152. print(f"Input text: {repr(prompt)}")
  153. print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
  154. with torch.no_grad():
  155. outputs = model(input_ids.to(model.device))
  156. logits = outputs.logits
  157. # Extract logits for the last token (next token prediction)
  158. last_logits = logits[0, -1, :].cpu().numpy()
  159. print(f"Logits shape: {logits.shape}")
  160. print(f"Last token logits shape: {last_logits.shape}")
  161. print(f"Vocab size: {len(last_logits)}")
  162. data_dir = Path("data")
  163. data_dir.mkdir(exist_ok=True)
  164. bin_filename = data_dir / f"pytorch-{model_name}.bin"
  165. txt_filename = data_dir / f"pytorch-{model_name}.txt"
  166. # Save to file for comparison
  167. last_logits.astype(np.float32).tofile(bin_filename)
  168. # Also save as text file for easy inspection
  169. with open(txt_filename, "w") as f:
  170. for i, logit in enumerate(last_logits):
  171. f.write(f"{i}: {logit:.6f}\n")
  172. # Print some sample logits for quick verification
  173. print(f"First 10 logits: {last_logits[:10]}")
  174. print(f"Last 10 logits: {last_logits[-10:]}")
  175. # Show top 5 predicted tokens
  176. top_indices = np.argsort(last_logits)[-5:][::-1]
  177. print("Top 5 predictions:")
  178. for idx in top_indices:
  179. token = tokenizer.decode([idx])
  180. print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
  181. print(f"Saved bin logits to: {bin_filename}")
  182. print(f"Saved txt logist to: {txt_filename}")