run-org-model-multi-token.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821
  1. #!/usr/bin/env python3
  2. import argparse
  3. import os
  4. import importlib
  5. from pathlib import Path
  6. import re
  7. from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
  8. import torch
  9. import numpy as np
  10. ### If you want to dump RoPE activations, apply this monkey patch to the model
  11. ### class from Transformers that you are running (replace apertus.modeling_apertus
  12. ### with the proper package and class for your model
  13. ### === START ROPE DEBUG ===
  14. # from transformers.models.apertus.modeling_apertus import apply_rotary_pos_emb
  15. # orig_rope = apply_rotary_pos_emb
  16. # torch.set_printoptions(threshold=float('inf'))
  17. # torch.set_printoptions(precision=6, sci_mode=False)
  18. # def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  19. # # log inputs
  20. # summarize(q, "RoPE.q_in")
  21. # summarize(k, "RoPE.k_in")
  22. # # call original
  23. # q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
  24. # # log outputs
  25. # summarize(q_out, "RoPE.q_out")
  26. # summarize(k_out, "RoPE.k_out")
  27. # return q_out, k_out
  28. # # Patch it
  29. # import transformers.models.apertus.modeling_apertus as apertus_mod # noqa: E402
  30. # apertus_mod.apply_rotary_pos_emb = debug_rope
  31. ### == END ROPE DEBUG ===
  32. token_counter = {}
  33. def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
  34. global token, token_counter
  35. """
  36. Print a tensor in llama.cpp debug style.
  37. Supports:
  38. - 2D tensors (seq, hidden)
  39. - 3D tensors (batch, seq, hidden)
  40. - 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
  41. Shows first and last max_vals of each vector per sequence position.
  42. """
  43. t = tensor.detach().to(torch.float32).cpu()
  44. # Determine dimensions
  45. if t.ndim == 3:
  46. _, s, _ = t.shape
  47. elif t.ndim == 2:
  48. _, s = 1, t.shape[0]
  49. t = t.unsqueeze(0)
  50. elif t.ndim == 4:
  51. _, s, _, _ = t.shape
  52. else:
  53. print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
  54. return
  55. ten_shape = t.shape
  56. print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}")
  57. print(" [")
  58. print(" [")
  59. # Determine indices for first and last sequences
  60. first_indices = list(range(min(s, max_seq)))
  61. last_indices = list(range(max(0, s - max_seq), s))
  62. # Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
  63. has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s)
  64. # Combine indices
  65. if has_overlap:
  66. # If there's overlap, just use the combined unique indices
  67. indices = sorted(list(set(first_indices + last_indices)))
  68. separator_index = None
  69. else:
  70. # If no overlap, we'll add a separator between first and last sequences
  71. indices = first_indices + last_indices
  72. separator_index = len(first_indices)
  73. for i, si in enumerate(indices):
  74. # Add separator if needed
  75. if separator_index is not None and i == separator_index:
  76. print(" ...")
  77. # Extract appropriate slice
  78. vec = t[0, si]
  79. if vec.ndim == 2: # 4D case: flatten heads × dim_per_head
  80. flat = vec.flatten().tolist()
  81. else: # 2D or 3D case
  82. flat = vec.tolist()
  83. # First and last slices
  84. first = flat[:max_vals]
  85. last = flat[-max_vals:] if len(flat) >= 2 * max_vals else flat
  86. first_str = ", ".join(f"{v:12.4f}" for v in first)
  87. last_str = ", ".join(f"{v:12.4f}" for v in last)
  88. if len(flat) >= 2 * max_vals:
  89. print(f" [{first_str}, ..., {last_str}]")
  90. else:
  91. print(f" [{last_str}]")
  92. print(" ],")
  93. print(" ]")
  94. print(f" sum = {t.sum().item():.6f}\n")
  95. pattern = r"model\.layers\.[0-9]+_out"
  96. pattern2 = r"recurrent_cache_[0-9]+"
  97. if re.fullmatch(pattern, name) or re.fullmatch(pattern2, name):
  98. if name not in token_counter:
  99. token_counter[name] = 1
  100. else:
  101. token_counter[name] = token_counter[name] + 1
  102. save_tensor(t, f"reference/tensors/org/{name}_{token_counter[name]}.bin")
  103. from transformers.models.qwen3_next.modeling_qwen3_next import torch_causal_conv1d_update, apply_rotary_pos_emb, l2norm # noqa: E402
  104. orig_conv1d_update = torch_causal_conv1d_update
  105. orig_rope = apply_rotary_pos_emb
  106. import torch.nn.functional as F # noqa: E402
  107. import typing # noqa: E402
  108. def patched_torch_causal_conv1d_update(
  109. hidden_states,
  110. conv_state,
  111. weight,
  112. bias=None,
  113. activation=None,
  114. ):
  115. _, hidden_size, seq_len = hidden_states.shape
  116. state_len = conv_state.shape[-1]
  117. summarize(hidden_states, "hidden_states_in")
  118. summarize(conv_state, "conv_state_in")
  119. hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype)
  120. summarize(hidden_states_new, "hidden_states_new")
  121. summarize(hidden_states_new[:, :, -state_len:], "hidden_states_to_copy")
  122. summarize(conv_state, "conv_state_pre")
  123. conv_state.copy_(hidden_states_new[:, :, -state_len:])
  124. summarize(conv_state, "conv_state_post")
  125. out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size)
  126. summarize(out, "out")
  127. summarize(out[:, :, -seq_len:], "out_proper")
  128. out = F.silu(out[:, :, -seq_len:])
  129. summarize(out, "out_silu")
  130. out = out.to(hidden_states.dtype)
  131. return out
  132. already_dumped_rope = False
  133. def save_tensor(tensor, filename):
  134. """Save tensor to binary file with shape information."""
  135. # Ensure tensors directory exists
  136. os.makedirs(os.path.dirname(filename), exist_ok=True)
  137. # Convert to numpy and save
  138. np_array = tensor.detach().cpu().numpy()
  139. # Save shape first (4 int64 values), then data
  140. with open(filename, 'wb') as f:
  141. shape = list(np_array.shape)
  142. while len(shape) < 4:
  143. shape.insert(0, 0)
  144. # Write shape as int64
  145. shape_array = np.array(shape, dtype=np.int64)
  146. f.write(shape_array.tobytes())
  147. # Write data as float32
  148. np_array_float32 = np_array.astype(np.float32)
  149. f.write(np_array_float32.tobytes())
  150. def patched_apply_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  151. global already_dumped_rope
  152. # log inputs
  153. summarize(q, "RoPE.q_in")
  154. summarize(k, "RoPE.k_in")
  155. summarize(cos, "cos")
  156. summarize(sin, "sin")
  157. # if q.shape[1] == 2 and k.shape[1] == 1 and k.shape[2] == 1 and not already_dumped_rope:
  158. # already_dumped_rope = True
  159. # print("Dumping input tensors")
  160. # save_tensor(q, "reference/tensors/testrope_q_in.bin")
  161. # save_tensor(k, "reference/tensors/testrope_k_in.bin")
  162. # save_tensor(cos, "reference/tensors/testrope_cos_in.bin")
  163. # save_tensor(sin, "reference/tensors/testrope_sin_in.bin")
  164. if position_ids:
  165. summarize(position_ids, "position_ids")
  166. # print(f"Rotary dim is {cos.unsqueeze(unsqueeze_dim).shape[-1]}")
  167. # call original
  168. q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
  169. # log outputs
  170. summarize(q_out, "RoPE.q_out")
  171. summarize(k_out, "RoPE.k_out")
  172. return q_out, k_out
  173. def patched_torch_chunk_gated_delta_rule(
  174. query,
  175. key,
  176. value,
  177. g,
  178. beta,
  179. chunk_size=64,
  180. initial_state=None,
  181. output_final_state=False,
  182. use_qk_l2norm_in_kernel=False,
  183. long=False
  184. ):
  185. torch.set_printoptions(threshold=10_000_000, sci_mode=False, precision=10, linewidth=200)
  186. initial_dtype = query.dtype
  187. [ summarize(x, y) for (x, y) in ((query, "q_prenorm"), (key, "k_prenorm")) ]
  188. if use_qk_l2norm_in_kernel:
  189. query = l2norm(query, dim=-1, eps=1e-6)
  190. key = l2norm(key, dim=-1, eps=1e-6)
  191. [ summarize(x, y) for (x, y) in ((query, "q_orig"), (key, "k_orig"), (value, "v_orig"), (beta, "b_orig"), (g, "g_orig")) ]
  192. query, key, value, beta, g = [
  193. x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
  194. ]
  195. [ summarize(x, y) for (x, y) in ((query, "q_tra"), (key, "k_tra"), (value, "v_tra"), (beta, "b_tra"), (g, "g_tra")) ]
  196. batch_size, sequence_length, num_heads, k_head_dim = key.shape
  197. print(f"batch_size = {batch_size}, seq_len = {sequence_length}, num_heads = {num_heads}, k_head_dim = {k_head_dim}")
  198. v_head_dim = value.shape[-1]
  199. pad_size = (chunk_size - num_heads % chunk_size) % chunk_size
  200. print(f"Pad size = {pad_size}, chunk_size = {chunk_size}")
  201. query = F.pad(query, (0, 0, 0, pad_size))
  202. key = F.pad(key, (0, 0, 0, pad_size))
  203. value = F.pad(value, (0, 0, 0, pad_size))
  204. beta = F.pad(beta, (0, pad_size))
  205. g = F.pad(g, (0, pad_size))
  206. [ summarize(x, y) for (x, y) in ((query, "q_pad"), (key, "k_pad"), (value, "v_pad"), (beta, "b_pad"), (g, "g_pad")) ]
  207. tot_heads = num_heads + pad_size
  208. scale = 1 / (query.shape[-1] ** 0.5)
  209. print(f"Scale for delta is {scale} (from {query.shape[-1]})")
  210. query = query * scale
  211. summarize(query, "q_scaled")
  212. summarize(key, "k")
  213. summarize(beta.unsqueeze(-1), "beta")
  214. v_beta = value * beta.unsqueeze(-1)
  215. k_beta = key * beta.unsqueeze(-1)
  216. summarize(k_beta, "k_beta")
  217. summarize(v_beta, "v_beta")
  218. # reshape to chunks
  219. query, key, value, k_beta, v_beta = [
  220. x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)
  221. ]
  222. g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
  223. [ summarize(x, y) for (x, y) in ((query, "q_resh"), (k_beta, "k_beta_resh"), (v_beta, "v_beta_resh"), (key, "k_resh"), (value, "v_resh")) ]
  224. mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
  225. # chunk decay
  226. g = g.cumsum(dim=-1)
  227. summarize(g, "g_cumsum")
  228. sub = g.unsqueeze(-1) - g.unsqueeze(-2)
  229. bt1, bt2 = torch.broadcast_tensors(g.unsqueeze(-1), g.unsqueeze(-2))
  230. summarize(bt1, "bt1")
  231. summarize(bt2, "bt2")
  232. summarize(sub, "sub")
  233. decay_mask = sub.tril()
  234. summarize(decay_mask, "sub_tril")
  235. decay_mask = decay_mask.exp()
  236. summarize(decay_mask, "sub_tril_exp")
  237. decay_mask = decay_mask.float()
  238. summarize(decay_mask, "sub_tril_exp_float")
  239. decay_mask = decay_mask.tril()
  240. summarize(decay_mask, "decay_mask")
  241. k_t = key.transpose(-1, -2)
  242. summarize(k_t, "k_t")
  243. kmul = k_beta @ k_t
  244. summarize(kmul, "k_beta @ k_t")
  245. #if not long:
  246. #print(f"k_beta @ k_t:\n{kmul[:,:,:,:8,:8]}\n\n")
  247. kmul_decay = kmul * decay_mask
  248. summarize(kmul_decay, "(k_beta @ k_t) * decay_mask")
  249. attn = -(kmul_decay).masked_fill(mask, 0)
  250. summarize(attn, "attn_in")
  251. for i in range(1, chunk_size):
  252. row = attn[..., i, :i].clone()
  253. sub = attn[..., :i, :i].clone()
  254. attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
  255. #if i <= num_heads and not long:
  256. #print(f"Chunk {i}: row:\n{row}\n\nsub:\n{sub}\nrow_unsq:\n{row.unsqueeze(-1)}\nrow_unsq * sub:\n{row.unsqueeze(-1)*sub}\n")
  257. #print(f"attn => sum = {attn[..., i, :i].sum()}, tensor: \n{attn[..., i, :i]}\n\n")
  258. summarize(attn, "attn_chunks")
  259. attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
  260. summarize(attn, "attn_eye")
  261. value = attn @ v_beta
  262. summarize(value, "value")
  263. k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
  264. summarize(k_cumdecay, "k_cumdecay")
  265. last_recurrent_state = (
  266. torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value)
  267. if initial_state is None
  268. else initial_state.to(value)
  269. )
  270. core_attn_out = torch.zeros_like(value)
  271. mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
  272. # for each chunk
  273. for i in range(0, tot_heads // chunk_size):
  274. print(f"\n=== Processing chunk {i} ===")
  275. q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
  276. summarize(q_i, f"q_i_chunk_{i}")
  277. summarize(k_i, f"k_i_chunk_{i}")
  278. summarize(v_i, f"v_i_chunk_{i}")
  279. attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
  280. summarize(attn, f"attn_chunk_{i}")
  281. v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
  282. summarize(v_prime, f"v_prime_chunk_{i}")
  283. v_new = v_i - v_prime
  284. summarize(v_new, f"v_new_chunk_{i}")
  285. attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
  286. summarize(attn_inter, f"attn_inter_chunk_{i}")
  287. core_attn_out[:, :, i] = attn_inter + attn @ v_new
  288. summarize(core_attn_out[:, :, i], f"core_attn_out_chunk_{i}")
  289. g_last = g[:, :, i, -1, None, None].exp()
  290. summarize(g_last, f"g_last_chunk_{i}")
  291. g_diff_exp = (g[:, :, i, -1, None] - g[:, :, i]).exp()
  292. last_recurrent_state = (
  293. last_recurrent_state * g_last
  294. + (k_i * g_diff_exp[..., None]).transpose(-1, -2) @ v_new
  295. )
  296. summarize(last_recurrent_state, f"updated_state_chunk_{i}")
  297. if not output_final_state:
  298. last_recurrent_state = None
  299. core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
  300. core_attn_out = core_attn_out[:, :, :num_heads]
  301. core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
  302. summarize(core_attn_out, "attn_out")
  303. if not long:
  304. print(f"attn_out:\n{core_attn_out}\n\n")
  305. if isinstance(last_recurrent_state, torch.Tensor):
  306. summarize(last_recurrent_state, "state_out")
  307. if not long:
  308. print(f"state_out:\n{last_recurrent_state}\n\n")
  309. return core_attn_out, last_recurrent_state
  310. def patched_torch_recurrent_gated_delta_rule(
  311. query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False
  312. ):
  313. initial_dtype = query.dtype
  314. if use_qk_l2norm_in_kernel:
  315. query = l2norm(query, dim=-1, eps=1e-6)
  316. key = l2norm(key, dim=-1, eps=1e-6)
  317. query, key, value, beta, g = [
  318. x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
  319. ]
  320. summarize(query, "q_t")
  321. summarize(key, "k_t")
  322. summarize(value, "v_t")
  323. summarize(beta, "beta_t")
  324. summarize(g, "g_t")
  325. batch_size, num_heads, sequence_length, k_head_dim = key.shape
  326. v_head_dim = value.shape[-1]
  327. scale = 1 / (query.shape[-1] ** 0.5)
  328. query = query * scale
  329. summarize(query, "q_scaled")
  330. if initial_state is not None:
  331. summarize(initial_state, "initial_state")
  332. core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value)
  333. last_recurrent_state = (
  334. torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
  335. if initial_state is None
  336. else initial_state.to(value)
  337. )
  338. for i in range(sequence_length):
  339. q_t = query[:, :, i]
  340. k_t = key[:, :, i]
  341. v_t = value[:, :, i]
  342. g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
  343. summarize(g_t, "g_exp_unsq")
  344. beta_t = beta[:, :, i].unsqueeze(-1)
  345. summarize(beta_t, "beta_t_unsq")
  346. last_recurrent_state = last_recurrent_state * g_t
  347. summarize(last_recurrent_state, "gated_state")
  348. k_unsq = k_t.unsqueeze(-1)
  349. summarize(k_unsq, "k_unsqueeze")
  350. state_k = last_recurrent_state * k_unsq
  351. summarize(state_k, "state_k_product")
  352. kv_mem = state_k.sum(dim=-2)
  353. summarize(kv_mem, "kv_mem")
  354. delta = (v_t - kv_mem) * beta_t
  355. summarize(delta, "delta")
  356. k_delta = k_t.unsqueeze(-1) * delta.unsqueeze(-2)
  357. summarize(k_delta, "k_delta")
  358. last_recurrent_state = last_recurrent_state + k_delta
  359. summarize(last_recurrent_state, "state_plus_k_delta")
  360. state_q_prod = last_recurrent_state * q_t.unsqueeze(-1)
  361. summarize(state_q_prod, "state_q_product")
  362. core_attn_out[:, :, i] = state_q_prod.sum(dim=-2)
  363. summarize(core_attn_out, "core_attn_out")
  364. if not output_final_state:
  365. last_recurrent_state = None
  366. core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
  367. return core_attn_out, last_recurrent_state
  368. import transformers.models.qwen3_next.modeling_qwen3_next as qwen_mod # noqa: E402
  369. qwen_mod.torch_chunk_gated_delta_rule = patched_torch_chunk_gated_delta_rule
  370. qwen_mod.torch_causal_conv1d_update = patched_torch_causal_conv1d_update
  371. qwen_mod.apply_rotary_pos_emb = patched_apply_rope
  372. qwen_mod.torch_recurrent_gated_delta_rule = patched_torch_recurrent_gated_delta_rule
  373. # Store original functions for patching
  374. original_functions = {}
  375. def debug_hook(name):
  376. def fn(_m, input, output):
  377. if isinstance(input, torch.Tensor):
  378. summarize(input, name + "_in")
  379. elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor):
  380. summarize(input[0], name + "_in")
  381. if isinstance(output, torch.Tensor):
  382. summarize(output, name + "_out")
  383. elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor):
  384. summarize(output[0], name + "_out")
  385. return fn
  386. def patch_all_forward_methods(model):
  387. """Apply monkey patches to all forward methods in the model"""
  388. for name, module in model.named_modules():
  389. # Set layer index if applicable
  390. parts = name.split('.')
  391. module.layer_idx = -1 # Default invalid value
  392. if len(parts) > 2 and parts[0] == 'model' and parts[1] == 'layers':
  393. try:
  394. module.layer_idx = int(parts[2]) # Convert to integer
  395. except (ValueError, IndexError):
  396. module.layer_idx = -1
  397. # Apply forward hook to log all inputs/outputs
  398. module.register_forward_hook(debug_hook(name))
  399. # Additional patches for specific methods in various modules
  400. if hasattr(module, 'forward'):
  401. original_forward = module.forward
  402. def make_patched_forward(orig_forward, mod_name):
  403. def patched_forward(*args, **kwargs):
  404. # Log inputs
  405. for i, arg in enumerate(args):
  406. if isinstance(arg, torch.Tensor):
  407. summarize(arg, f"{mod_name}.forward.arg_{i}_in")
  408. # Call original forward
  409. result = orig_forward(*args, **kwargs)
  410. if mod_name.endswith("linear_attn"):
  411. cache = kwargs["cache_params"]
  412. nameparts = mod_name.split(".")
  413. layer_idx = -1
  414. try:
  415. layer_idx = int(nameparts[2])
  416. except (ValueError, IndexError):
  417. print(f"\n\nDEBUG: Failed to calculate layer index for module: {mod_name}\n\n")
  418. rec_cache = cache.recurrent_states[layer_idx]
  419. if rec_cache is not None:
  420. summarize(rec_cache, f"recurrent_cache_{layer_idx}")
  421. # Log output
  422. if isinstance(result, torch.Tensor):
  423. summarize(result, f"{mod_name}.forward.out")
  424. elif isinstance(result, (tuple, list)):
  425. for i, res in enumerate(result):
  426. if isinstance(res, torch.Tensor):
  427. summarize(res, f"{mod_name}.forward.out_{i}")
  428. return result
  429. return patched_forward
  430. module.forward = make_patched_forward(original_forward, name)
  431. def patch_silu():
  432. """Patch torch.nn.functional.silu to log inputs and outputs"""
  433. global original_functions
  434. if 'silu' not in original_functions:
  435. original_functions['silu'] = torch.nn.functional.silu
  436. def patched_silu(input, inplace=False):
  437. # Log input
  438. summarize(input, "silu_in")
  439. # Call original function
  440. result = original_functions['silu'](input, inplace)
  441. # Log output
  442. summarize(result, "silu_out")
  443. return result
  444. # Replace the function in the torch.nn.functional module
  445. torch.nn.functional.silu = patched_silu
  446. def patch_sigmoid():
  447. """Patch torch.nn.functional.sigmoid to log inputs and outputs"""
  448. global original_functions
  449. if 'sigmoid' not in original_functions:
  450. original_functions['sigmoid'] = torch.nn.functional.sigmoid
  451. def patched_sigmoid(input):
  452. # Log input
  453. summarize(input, "sigmoid_in")
  454. # Call original function
  455. result = original_functions['sigmoid'](input)
  456. # Log output
  457. summarize(result, "sigmoid_out")
  458. return result
  459. # Replace the function in the torch.nn.functional module
  460. torch.nn.functional.sigmoid = patched_sigmoid
  461. def patch_torch_sigmoid():
  462. """Patch torch.nn.functional.sigmoid to log inputs and outputs"""
  463. global original_functions
  464. if 'torch_sigmoid' not in original_functions:
  465. original_functions['torch_sigmoid'] = torch.sigmoid
  466. def patched_torch_sigmoid(input):
  467. # Log input
  468. summarize(input, "torch_sigmoid_in")
  469. # Call original function
  470. result = original_functions['torch_sigmoid'](input)
  471. # Log output
  472. summarize(result, "torch_sigmoid_out")
  473. return result
  474. # Replace the function in the torch.nn.functional module
  475. torch.sigmoid = patched_torch_sigmoid
  476. def patch_pad():
  477. """Patch torch.nn.functional.pad to log inputs and outputs"""
  478. global original_functions
  479. if 'pad' not in original_functions:
  480. original_functions['pad'] = torch.nn.functional.pad
  481. def patched_pad(input: torch.Tensor, pad: typing.Sequence[int], mode: str = 'constant', value: float | None = None): # pyright: ignore[reportGeneralTypeIssues]
  482. # Log input
  483. summarize(input, "pad_in")
  484. print(f"Padding shape is {pad}")
  485. # Call original function
  486. result = original_functions['pad'](input=input, pad=pad, mode=mode, value=value)
  487. # Log output
  488. summarize(result, "pad_out")
  489. return result
  490. # Replace the function in the torch.nn.functional module
  491. torch.nn.functional.pad = patched_pad
  492. def save_kv_cache(past_key_values, step_num, data_dir, model_name):
  493. """Save KV cache tensors for each layer"""
  494. cache_dir = data_dir / f"kv_cache_step_{step_num}"
  495. cache_dir.mkdir(exist_ok=True)
  496. # Access past_key_values if available
  497. if past_key_values is not None:
  498. for layer_idx, cache_tuple in enumerate(past_key_values):
  499. if cache_tuple is None:
  500. print(f"Cache tuple is None for layer {layer_idx} at step {step_num}")
  501. continue
  502. # Handle different cache formats
  503. if isinstance(cache_tuple, (tuple, list)) and len(cache_tuple) >= 2:
  504. key, value = cache_tuple[0], cache_tuple[1]
  505. # Check if key and value are not None
  506. if key is not None and value is not None:
  507. # Save key cache
  508. key_filename = cache_dir / f"layer_{layer_idx}_key.bin"
  509. key.detach().cpu().numpy().astype(np.float32).tofile(key_filename)
  510. # Save value cache
  511. value_filename = cache_dir / f"layer_{layer_idx}_value.bin"
  512. value.detach().cpu().numpy().astype(np.float32).tofile(value_filename)
  513. print(f"Saved KV cache for layer {layer_idx} at step {step_num}: key.shape={key.shape}, value.shape={value.shape}")
  514. else:
  515. print(f"Key or value is None for layer {layer_idx} at step {step_num}")
  516. else:
  517. # Handle other cache formats (e.g., recurrent models)
  518. print(f"Non-standard cache format for layer {layer_idx} at step {step_num}: {type(cache_tuple)}")
  519. # Save as generic cache if it's a tensor
  520. if hasattr(cache_tuple, 'detach'):
  521. cache_filename = cache_dir / f"layer_{layer_idx}_cache.bin"
  522. cache_tuple.detach().cpu().numpy().astype(np.float32).tofile(cache_filename)
  523. print(f"Saved generic cache for layer {layer_idx} at step {step_num}: shape={cache_tuple.shape}")
  524. else:
  525. print(f"No KV cache available at step {step_num}")
  526. unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
  527. parser = argparse.ArgumentParser(description="Process model with specified path")
  528. parser.add_argument("--model-path", "-m", help="Path to the model")
  529. parser.add_argument("--num-tokens", "-n", type=int, default=5, help="Number of tokens to generate")
  530. parser.add_argument("--prompt", "-p", default="Hello, my name is", help="Input prompt")
  531. parser.add_argument("--save-cache", action="store_true", help="Save KV cache at each step")
  532. args = parser.parse_args()
  533. model_path = os.environ.get("MODEL_PATH", args.model_path)
  534. if model_path is None:
  535. parser.error(
  536. "Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
  537. )
  538. config = AutoConfig.from_pretrained(model_path)
  539. print("Model type: ", config.model_type)
  540. print("Vocab size: ", config.vocab_size)
  541. print("Hidden size: ", config.hidden_size)
  542. print("Number of layers: ", config.num_hidden_layers)
  543. print("BOS token id: ", config.bos_token_id)
  544. print("EOS token id: ", config.eos_token_id)
  545. print("Loading model and tokenizer using AutoTokenizer:", model_path)
  546. tokenizer = AutoTokenizer.from_pretrained(model_path)
  547. config = AutoConfig.from_pretrained(model_path)
  548. if unreleased_model_name:
  549. model_name_lower = unreleased_model_name.lower()
  550. unreleased_module_path = (
  551. f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
  552. )
  553. class_name = f"{unreleased_model_name}ForCausalLM"
  554. print(f"Importing unreleased model module: {unreleased_module_path}")
  555. try:
  556. model_class = getattr(
  557. importlib.import_module(unreleased_module_path), class_name
  558. )
  559. model = model_class.from_pretrained(
  560. model_path
  561. ) # Note: from_pretrained, not fromPretrained
  562. except (ImportError, AttributeError) as e:
  563. print(f"Failed to import or load model: {e}")
  564. exit(1)
  565. else:
  566. model = AutoModelForCausalLM.from_pretrained(
  567. model_path, device_map="auto", offload_folder="offload"
  568. )
  569. patch_all_forward_methods(model)
  570. patch_silu()
  571. patch_pad()
  572. patch_sigmoid()
  573. patch_torch_sigmoid()
  574. model_name = os.path.basename(model_path)
  575. # Printing the Model class to allow for easier debugging. This can be useful
  576. # when working with models that have not been publicly released yet and this
  577. # migth require that the concrete class is imported and used directly instead
  578. # of using AutoModelForCausalLM.
  579. print(f"Model class: {model.__class__.__name__}")
  580. device = next(model.parameters()).device
  581. prompt = args.prompt
  582. input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
  583. print(f"Input tokens: {input_ids}")
  584. print(f"Input text: {repr(prompt)}")
  585. print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
  586. data_dir = Path("data")
  587. data_dir.mkdir(exist_ok=True)
  588. # Store all generated tokens and logits
  589. all_generated_tokens = []
  590. all_logits = []
  591. with torch.no_grad():
  592. # Initial forward pass
  593. print(f"\n=== Initial Forward Pass ===")
  594. outputs = model(input_ids, use_cache=True)
  595. logits = outputs.logits
  596. # Extract logits for the last token (next token prediction)
  597. last_logits = logits[0, -1, :].cpu().numpy()
  598. all_logits.append(last_logits)
  599. print(f"Logits shape: {logits.shape}")
  600. print(f"Last token logits shape: {last_logits.shape}")
  601. # Generate first token
  602. next_token_id = np.argmax(last_logits).item()
  603. all_generated_tokens.append(next_token_id)
  604. # Show top 5 predicted tokens for first step
  605. top_indices = np.argsort(last_logits)[-5:][::-1]
  606. print("Top 5 predictions for first token:")
  607. for idx in top_indices:
  608. token = tokenizer.decode([idx])
  609. print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
  610. print(f"Generated token {next_token_id} ({repr(tokenizer.decode([next_token_id]))})")
  611. # Save KV cache if requested
  612. if args.save_cache:
  613. save_kv_cache(outputs.past_key_values, 0, data_dir, model_name)
  614. # Prepare for next iteration
  615. past_key_values = outputs.past_key_values
  616. current_input = torch.tensor([[next_token_id]], device=device)
  617. # Generate remaining tokens
  618. for step in range(1, args.num_tokens):
  619. print(f"\n=== Generation Step {step} ===")
  620. # Forward pass with cache
  621. outputs = model(
  622. input_ids=current_input,
  623. past_key_values=past_key_values,
  624. use_cache=True
  625. )
  626. logits = outputs.logits
  627. last_logits = logits[0, -1, :].cpu().numpy()
  628. all_logits.append(last_logits)
  629. # Generate next token
  630. next_token_id = np.argmax(last_logits).item()
  631. all_generated_tokens.append(next_token_id)
  632. # Show top 5 predicted tokens for this step
  633. top_indices = np.argsort(last_logits)[-5:][::-1]
  634. print(f"Top 5 predictions for step {step}:")
  635. for idx in top_indices:
  636. token = tokenizer.decode([idx])
  637. print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
  638. print(f"Generated token {next_token_id} ({repr(tokenizer.decode([next_token_id]))})")
  639. # Save KV cache if requested
  640. if args.save_cache:
  641. save_kv_cache(outputs.past_key_values, step, data_dir, model_name)
  642. # Update for next iteration
  643. past_key_values = outputs.past_key_values
  644. current_input = torch.tensor([[next_token_id]], device=device)
  645. # Save results
  646. bin_filename = data_dir / f"pytorch-{model_name}-multi-token.bin"
  647. txt_filename = data_dir / f"pytorch-{model_name}-multi-token.txt"
  648. # Save all logits concatenated
  649. all_logits_array = np.array(all_logits)
  650. all_logits_array.astype(np.float32).tofile(bin_filename)
  651. # Also save as text file for easy inspection
  652. with open(txt_filename, "w") as f:
  653. f.write(f"Generated tokens: {all_generated_tokens}\n")
  654. f.write(f"Generated text: {repr(tokenizer.decode(all_generated_tokens))}\n")
  655. f.write(f"Full sequence: {repr(tokenizer.decode(input_ids[0].tolist() + all_generated_tokens))}\n\n")
  656. for step, logits in enumerate(all_logits):
  657. f.write(f"=== Step {step} logits ===\n")
  658. for i, logit in enumerate(logits):
  659. f.write(f"{i}: {logit:.6f}\n")
  660. f.write("\n")
  661. print(f"\n=== Generation Complete ===")
  662. print(f"Generated {len(all_generated_tokens)} tokens: {all_generated_tokens}")
  663. print(f"Generated text: {repr(tokenizer.decode(all_generated_tokens))}")
  664. print(f"Full sequence: {repr(tokenizer.decode(input_ids[0].tolist() + all_generated_tokens))}")
  665. print(f"Saved bin logits to: {bin_filename}")
  666. print(f"Saved txt logits to: {txt_filename}")
  667. if args.save_cache:
  668. print(f"KV cache saved to: {data_dir}/kv_cache_step_*")