common.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. #!/usr/bin/env python3
  2. import os
  3. import sys
  4. import torch
  5. import numpy as np
  6. from pathlib import Path
  7. def get_model_name_from_env_path(env_path_name):
  8. model_path = os.getenv(env_path_name)
  9. if not model_path:
  10. print(f"Error: {env_path_name} environment variable not set")
  11. sys.exit(1)
  12. if not os.path.exists(model_path):
  13. print(f"Error: Model file not found: {model_path}")
  14. sys.exit(1)
  15. name = os.path.basename(os.path.normpath(model_path))
  16. if name.endswith(".gguf"):
  17. name = name[:-5]
  18. return name
  19. def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
  20. """
  21. Print a tensor in llama.cpp debug style.
  22. Supports:
  23. - 2D tensors (seq, hidden)
  24. - 3D tensors (batch, seq, hidden)
  25. - 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
  26. Shows first and last max_vals of each vector per sequence position.
  27. """
  28. t = tensor.detach().to(torch.float32).cpu()
  29. # Determine dimensions
  30. if t.ndim == 3:
  31. _, s, _ = t.shape
  32. elif t.ndim == 2:
  33. _, s = 1, t.shape[0]
  34. t = t.unsqueeze(0)
  35. elif t.ndim == 4:
  36. _, s, _, _ = t.shape
  37. else:
  38. print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
  39. return
  40. ten_shape = t.shape
  41. print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}")
  42. print(" [")
  43. print(" [")
  44. # Determine indices for first and last sequences
  45. first_indices = list(range(min(s, max_seq)))
  46. last_indices = list(range(max(0, s - max_seq), s))
  47. # Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
  48. has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s)
  49. # Combine indices
  50. if has_overlap:
  51. # If there's overlap, just use the combined unique indices
  52. indices = sorted(list(set(first_indices + last_indices)))
  53. separator_index = None
  54. else:
  55. # If no overlap, we'll add a separator between first and last sequences
  56. indices = first_indices + last_indices
  57. separator_index = len(first_indices)
  58. for i, si in enumerate(indices):
  59. # Add separator if needed
  60. if separator_index is not None and i == separator_index:
  61. print(" ...")
  62. # Extract appropriate slice
  63. vec = t[0, si]
  64. if vec.ndim == 2: # 4D case: flatten heads × dim_per_head
  65. flat = vec.flatten().tolist()
  66. else: # 2D or 3D case
  67. flat = vec.tolist()
  68. # First and last slices
  69. first = flat[:max_vals]
  70. last = flat[-max_vals:] if len(flat) >= max_vals else flat
  71. first_str = ", ".join(f"{v:12.4f}" for v in first)
  72. last_str = ", ".join(f"{v:12.4f}" for v in last)
  73. print(f" [{first_str}, ..., {last_str}]")
  74. print(" ],")
  75. print(" ]")
  76. print(f" sum = {t.sum().item():.6f}\n")
  77. def debug_hook(name):
  78. def fn(_m, input, output):
  79. if isinstance(input, torch.Tensor):
  80. summarize(input, name + "_in")
  81. elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor):
  82. summarize(input[0], name + "_in")
  83. if isinstance(output, torch.Tensor):
  84. summarize(output, name + "_out")
  85. elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor):
  86. summarize(output[0], name + "_out")
  87. return fn
  88. def setup_rope_debug(model_module_path: str, function_name: str = "apply_rotary_pos_emb"):
  89. """
  90. Apply monkey patch to dump RoPE activations for debugging.
  91. Args:
  92. model_module_path: Path to the model module (e.g., "transformers.models.apertus.modeling_apertus")
  93. function_name: Name of the RoPE function to patch (default: "apply_rotary_pos_emb")
  94. Example:
  95. from utils.common import setup_rope_debug
  96. setup_rope_debug("transformers.models.apertus.modeling_apertus")
  97. """
  98. import importlib
  99. # Import the module and get the original function
  100. module = importlib.import_module(model_module_path)
  101. orig_rope = getattr(module, function_name)
  102. # Set torch print options for better debugging
  103. torch.set_printoptions(threshold=float('inf'))
  104. torch.set_printoptions(precision=6, sci_mode=False)
  105. def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  106. # log inputs
  107. summarize(q, "RoPE.q_in")
  108. summarize(k, "RoPE.k_in")
  109. # call original
  110. q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
  111. # log outputs
  112. summarize(q_out, "RoPE.q_out")
  113. summarize(k_out, "RoPE.k_out")
  114. return q_out, k_out
  115. # Patch it
  116. setattr(module, function_name, debug_rope)
  117. print(f"RoPE debug patching applied to {model_module_path}.{function_name}")
  118. def save_output_data(data, tokens, prompt, model_name, type_suffix="", output_dir="data"):
  119. """
  120. Save output data (logits/embeddings), tokens, and prompt to files.
  121. Args:
  122. data: numpy array of floats (logits or embeddings)
  123. tokens: list or array of token IDs
  124. prompt: string containing the input prompt
  125. model_name: name of the model
  126. type_suffix: optional suffix like "-embeddings" (default: "")
  127. output_dir: directory to save files (default: "data")
  128. Creates the following files in output_dir:
  129. - pytorch-{model_name}{type_suffix}.bin
  130. - pytorch-{model_name}{type_suffix}.txt
  131. - pytorch-{model_name}{type_suffix}-prompt.txt
  132. - pytorch-{model_name}{type_suffix}-tokens.bin
  133. """
  134. data_dir = Path(output_dir)
  135. data_dir.mkdir(exist_ok=True)
  136. base_path = data_dir / f"pytorch-{model_name}{type_suffix}"
  137. # Convert and flatten logits/embeddings
  138. data = data.cpu().numpy() if isinstance(data, torch.Tensor) else np.asarray(data)
  139. data = data.flatten() if data.ndim > 1 else data
  140. # Save logits/embedding files
  141. data.astype(np.float32).tofile(f"{base_path}.bin")
  142. print(f"Data saved to {base_path}.bin")
  143. with open(f"{base_path}.txt", "w") as f:
  144. f.writelines(f"{i}: {value:.6f}\n" for i, value in enumerate(data))
  145. print(f"Data saved to {base_path}.txt")
  146. # Convert and flatten tokens
  147. tokens = tokens.cpu().numpy() if isinstance(tokens, torch.Tensor) else np.asarray(tokens)
  148. tokens = tokens.flatten() if tokens.ndim > 1 else tokens
  149. # Save token binary file
  150. tokens.astype(np.int32).tofile(f"{base_path}-tokens.bin")
  151. print(f"Tokens saved to {base_path}-tokens.bin")
  152. # Save prompt file
  153. with open(f"{base_path}-prompt.txt", "w") as f:
  154. f.write(f"prompt: {prompt}\n")
  155. f.write(f"n_tokens: {len(tokens)}\n")
  156. f.write(f"token ids: {', '.join(str(int(tid)) for tid in tokens)}\n")
  157. print(f"Prompt saved to {base_path}-prompt.txt")
  158. def compare_tokens(original, converted, type_suffix="", output_dir="data"):
  159. data_dir = Path(output_dir)
  160. # Read tokens from both models
  161. tokens1_file = data_dir / f"{original}{type_suffix}-tokens.bin"
  162. tokens2_file = data_dir / f"{converted}{type_suffix}-tokens.bin"
  163. if not tokens1_file.exists():
  164. print(f"Error: Token file not found: {tokens1_file}")
  165. return False
  166. if not tokens2_file.exists():
  167. print(f"Error: Token file not found: {tokens2_file}")
  168. return False
  169. tokens1 = np.fromfile(tokens1_file, dtype=np.int32)
  170. tokens2 = np.fromfile(tokens2_file, dtype=np.int32)
  171. print(f"\nComparing tokens between:")
  172. print(f" Original : {original} ({len(tokens1)} tokens)")
  173. print(f" Converted: {converted} ({len(tokens2)} tokens)")
  174. if len(tokens1) != len(tokens2):
  175. print(f"\n❌ Token count mismatch: {len(tokens1)} vs {len(tokens2)}")
  176. return False
  177. if np.array_equal(tokens1, tokens2):
  178. print(f"\n✅ All {len(tokens1)} tokens match!")
  179. return True
  180. mismatches = np.where(tokens1 != tokens2)[0]
  181. print(f"\n❌ Found {len(mismatches)} mismatched tokens:")
  182. num_to_show = min(len(mismatches), 10)
  183. for idx in mismatches[:num_to_show]:
  184. print(f" Position {idx}: {tokens1[idx]} vs {tokens2[idx]}")
  185. if len(mismatches) > num_to_show:
  186. print(f" ... and {len(mismatches) - num_to_show} more mismatches")
  187. return False