run-org-model-multi-token.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586
  1. #!/usr/bin/env python3
  2. import argparse
  3. import os
  4. import importlib
  5. from pathlib import Path
  6. import re
  7. from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
  8. import torch
  9. import numpy as np
  10. ### If you want to dump RoPE activations, apply this monkey patch to the model
  11. ### class from Transformers that you are running (replace apertus.modeling_apertus
  12. ### with the proper package and class for your model
  13. ### === START ROPE DEBUG ===
  14. # from transformers.models.apertus.modeling_apertus import apply_rotary_pos_emb
  15. # orig_rope = apply_rotary_pos_emb
  16. # torch.set_printoptions(threshold=float('inf'))
  17. # torch.set_printoptions(precision=6, sci_mode=False)
  18. # def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  19. # # log inputs
  20. # summarize(q, "RoPE.q_in")
  21. # summarize(k, "RoPE.k_in")
  22. # # call original
  23. # q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
  24. # # log outputs
  25. # summarize(q_out, "RoPE.q_out")
  26. # summarize(k_out, "RoPE.k_out")
  27. # return q_out, k_out
  28. # # Patch it
  29. # import transformers.models.apertus.modeling_apertus as apertus_mod # noqa: E402
  30. # apertus_mod.apply_rotary_pos_emb = debug_rope
  31. ### == END ROPE DEBUG ===
  32. token_counter = {}
  33. def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
  34. global token, token_counter
  35. """
  36. Print a tensor in llama.cpp debug style.
  37. Supports:
  38. - 2D tensors (seq, hidden)
  39. - 3D tensors (batch, seq, hidden)
  40. - 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
  41. Shows first and last max_vals of each vector per sequence position.
  42. """
  43. t = tensor.detach().to(torch.float32).cpu()
  44. # Determine dimensions
  45. if t.ndim == 3:
  46. _, s, _ = t.shape
  47. elif t.ndim == 2:
  48. _, s = 1, t.shape[0]
  49. t = t.unsqueeze(0)
  50. elif t.ndim == 4:
  51. _, s, _, _ = t.shape
  52. else:
  53. print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
  54. return
  55. ten_shape = t.shape
  56. print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}")
  57. print(" [")
  58. print(" [")
  59. # Determine indices for first and last sequences
  60. first_indices = list(range(min(s, max_seq)))
  61. last_indices = list(range(max(0, s - max_seq), s))
  62. # Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
  63. has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s)
  64. # Combine indices
  65. if has_overlap:
  66. # If there's overlap, just use the combined unique indices
  67. indices = sorted(list(set(first_indices + last_indices)))
  68. separator_index = None
  69. else:
  70. # If no overlap, we'll add a separator between first and last sequences
  71. indices = first_indices + last_indices
  72. separator_index = len(first_indices)
  73. for i, si in enumerate(indices):
  74. # Add separator if needed
  75. if separator_index is not None and i == separator_index:
  76. print(" ...")
  77. # Extract appropriate slice
  78. vec = t[0, si]
  79. if vec.ndim == 2: # 4D case: flatten heads × dim_per_head
  80. flat = vec.flatten().tolist()
  81. else: # 2D or 3D case
  82. flat = vec.tolist()
  83. # First and last slices
  84. first = flat[:max_vals]
  85. last = flat[-max_vals:] if len(flat) >= 2 * max_vals else flat
  86. first_str = ", ".join(f"{v:12.4f}" for v in first)
  87. last_str = ", ".join(f"{v:12.4f}" for v in last)
  88. if len(flat) >= 2 * max_vals:
  89. print(f" [{first_str}, ..., {last_str}]")
  90. else:
  91. print(f" [{last_str}]")
  92. print(" ],")
  93. print(" ]")
  94. print(f" sum = {t.sum().item():.6f}\n")
  95. pattern = r"model\.layers\.[0-9]+_out"
  96. if re.fullmatch(pattern, name):
  97. if name not in token_counter:
  98. token_counter[name] = 1
  99. else:
  100. token_counter[name] = token_counter[name] + 1
  101. save_tensor(t, f"reference/tensors/org/{name}_{token_counter[name]}.bin")
  102. from transformers.models.qwen3_next.modeling_qwen3_next import torch_causal_conv1d_update, apply_rotary_pos_emb # noqa: E402
  103. orig_conv1d_update = torch_causal_conv1d_update
  104. orig_rope = apply_rotary_pos_emb
  105. import torch.nn.functional as F # noqa: E402
  106. import typing # noqa: E402
  107. def patched_torch_causal_conv1d_update(
  108. hidden_states,
  109. conv_state,
  110. weight,
  111. bias=None,
  112. activation=None,
  113. ):
  114. _, hidden_size, seq_len = hidden_states.shape
  115. state_len = conv_state.shape[-1]
  116. summarize(hidden_states, "hidden_states_in")
  117. summarize(conv_state, "conv_state_in")
  118. hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype)
  119. summarize(hidden_states_new, "hidden_states_new")
  120. summarize(hidden_states_new[:, :, -state_len:], "hidden_states_to_copy")
  121. summarize(conv_state, "conv_state_pre")
  122. conv_state.copy_(hidden_states_new[:, :, -state_len:])
  123. summarize(conv_state, "conv_state_post")
  124. out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size)
  125. summarize(out, "out")
  126. summarize(out[:, :, -seq_len:], "out_proper")
  127. out = F.silu(out[:, :, -seq_len:])
  128. summarize(out, "out_silu")
  129. out = out.to(hidden_states.dtype)
  130. return out
  131. already_dumped_rope = False
  132. def save_tensor(tensor, filename):
  133. """Save tensor to binary file with shape information."""
  134. # Ensure tensors directory exists
  135. os.makedirs(os.path.dirname(filename), exist_ok=True)
  136. # Convert to numpy and save
  137. np_array = tensor.detach().cpu().numpy()
  138. # Save shape first (4 int64 values), then data
  139. with open(filename, 'wb') as f:
  140. shape = list(np_array.shape)
  141. while len(shape) < 4:
  142. shape.insert(0, 0)
  143. # Write shape as int64
  144. shape_array = np.array(shape, dtype=np.int64)
  145. f.write(shape_array.tobytes())
  146. # Write data as float32
  147. np_array_float32 = np_array.astype(np.float32)
  148. f.write(np_array_float32.tobytes())
  149. def patched_apply_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  150. global already_dumped_rope
  151. # log inputs
  152. summarize(q, "RoPE.q_in")
  153. summarize(k, "RoPE.k_in")
  154. summarize(cos, "cos")
  155. summarize(sin, "sin")
  156. if q.shape[1] == 2 and k.shape[1] == 1 and k.shape[2] == 1 and not already_dumped_rope:
  157. already_dumped_rope = True
  158. print("Dumping input tensors")
  159. save_tensor(q, "reference/tensors/testrope_q_in.bin")
  160. save_tensor(k, "reference/tensors/testrope_k_in.bin")
  161. save_tensor(cos, "reference/tensors/testrope_cos_in.bin")
  162. save_tensor(sin, "reference/tensors/testrope_sin_in.bin")
  163. if position_ids:
  164. summarize(position_ids, "position_ids")
  165. print(f"Rotary dim is {cos.unsqueeze(unsqueeze_dim).shape[-1]}")
  166. # call original
  167. q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
  168. # log outputs
  169. summarize(q_out, "RoPE.q_out")
  170. summarize(k_out, "RoPE.k_out")
  171. return q_out, k_out
  172. import transformers.models.qwen3_next.modeling_qwen3_next as qwen_mod # noqa: E402
  173. qwen_mod.torch_causal_conv1d_update = patched_torch_causal_conv1d_update
  174. qwen_mod.apply_rotary_pos_emb = patched_apply_rope
  175. # Store original functions for patching
  176. original_functions = {}
  177. def debug_hook(name):
  178. def fn(_m, input, output):
  179. if isinstance(input, torch.Tensor):
  180. summarize(input, name + "_in")
  181. elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor):
  182. summarize(input[0], name + "_in")
  183. if isinstance(output, torch.Tensor):
  184. summarize(output, name + "_out")
  185. elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor):
  186. summarize(output[0], name + "_out")
  187. return fn
  188. def patch_all_forward_methods(model):
  189. """Apply monkey patches to all forward methods in the model"""
  190. for name, module in model.named_modules():
  191. # Set layer index if applicable
  192. parts = name.split('.')
  193. module.layer_idx = -1 # Default invalid value
  194. if len(parts) > 2 and parts[0] == 'model' and parts[1] == 'layers':
  195. try:
  196. module.layer_idx = int(parts[2]) # Convert to integer
  197. except (ValueError, IndexError):
  198. module.layer_idx = -1
  199. # Apply forward hook to log all inputs/outputs
  200. module.register_forward_hook(debug_hook(name))
  201. # Additional patches for specific methods in various modules
  202. if hasattr(module, 'forward'):
  203. original_forward = module.forward
  204. def make_patched_forward(orig_forward, mod_name):
  205. def patched_forward(*args, **kwargs):
  206. # Log inputs
  207. for i, arg in enumerate(args):
  208. if isinstance(arg, torch.Tensor):
  209. summarize(arg, f"{mod_name}.forward.arg_{i}_in")
  210. # Call original forward
  211. result = orig_forward(*args, **kwargs)
  212. # Log output
  213. if isinstance(result, torch.Tensor):
  214. summarize(result, f"{mod_name}.forward.out")
  215. elif isinstance(result, (tuple, list)):
  216. for i, res in enumerate(result):
  217. if isinstance(res, torch.Tensor):
  218. summarize(res, f"{mod_name}.forward.out_{i}")
  219. return result
  220. return patched_forward
  221. module.forward = make_patched_forward(original_forward, name)
  222. def patch_silu():
  223. """Patch torch.nn.functional.silu to log inputs and outputs"""
  224. global original_functions
  225. if 'silu' not in original_functions:
  226. original_functions['silu'] = torch.nn.functional.silu
  227. def patched_silu(input, inplace=False):
  228. # Log input
  229. summarize(input, "silu_in")
  230. # Call original function
  231. result = original_functions['silu'](input, inplace)
  232. # Log output
  233. summarize(result, "silu_out")
  234. return result
  235. # Replace the function in the torch.nn.functional module
  236. torch.nn.functional.silu = patched_silu
  237. def patch_sigmoid():
  238. """Patch torch.nn.functional.sigmoid to log inputs and outputs"""
  239. global original_functions
  240. if 'sigmoid' not in original_functions:
  241. original_functions['sigmoid'] = torch.nn.functional.sigmoid
  242. def patched_sigmoid(input):
  243. # Log input
  244. summarize(input, "sigmoid_in")
  245. # Call original function
  246. result = original_functions['sigmoid'](input)
  247. # Log output
  248. summarize(result, "sigmoid_out")
  249. return result
  250. # Replace the function in the torch.nn.functional module
  251. torch.nn.functional.sigmoid = patched_sigmoid
  252. def patch_torch_sigmoid():
  253. """Patch torch.nn.functional.sigmoid to log inputs and outputs"""
  254. global original_functions
  255. if 'torch_sigmoid' not in original_functions:
  256. original_functions['torch_sigmoid'] = torch.sigmoid
  257. def patched_torch_sigmoid(input):
  258. # Log input
  259. summarize(input, "torch_sigmoid_in")
  260. # Call original function
  261. result = original_functions['torch_sigmoid'](input)
  262. # Log output
  263. summarize(result, "torch_sigmoid_out")
  264. return result
  265. # Replace the function in the torch.nn.functional module
  266. torch.sigmoid = patched_torch_sigmoid
  267. def patch_pad():
  268. """Patch torch.nn.functional.pad to log inputs and outputs"""
  269. global original_functions
  270. if 'pad' not in original_functions:
  271. original_functions['pad'] = torch.nn.functional.pad
  272. def patched_pad(input: torch.Tensor, pad: typing.Sequence[int], mode: str = 'constant', value: float | None = None): # pyright: ignore[reportGeneralTypeIssues]
  273. # Log input
  274. summarize(input, "pad_in")
  275. print(f"Padding shape is {pad}")
  276. # Call original function
  277. result = original_functions['pad'](input=input, pad=pad, mode=mode, value=value)
  278. # Log output
  279. summarize(result, "pad_out")
  280. return result
  281. # Replace the function in the torch.nn.functional module
  282. torch.nn.functional.pad = patched_pad
  283. def save_kv_cache(past_key_values, step_num, data_dir, model_name):
  284. """Save KV cache tensors for each layer"""
  285. cache_dir = data_dir / f"kv_cache_step_{step_num}"
  286. cache_dir.mkdir(exist_ok=True)
  287. # Access past_key_values if available
  288. if past_key_values is not None:
  289. for layer_idx, cache_tuple in enumerate(past_key_values):
  290. if cache_tuple is None:
  291. print(f"Cache tuple is None for layer {layer_idx} at step {step_num}")
  292. continue
  293. # Handle different cache formats
  294. if isinstance(cache_tuple, (tuple, list)) and len(cache_tuple) >= 2:
  295. key, value = cache_tuple[0], cache_tuple[1]
  296. # Check if key and value are not None
  297. if key is not None and value is not None:
  298. # Save key cache
  299. key_filename = cache_dir / f"layer_{layer_idx}_key.bin"
  300. key.detach().cpu().numpy().astype(np.float32).tofile(key_filename)
  301. # Save value cache
  302. value_filename = cache_dir / f"layer_{layer_idx}_value.bin"
  303. value.detach().cpu().numpy().astype(np.float32).tofile(value_filename)
  304. print(f"Saved KV cache for layer {layer_idx} at step {step_num}: key.shape={key.shape}, value.shape={value.shape}")
  305. else:
  306. print(f"Key or value is None for layer {layer_idx} at step {step_num}")
  307. else:
  308. # Handle other cache formats (e.g., recurrent models)
  309. print(f"Non-standard cache format for layer {layer_idx} at step {step_num}: {type(cache_tuple)}")
  310. # Save as generic cache if it's a tensor
  311. if hasattr(cache_tuple, 'detach'):
  312. cache_filename = cache_dir / f"layer_{layer_idx}_cache.bin"
  313. cache_tuple.detach().cpu().numpy().astype(np.float32).tofile(cache_filename)
  314. print(f"Saved generic cache for layer {layer_idx} at step {step_num}: shape={cache_tuple.shape}")
  315. else:
  316. print(f"No KV cache available at step {step_num}")
  317. unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
  318. parser = argparse.ArgumentParser(description="Process model with specified path")
  319. parser.add_argument("--model-path", "-m", help="Path to the model")
  320. parser.add_argument("--num-tokens", "-n", type=int, default=5, help="Number of tokens to generate")
  321. parser.add_argument("--prompt", "-p", default="Hello, my name is", help="Input prompt")
  322. parser.add_argument("--save-cache", action="store_true", help="Save KV cache at each step")
  323. args = parser.parse_args()
  324. model_path = os.environ.get("MODEL_PATH", args.model_path)
  325. if model_path is None:
  326. parser.error(
  327. "Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
  328. )
  329. config = AutoConfig.from_pretrained(model_path)
  330. print("Model type: ", config.model_type)
  331. print("Vocab size: ", config.vocab_size)
  332. print("Hidden size: ", config.hidden_size)
  333. print("Number of layers: ", config.num_hidden_layers)
  334. print("BOS token id: ", config.bos_token_id)
  335. print("EOS token id: ", config.eos_token_id)
  336. print("Loading model and tokenizer using AutoTokenizer:", model_path)
  337. tokenizer = AutoTokenizer.from_pretrained(model_path)
  338. config = AutoConfig.from_pretrained(model_path)
  339. if unreleased_model_name:
  340. model_name_lower = unreleased_model_name.lower()
  341. unreleased_module_path = (
  342. f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
  343. )
  344. class_name = f"{unreleased_model_name}ForCausalLM"
  345. print(f"Importing unreleased model module: {unreleased_module_path}")
  346. try:
  347. model_class = getattr(
  348. importlib.import_module(unreleased_module_path), class_name
  349. )
  350. model = model_class.from_pretrained(
  351. model_path
  352. ) # Note: from_pretrained, not fromPretrained
  353. except (ImportError, AttributeError) as e:
  354. print(f"Failed to import or load model: {e}")
  355. exit(1)
  356. else:
  357. model = AutoModelForCausalLM.from_pretrained(
  358. model_path, device_map="auto", offload_folder="offload"
  359. )
  360. patch_all_forward_methods(model)
  361. patch_silu()
  362. patch_pad()
  363. patch_sigmoid()
  364. patch_torch_sigmoid()
  365. model_name = os.path.basename(model_path)
  366. # Printing the Model class to allow for easier debugging. This can be useful
  367. # when working with models that have not been publicly released yet and this
  368. # migth require that the concrete class is imported and used directly instead
  369. # of using AutoModelForCausalLM.
  370. print(f"Model class: {model.__class__.__name__}")
  371. device = next(model.parameters()).device
  372. prompt = args.prompt
  373. input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
  374. print(f"Input tokens: {input_ids}")
  375. print(f"Input text: {repr(prompt)}")
  376. print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
  377. data_dir = Path("data")
  378. data_dir.mkdir(exist_ok=True)
  379. # Store all generated tokens and logits
  380. all_generated_tokens = []
  381. all_logits = []
  382. with torch.no_grad():
  383. # Initial forward pass
  384. print(f"\n=== Initial Forward Pass ===")
  385. outputs = model(input_ids, use_cache=True)
  386. logits = outputs.logits
  387. # Extract logits for the last token (next token prediction)
  388. last_logits = logits[0, -1, :].cpu().numpy()
  389. all_logits.append(last_logits)
  390. print(f"Logits shape: {logits.shape}")
  391. print(f"Last token logits shape: {last_logits.shape}")
  392. # Generate first token
  393. next_token_id = np.argmax(last_logits).item()
  394. all_generated_tokens.append(next_token_id)
  395. # Show top 5 predicted tokens for first step
  396. top_indices = np.argsort(last_logits)[-5:][::-1]
  397. print("Top 5 predictions for first token:")
  398. for idx in top_indices:
  399. token = tokenizer.decode([idx])
  400. print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
  401. print(f"Generated token {next_token_id} ({repr(tokenizer.decode([next_token_id]))})")
  402. # Save KV cache if requested
  403. if args.save_cache:
  404. save_kv_cache(outputs.past_key_values, 0, data_dir, model_name)
  405. # Prepare for next iteration
  406. past_key_values = outputs.past_key_values
  407. current_input = torch.tensor([[next_token_id]], device=device)
  408. # Generate remaining tokens
  409. for step in range(1, args.num_tokens):
  410. print(f"\n=== Generation Step {step} ===")
  411. # Forward pass with cache
  412. outputs = model(
  413. input_ids=current_input,
  414. past_key_values=past_key_values,
  415. use_cache=True
  416. )
  417. logits = outputs.logits
  418. last_logits = logits[0, -1, :].cpu().numpy()
  419. all_logits.append(last_logits)
  420. # Generate next token
  421. next_token_id = np.argmax(last_logits).item()
  422. all_generated_tokens.append(next_token_id)
  423. # Show top 5 predicted tokens for this step
  424. top_indices = np.argsort(last_logits)[-5:][::-1]
  425. print(f"Top 5 predictions for step {step}:")
  426. for idx in top_indices:
  427. token = tokenizer.decode([idx])
  428. print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
  429. print(f"Generated token {next_token_id} ({repr(tokenizer.decode([next_token_id]))})")
  430. # Save KV cache if requested
  431. if args.save_cache:
  432. save_kv_cache(outputs.past_key_values, step, data_dir, model_name)
  433. # Update for next iteration
  434. past_key_values = outputs.past_key_values
  435. current_input = torch.tensor([[next_token_id]], device=device)
  436. # Save results
  437. bin_filename = data_dir / f"pytorch-{model_name}-multi-token.bin"
  438. txt_filename = data_dir / f"pytorch-{model_name}-multi-token.txt"
  439. # Save all logits concatenated
  440. all_logits_array = np.array(all_logits)
  441. all_logits_array.astype(np.float32).tofile(bin_filename)
  442. # Also save as text file for easy inspection
  443. with open(txt_filename, "w") as f:
  444. f.write(f"Generated tokens: {all_generated_tokens}\n")
  445. f.write(f"Generated text: {repr(tokenizer.decode(all_generated_tokens))}\n")
  446. f.write(f"Full sequence: {repr(tokenizer.decode(input_ids[0].tolist() + all_generated_tokens))}\n\n")
  447. for step, logits in enumerate(all_logits):
  448. f.write(f"=== Step {step} logits ===\n")
  449. for i, logit in enumerate(logits):
  450. f.write(f"{i}: {logit:.6f}\n")
  451. f.write("\n")
  452. print(f"\n=== Generation Complete ===")
  453. print(f"Generated {len(all_generated_tokens)} tokens: {all_generated_tokens}")
  454. print(f"Generated text: {repr(tokenizer.decode(all_generated_tokens))}")
  455. print(f"Full sequence: {repr(tokenizer.decode(input_ids[0].tolist() + all_generated_tokens))}")
  456. print(f"Saved bin logits to: {bin_filename}")
  457. print(f"Saved txt logits to: {txt_filename}")
  458. if args.save_cache:
  459. print(f"KV cache saved to: {data_dir}/kv_cache_step_*")