common.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. #!/usr/bin/env python3
  2. import os
  3. import sys
  4. import torch
  5. import transformers
  6. import json
  7. import textwrap
  8. import numpy as np
  9. from pathlib import Path
  10. def get_model_name_from_env_path(env_path_name):
  11. model_path = os.getenv(env_path_name)
  12. if not model_path:
  13. print(f"Error: {env_path_name} environment variable not set")
  14. sys.exit(1)
  15. if not os.path.exists(model_path):
  16. print(f"Error: Model file not found: {model_path}")
  17. sys.exit(1)
  18. name = os.path.basename(os.path.normpath(model_path))
  19. if name.endswith(".gguf"):
  20. name = name[:-5]
  21. return name
  22. def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
  23. """
  24. Print a tensor in llama.cpp debug style.
  25. Supports:
  26. - 2D tensors (seq, hidden)
  27. - 3D tensors (batch, seq, hidden)
  28. - 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
  29. Shows first and last max_vals of each vector per sequence position.
  30. """
  31. t = tensor.detach().to(torch.float32).cpu()
  32. # Determine dimensions
  33. if t.ndim == 3:
  34. _, s, _ = t.shape
  35. elif t.ndim == 2:
  36. _, s = 1, t.shape[0]
  37. t = t.unsqueeze(0)
  38. elif t.ndim == 4:
  39. _, s, _, _ = t.shape
  40. else:
  41. print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
  42. return
  43. ten_shape = t.shape
  44. print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}")
  45. print(" [")
  46. print(" [")
  47. # Determine indices for first and last sequences
  48. first_indices = list(range(min(s, max_seq)))
  49. last_indices = list(range(max(0, s - max_seq), s))
  50. # Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
  51. has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s)
  52. # Combine indices
  53. if has_overlap:
  54. # If there's overlap, just use the combined unique indices
  55. indices = sorted(list(set(first_indices + last_indices)))
  56. separator_index = None
  57. else:
  58. # If no overlap, we'll add a separator between first and last sequences
  59. indices = first_indices + last_indices
  60. separator_index = len(first_indices)
  61. for i, si in enumerate(indices):
  62. # Add separator if needed
  63. if separator_index is not None and i == separator_index:
  64. print(" ...")
  65. # Extract appropriate slice
  66. vec = t[0, si]
  67. if vec.ndim == 2: # 4D case: flatten heads × dim_per_head
  68. flat = vec.flatten().tolist()
  69. else: # 2D or 3D case
  70. flat = vec.tolist()
  71. # First and last slices
  72. first = flat[:max_vals]
  73. last = flat[-max_vals:] if len(flat) >= max_vals else flat
  74. first_str = ", ".join(f"{v:12.4f}" for v in first)
  75. last_str = ", ".join(f"{v:12.4f}" for v in last)
  76. print(f" [{first_str}, ..., {last_str}]")
  77. print(" ],")
  78. print(" ]")
  79. print(f" sum = {t.sum().item():.6f}\n")
  80. def debug_hook(name):
  81. def fn(_m, input, output):
  82. if isinstance(input, torch.Tensor):
  83. summarize(input, name + "_in")
  84. elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor):
  85. summarize(input[0], name + "_in")
  86. if isinstance(output, torch.Tensor):
  87. summarize(output, name + "_out")
  88. elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor):
  89. summarize(output[0], name + "_out")
  90. return fn
  91. def setup_rope_debug(model_module_path: str, function_name: str = "apply_rotary_pos_emb"):
  92. """
  93. Apply monkey patch to dump RoPE activations for debugging.
  94. Args:
  95. model_module_path: Path to the model module (e.g., "transformers.models.apertus.modeling_apertus")
  96. function_name: Name of the RoPE function to patch (default: "apply_rotary_pos_emb")
  97. Example:
  98. from utils.common import setup_rope_debug
  99. setup_rope_debug("transformers.models.apertus.modeling_apertus")
  100. """
  101. import importlib
  102. # Import the module and get the original function
  103. module = importlib.import_module(model_module_path)
  104. orig_rope = getattr(module, function_name)
  105. # Set torch print options for better debugging
  106. torch.set_printoptions(threshold=float('inf'))
  107. torch.set_printoptions(precision=6, sci_mode=False)
  108. def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  109. # log inputs
  110. summarize(q, "RoPE.q_in")
  111. summarize(k, "RoPE.k_in")
  112. # call original
  113. q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
  114. # log outputs
  115. summarize(q_out, "RoPE.q_out")
  116. summarize(k_out, "RoPE.k_out")
  117. return q_out, k_out
  118. # Patch it
  119. setattr(module, function_name, debug_rope)
  120. print(f"RoPE debug patching applied to {model_module_path}.{function_name}")
  121. def save_output_data(data, tokens, prompt, model_name, type_suffix="", output_dir="data"):
  122. """
  123. Save output data (logits/embeddings), tokens, and prompt to files.
  124. Args:
  125. data: numpy array of floats (logits or embeddings)
  126. tokens: list or array of token IDs
  127. prompt: string containing the input prompt
  128. model_name: name of the model
  129. type_suffix: optional suffix like "-embeddings" (default: "")
  130. output_dir: directory to save files (default: "data")
  131. Creates the following files in output_dir:
  132. - pytorch-{model_name}{type_suffix}.bin
  133. - pytorch-{model_name}{type_suffix}.txt
  134. - pytorch-{model_name}{type_suffix}-prompt.txt
  135. - pytorch-{model_name}{type_suffix}-tokens.bin
  136. """
  137. data_dir = Path(output_dir)
  138. data_dir.mkdir(exist_ok=True)
  139. base_path = data_dir / f"pytorch-{model_name}{type_suffix}"
  140. # Convert and flatten logits/embeddings
  141. data = data.cpu().numpy() if isinstance(data, torch.Tensor) else np.asarray(data)
  142. data = data.flatten() if data.ndim > 1 else data
  143. # Save logits/embedding files
  144. data.astype(np.float32).tofile(f"{base_path}.bin")
  145. print(f"Data saved to {base_path}.bin")
  146. with open(f"{base_path}.txt", "w") as f:
  147. f.writelines(f"{i}: {value:.6f}\n" for i, value in enumerate(data))
  148. print(f"Data saved to {base_path}.txt")
  149. # Convert and flatten tokens
  150. tokens = tokens.cpu().numpy() if isinstance(tokens, torch.Tensor) else np.asarray(tokens)
  151. tokens = tokens.flatten() if tokens.ndim > 1 else tokens
  152. # Save token binary file
  153. tokens.astype(np.int32).tofile(f"{base_path}-tokens.bin")
  154. print(f"Tokens saved to {base_path}-tokens.bin")
  155. # Save prompt file
  156. with open(f"{base_path}-prompt.txt", "w") as f:
  157. f.write(f"prompt: {prompt}\n")
  158. f.write(f"n_tokens: {len(tokens)}\n")
  159. f.write(f"token ids: {', '.join(str(int(tid)) for tid in tokens)}\n")
  160. print(f"Prompt saved to {base_path}-prompt.txt")
  161. def compare_tokens(original, converted, type_suffix="", output_dir="data"):
  162. data_dir = Path(output_dir)
  163. # Read tokens from both models
  164. tokens1_file = data_dir / f"{original}{type_suffix}-tokens.bin"
  165. tokens2_file = data_dir / f"{converted}{type_suffix}-tokens.bin"
  166. if not tokens1_file.exists():
  167. print(f"Error: Token file not found: {tokens1_file}")
  168. return False
  169. if not tokens2_file.exists():
  170. print(f"Error: Token file not found: {tokens2_file}")
  171. return False
  172. tokens1 = np.fromfile(tokens1_file, dtype=np.int32)
  173. tokens2 = np.fromfile(tokens2_file, dtype=np.int32)
  174. print(f"\nComparing tokens between:")
  175. print(f" Original : {original} ({len(tokens1)} tokens)")
  176. print(f" Converted: {converted} ({len(tokens2)} tokens)")
  177. if len(tokens1) != len(tokens2):
  178. print(f"\n❌ Token count mismatch: {len(tokens1)} vs {len(tokens2)}")
  179. return False
  180. if np.array_equal(tokens1, tokens2):
  181. print(f"\n✅ All {len(tokens1)} tokens match!")
  182. return True
  183. mismatches = np.where(tokens1 != tokens2)[0]
  184. print(f"\n❌ Found {len(mismatches)} mismatched tokens:")
  185. num_to_show = min(len(mismatches), 10)
  186. for idx in mismatches[:num_to_show]:
  187. print(f" Position {idx}: {tokens1[idx]} vs {tokens2[idx]}")
  188. if len(mismatches) > num_to_show:
  189. print(f" ... and {len(mismatches) - num_to_show} more mismatches")
  190. return False
  191. def show_version_warning(current_version, model_version):
  192. if not model_version:
  193. return False
  194. try:
  195. from packaging.version import parse, InvalidVersion
  196. try:
  197. return parse(current_version) < parse(model_version)
  198. except InvalidVersion:
  199. return current_version != model_version
  200. except ImportError:
  201. return current_version != model_version
  202. def get_model_transformers_version(model_path):
  203. if not model_path:
  204. return None
  205. config_path = Path(model_path) / "config.json"
  206. if not config_path.is_file():
  207. return None
  208. try:
  209. with open(config_path, "r", encoding="utf-8") as f:
  210. config = json.load(f)
  211. return config.get("transformers_version")
  212. except (IOError, json.JSONDecodeError) as e:
  213. print(f"Warning: Could not read or parse {config_path}: {e}", file=sys.stderr)
  214. return None
  215. def exit_with_warning(message, model_path):
  216. print(message)
  217. if model_path and transformers is not None:
  218. model_transformers_version = get_model_transformers_version(model_path)
  219. transformers_version = transformers.__version__
  220. if show_version_warning(transformers_version, model_transformers_version):
  221. warning_message = f"""
  222. =====================================================================
  223. Verification failure might be due to a transformers version mismatch:
  224. Current transformers version: {transformers_version}
  225. Model's required version : {model_transformers_version}
  226. Consider installing the version specified by the model's config:
  227. pip install transformers=={model_transformers_version}
  228. =====================================================================
  229. """
  230. print(textwrap.dedent(warning_message))
  231. sys.exit(1)