1
0

common.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. #!/usr/bin/env python3
  2. import os
  3. import sys
  4. import torch
  5. def get_model_name_from_env_path(env_path_name):
  6. model_path = os.getenv(env_path_name)
  7. if not model_path:
  8. print(f"Error: {env_path_name} environment variable not set")
  9. sys.exit(1)
  10. if not os.path.exists(model_path):
  11. print(f"Error: Model file not found: {model_path}")
  12. sys.exit(1)
  13. name = os.path.basename(os.path.normpath(model_path))
  14. if name.endswith(".gguf"):
  15. name = name[:-5]
  16. return name
  17. def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
  18. """
  19. Print a tensor in llama.cpp debug style.
  20. Supports:
  21. - 2D tensors (seq, hidden)
  22. - 3D tensors (batch, seq, hidden)
  23. - 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
  24. Shows first and last max_vals of each vector per sequence position.
  25. """
  26. t = tensor.detach().to(torch.float32).cpu()
  27. # Determine dimensions
  28. if t.ndim == 3:
  29. _, s, _ = t.shape
  30. elif t.ndim == 2:
  31. _, s = 1, t.shape[0]
  32. t = t.unsqueeze(0)
  33. elif t.ndim == 4:
  34. _, s, _, _ = t.shape
  35. else:
  36. print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
  37. return
  38. ten_shape = t.shape
  39. print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}")
  40. print(" [")
  41. print(" [")
  42. # Determine indices for first and last sequences
  43. first_indices = list(range(min(s, max_seq)))
  44. last_indices = list(range(max(0, s - max_seq), s))
  45. # Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
  46. has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s)
  47. # Combine indices
  48. if has_overlap:
  49. # If there's overlap, just use the combined unique indices
  50. indices = sorted(list(set(first_indices + last_indices)))
  51. separator_index = None
  52. else:
  53. # If no overlap, we'll add a separator between first and last sequences
  54. indices = first_indices + last_indices
  55. separator_index = len(first_indices)
  56. for i, si in enumerate(indices):
  57. # Add separator if needed
  58. if separator_index is not None and i == separator_index:
  59. print(" ...")
  60. # Extract appropriate slice
  61. vec = t[0, si]
  62. if vec.ndim == 2: # 4D case: flatten heads × dim_per_head
  63. flat = vec.flatten().tolist()
  64. else: # 2D or 3D case
  65. flat = vec.tolist()
  66. # First and last slices
  67. first = flat[:max_vals]
  68. last = flat[-max_vals:] if len(flat) >= max_vals else flat
  69. first_str = ", ".join(f"{v:12.4f}" for v in first)
  70. last_str = ", ".join(f"{v:12.4f}" for v in last)
  71. print(f" [{first_str}, ..., {last_str}]")
  72. print(" ],")
  73. print(" ]")
  74. print(f" sum = {t.sum().item():.6f}\n")
  75. def debug_hook(name):
  76. def fn(_m, input, output):
  77. if isinstance(input, torch.Tensor):
  78. summarize(input, name + "_in")
  79. elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor):
  80. summarize(input[0], name + "_in")
  81. if isinstance(output, torch.Tensor):
  82. summarize(output, name + "_out")
  83. elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor):
  84. summarize(output[0], name + "_out")
  85. return fn
  86. def setup_rope_debug(model_module_path: str, function_name: str = "apply_rotary_pos_emb"):
  87. """
  88. Apply monkey patch to dump RoPE activations for debugging.
  89. Args:
  90. model_module_path: Path to the model module (e.g., "transformers.models.apertus.modeling_apertus")
  91. function_name: Name of the RoPE function to patch (default: "apply_rotary_pos_emb")
  92. Example:
  93. from utils.common import setup_rope_debug
  94. setup_rope_debug("transformers.models.apertus.modeling_apertus")
  95. """
  96. import importlib
  97. # Import the module and get the original function
  98. module = importlib.import_module(model_module_path)
  99. orig_rope = getattr(module, function_name)
  100. # Set torch print options for better debugging
  101. torch.set_printoptions(threshold=float('inf'))
  102. torch.set_printoptions(precision=6, sci_mode=False)
  103. def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  104. # log inputs
  105. summarize(q, "RoPE.q_in")
  106. summarize(k, "RoPE.k_in")
  107. # call original
  108. q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
  109. # log outputs
  110. summarize(q_out, "RoPE.q_out")
  111. summarize(k_out, "RoPE.k_out")
  112. return q_out, k_out
  113. # Patch it
  114. setattr(module, function_name, debug_rope)
  115. print(f"RoPE debug patching applied to {model_module_path}.{function_name}")