run-org-model-multi-token.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824
  1. #!/usr/bin/env python3
  2. import argparse
  3. import os
  4. import importlib
  5. from pathlib import Path
  6. import re
  7. from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
  8. import torch
  9. import numpy as np
  10. ### If you want to dump RoPE activations, apply this monkey patch to the model
  11. ### class from Transformers that you are running (replace apertus.modeling_apertus
  12. ### with the proper package and class for your model
  13. ### === START ROPE DEBUG ===
  14. # from transformers.models.apertus.modeling_apertus import apply_rotary_pos_emb
  15. # orig_rope = apply_rotary_pos_emb
  16. # torch.set_printoptions(threshold=float('inf'))
  17. # torch.set_printoptions(precision=6, sci_mode=False)
  18. # def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  19. # # log inputs
  20. # summarize(q, "RoPE.q_in")
  21. # summarize(k, "RoPE.k_in")
  22. # # call original
  23. # q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
  24. # # log outputs
  25. # summarize(q_out, "RoPE.q_out")
  26. # summarize(k_out, "RoPE.k_out")
  27. # return q_out, k_out
  28. # # Patch it
  29. # import transformers.models.apertus.modeling_apertus as apertus_mod # noqa: E402
  30. # apertus_mod.apply_rotary_pos_emb = debug_rope
  31. ### == END ROPE DEBUG ===
  32. token_counter = {}
  33. def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
  34. global token, token_counter
  35. """
  36. Print a tensor in llama.cpp debug style.
  37. Supports:
  38. - 2D tensors (seq, hidden)
  39. - 3D tensors (batch, seq, hidden)
  40. - 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
  41. - 5D tensors
  42. Shows first and last max_vals of each vector per sequence position.
  43. """
  44. t = tensor.detach().to(torch.float32).cpu()
  45. ten_shape = t.shape
  46. while t.ndim > 4:
  47. t = t.squeeze(0)
  48. # Determine dimensions
  49. if t.ndim == 3:
  50. _, s, _ = t.shape
  51. elif t.ndim == 2:
  52. _, s = 1, t.shape[0]
  53. t = t.unsqueeze(0)
  54. elif t.ndim == 4:
  55. _, s, _, _ = t.shape
  56. else:
  57. print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
  58. return
  59. print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}")
  60. print(" [")
  61. print(" [")
  62. # Determine indices for first and last sequences
  63. first_indices = list(range(min(s, max_seq)))
  64. last_indices = list(range(max(0, s - max_seq), s))
  65. # Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
  66. has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s)
  67. # Combine indices
  68. if has_overlap:
  69. # If there's overlap, just use the combined unique indices
  70. indices = sorted(list(set(first_indices + last_indices)))
  71. separator_index = None
  72. else:
  73. # If no overlap, we'll add a separator between first and last sequences
  74. indices = first_indices + last_indices
  75. separator_index = len(first_indices)
  76. for i, si in enumerate(indices):
  77. # Add separator if needed
  78. if separator_index is not None and i == separator_index:
  79. print(" ...")
  80. # Extract appropriate slice
  81. vec = t[0, si]
  82. if vec.ndim == 2: # 4D case: flatten heads × dim_per_head
  83. flat = vec.flatten().tolist()
  84. else: # 2D or 3D case
  85. flat = vec.tolist()
  86. # First and last slices
  87. first = flat[:max_vals]
  88. last = flat[-max_vals:] if len(flat) >= 2 * max_vals else flat
  89. first_str = ", ".join(f"{v:12.4f}" for v in first)
  90. last_str = ", ".join(f"{v:12.4f}" for v in last)
  91. if len(flat) >= 2 * max_vals:
  92. print(f" [{first_str}, ..., {last_str}]")
  93. else:
  94. print(f" [{last_str}]")
  95. print(" ],")
  96. print(" ]")
  97. print(f" sum = {t.sum().item():.6f}\n")
  98. pattern = r"model\.layers\.[0-9]+_out"
  99. pattern2 = r"recurrent_cache_[0-9]+"
  100. if re.fullmatch(pattern, name) or re.fullmatch(pattern2, name):
  101. if name not in token_counter:
  102. token_counter[name] = 1
  103. else:
  104. token_counter[name] = token_counter[name] + 1
  105. save_tensor(t, f"reference/tensors/org/{name}_{token_counter[name]}.bin")
  106. from transformers.models.qwen3_next.modeling_qwen3_next import torch_causal_conv1d_update, apply_rotary_pos_emb, l2norm # noqa: E402
  107. orig_conv1d_update = torch_causal_conv1d_update
  108. orig_rope = apply_rotary_pos_emb
  109. import torch.nn.functional as F # noqa: E402
  110. import typing # noqa: E402
  111. def patched_torch_causal_conv1d_update(
  112. hidden_states,
  113. conv_state,
  114. weight,
  115. bias=None,
  116. activation=None,
  117. ):
  118. _, hidden_size, seq_len = hidden_states.shape
  119. state_len = conv_state.shape[-1]
  120. summarize(hidden_states, "hidden_states_in")
  121. summarize(conv_state, "conv_state_in")
  122. hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype)
  123. summarize(hidden_states_new, "hidden_states_new")
  124. summarize(hidden_states_new[:, :, -state_len:], "hidden_states_to_copy")
  125. summarize(conv_state, "conv_state_pre")
  126. conv_state.copy_(hidden_states_new[:, :, -state_len:])
  127. summarize(conv_state, "conv_state_post")
  128. out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size)
  129. summarize(out, "out")
  130. summarize(out[:, :, -seq_len:], "out_proper")
  131. out = F.silu(out[:, :, -seq_len:])
  132. summarize(out, "out_silu")
  133. out = out.to(hidden_states.dtype)
  134. return out
  135. already_dumped_rope = False
  136. def save_tensor(tensor, filename):
  137. """Save tensor to binary file with shape information."""
  138. # Ensure tensors directory exists
  139. os.makedirs(os.path.dirname(filename), exist_ok=True)
  140. # Convert to numpy and save
  141. np_array = tensor.detach().cpu().numpy()
  142. # Save shape first (4 int64 values), then data
  143. with open(filename, 'wb') as f:
  144. shape = list(np_array.shape)
  145. while len(shape) < 4:
  146. shape.insert(0, 0)
  147. # Write shape as int64
  148. shape_array = np.array(shape, dtype=np.int64)
  149. f.write(shape_array.tobytes())
  150. # Write data as float32
  151. np_array_float32 = np_array.astype(np.float32)
  152. f.write(np_array_float32.tobytes())
  153. def patched_apply_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  154. global already_dumped_rope
  155. # log inputs
  156. summarize(q, "RoPE.q_in")
  157. summarize(k, "RoPE.k_in")
  158. summarize(cos, "cos")
  159. summarize(sin, "sin")
  160. # if q.shape[1] == 2 and k.shape[1] == 1 and k.shape[2] == 1 and not already_dumped_rope:
  161. # already_dumped_rope = True
  162. # print("Dumping input tensors")
  163. # save_tensor(q, "reference/tensors/testrope_q_in.bin")
  164. # save_tensor(k, "reference/tensors/testrope_k_in.bin")
  165. # save_tensor(cos, "reference/tensors/testrope_cos_in.bin")
  166. # save_tensor(sin, "reference/tensors/testrope_sin_in.bin")
  167. if position_ids:
  168. summarize(position_ids, "position_ids")
  169. # print(f"Rotary dim is {cos.unsqueeze(unsqueeze_dim).shape[-1]}")
  170. # call original
  171. q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
  172. # log outputs
  173. summarize(q_out, "RoPE.q_out")
  174. summarize(k_out, "RoPE.k_out")
  175. return q_out, k_out
  176. def patched_torch_chunk_gated_delta_rule(
  177. query,
  178. key,
  179. value,
  180. g,
  181. beta,
  182. chunk_size=64,
  183. initial_state=None,
  184. output_final_state=False,
  185. use_qk_l2norm_in_kernel=False,
  186. long=False
  187. ):
  188. torch.set_printoptions(threshold=10_000_000, sci_mode=False, precision=10, linewidth=200)
  189. initial_dtype = query.dtype
  190. [ summarize(x, y) for (x, y) in ((query, "q_prenorm"), (key, "k_prenorm")) ]
  191. if use_qk_l2norm_in_kernel:
  192. query = l2norm(query, dim=-1, eps=1e-6)
  193. key = l2norm(key, dim=-1, eps=1e-6)
  194. [ summarize(x, y) for (x, y) in ((query, "q_orig"), (key, "k_orig"), (value, "v_orig"), (beta, "b_orig"), (g, "g_orig")) ]
  195. query, key, value, beta, g = [
  196. x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
  197. ]
  198. [ summarize(x, y) for (x, y) in ((query, "q_tra"), (key, "k_tra"), (value, "v_tra"), (beta, "b_tra"), (g, "g_tra")) ]
  199. batch_size, sequence_length, num_heads, k_head_dim = key.shape
  200. print(f"batch_size = {batch_size}, seq_len = {sequence_length}, num_heads = {num_heads}, k_head_dim = {k_head_dim}")
  201. v_head_dim = value.shape[-1]
  202. pad_size = (chunk_size - num_heads % chunk_size) % chunk_size
  203. print(f"Pad size = {pad_size}, chunk_size = {chunk_size}")
  204. query = F.pad(query, (0, 0, 0, pad_size))
  205. key = F.pad(key, (0, 0, 0, pad_size))
  206. value = F.pad(value, (0, 0, 0, pad_size))
  207. beta = F.pad(beta, (0, pad_size))
  208. g = F.pad(g, (0, pad_size))
  209. [ summarize(x, y) for (x, y) in ((query, "q_pad"), (key, "k_pad"), (value, "v_pad"), (beta, "b_pad"), (g, "g_pad")) ]
  210. tot_heads = num_heads + pad_size
  211. scale = 1 / (query.shape[-1] ** 0.5)
  212. print(f"Scale for delta is {scale} (from {query.shape[-1]})")
  213. query = query * scale
  214. summarize(query, "q_scaled")
  215. summarize(key, "k")
  216. summarize(beta.unsqueeze(-1), "beta")
  217. v_beta = value * beta.unsqueeze(-1)
  218. k_beta = key * beta.unsqueeze(-1)
  219. summarize(k_beta, "k_beta")
  220. summarize(v_beta, "v_beta")
  221. # reshape to chunks
  222. query, key, value, k_beta, v_beta = [
  223. x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)
  224. ]
  225. g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
  226. [ summarize(x, y) for (x, y) in ((query, "q_resh"), (k_beta, "k_beta_resh"), (v_beta, "v_beta_resh"), (key, "k_resh"), (value, "v_resh")) ]
  227. mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
  228. # chunk decay
  229. g = g.cumsum(dim=-1)
  230. summarize(g, "g_cumsum")
  231. sub = g.unsqueeze(-1) - g.unsqueeze(-2)
  232. bt1, bt2 = torch.broadcast_tensors(g.unsqueeze(-1), g.unsqueeze(-2))
  233. summarize(bt1, "bt1")
  234. summarize(bt2, "bt2")
  235. summarize(sub, "sub")
  236. decay_mask = sub.tril()
  237. summarize(decay_mask, "sub_tril")
  238. decay_mask = decay_mask.exp()
  239. summarize(decay_mask, "sub_tril_exp")
  240. decay_mask = decay_mask.float()
  241. summarize(decay_mask, "sub_tril_exp_float")
  242. decay_mask = decay_mask.tril()
  243. summarize(decay_mask, "decay_mask")
  244. k_t = key.transpose(-1, -2)
  245. summarize(k_t, "k_t")
  246. kmul = k_beta @ k_t
  247. summarize(kmul, "k_beta @ k_t")
  248. #if not long:
  249. #print(f"k_beta @ k_t:\n{kmul[:,:,:,:8,:8]}\n\n")
  250. kmul_decay = kmul * decay_mask
  251. summarize(kmul_decay, "(k_beta @ k_t) * decay_mask")
  252. attn = -(kmul_decay).masked_fill(mask, 0)
  253. summarize(attn, "attn_in")
  254. for i in range(1, chunk_size):
  255. row = attn[..., i, :i].clone()
  256. sub = attn[..., :i, :i].clone()
  257. attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
  258. #if i <= num_heads and not long:
  259. #print(f"Chunk {i}: row:\n{row}\n\nsub:\n{sub}\nrow_unsq:\n{row.unsqueeze(-1)}\nrow_unsq * sub:\n{row.unsqueeze(-1)*sub}\n")
  260. #print(f"attn => sum = {attn[..., i, :i].sum()}, tensor: \n{attn[..., i, :i]}\n\n")
  261. summarize(attn, "attn_chunks")
  262. attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
  263. summarize(attn, "attn_eye")
  264. value = attn @ v_beta
  265. summarize(value, "value")
  266. k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
  267. summarize(k_cumdecay, "k_cumdecay")
  268. last_recurrent_state = (
  269. torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value)
  270. if initial_state is None
  271. else initial_state.to(value)
  272. )
  273. core_attn_out = torch.zeros_like(value)
  274. mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
  275. # for each chunk
  276. for i in range(0, tot_heads // chunk_size):
  277. print(f"\n=== Processing chunk {i} ===")
  278. q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
  279. summarize(q_i, f"q_i_chunk_{i}")
  280. summarize(k_i, f"k_i_chunk_{i}")
  281. summarize(v_i, f"v_i_chunk_{i}")
  282. attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
  283. summarize(attn, f"attn_chunk_{i}")
  284. v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
  285. summarize(v_prime, f"v_prime_chunk_{i}")
  286. v_new = v_i - v_prime
  287. summarize(v_new, f"v_new_chunk_{i}")
  288. attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
  289. summarize(attn_inter, f"attn_inter_chunk_{i}")
  290. core_attn_out[:, :, i] = attn_inter + attn @ v_new
  291. summarize(core_attn_out[:, :, i], f"core_attn_out_chunk_{i}")
  292. g_last = g[:, :, i, -1, None, None].exp()
  293. summarize(g_last, f"g_last_chunk_{i}")
  294. g_diff_exp = (g[:, :, i, -1, None] - g[:, :, i]).exp()
  295. last_recurrent_state = (
  296. last_recurrent_state * g_last
  297. + (k_i * g_diff_exp[..., None]).transpose(-1, -2) @ v_new
  298. )
  299. summarize(last_recurrent_state, f"updated_state_chunk_{i}")
  300. if not output_final_state:
  301. last_recurrent_state = None
  302. core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
  303. core_attn_out = core_attn_out[:, :, :num_heads]
  304. core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
  305. summarize(core_attn_out, "attn_out")
  306. if not long:
  307. print(f"attn_out:\n{core_attn_out}\n\n")
  308. if isinstance(last_recurrent_state, torch.Tensor):
  309. summarize(last_recurrent_state, "state_out")
  310. if not long:
  311. print(f"state_out:\n{last_recurrent_state}\n\n")
  312. return core_attn_out, last_recurrent_state
  313. def patched_torch_recurrent_gated_delta_rule(
  314. query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False
  315. ):
  316. initial_dtype = query.dtype
  317. if use_qk_l2norm_in_kernel:
  318. query = l2norm(query, dim=-1, eps=1e-6)
  319. key = l2norm(key, dim=-1, eps=1e-6)
  320. query, key, value, beta, g = [
  321. x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
  322. ]
  323. summarize(query, "q_t")
  324. summarize(key, "k_t")
  325. summarize(value, "v_t")
  326. summarize(beta, "beta_t")
  327. summarize(g, "g_t")
  328. batch_size, num_heads, sequence_length, k_head_dim = key.shape
  329. v_head_dim = value.shape[-1]
  330. scale = 1 / (query.shape[-1] ** 0.5)
  331. query = query * scale
  332. summarize(query, "q_scaled")
  333. if initial_state is not None:
  334. summarize(initial_state, "initial_state")
  335. core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value)
  336. last_recurrent_state = (
  337. torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
  338. if initial_state is None
  339. else initial_state.to(value)
  340. )
  341. for i in range(sequence_length):
  342. q_t = query[:, :, i]
  343. k_t = key[:, :, i]
  344. v_t = value[:, :, i]
  345. g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
  346. summarize(g_t, "g_exp_unsq")
  347. beta_t = beta[:, :, i].unsqueeze(-1)
  348. summarize(beta_t, "beta_t_unsq")
  349. last_recurrent_state = last_recurrent_state * g_t
  350. summarize(last_recurrent_state, "gated_state")
  351. k_unsq = k_t.unsqueeze(-1)
  352. summarize(k_unsq, "k_unsqueeze")
  353. state_k = last_recurrent_state * k_unsq
  354. summarize(state_k, "state_k_product")
  355. kv_mem = state_k.sum(dim=-2)
  356. summarize(kv_mem, "kv_mem")
  357. delta = (v_t - kv_mem) * beta_t
  358. summarize(delta, "delta")
  359. k_delta = k_t.unsqueeze(-1) * delta.unsqueeze(-2)
  360. summarize(k_delta, "k_delta")
  361. last_recurrent_state = last_recurrent_state + k_delta
  362. summarize(last_recurrent_state, "state_plus_k_delta")
  363. state_q_prod = last_recurrent_state * q_t.unsqueeze(-1)
  364. summarize(state_q_prod, "state_q_product")
  365. core_attn_out[:, :, i] = state_q_prod.sum(dim=-2)
  366. summarize(core_attn_out, "core_attn_out")
  367. if not output_final_state:
  368. last_recurrent_state = None
  369. core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
  370. return core_attn_out, last_recurrent_state
  371. import transformers.models.qwen3_next.modeling_qwen3_next as qwen_mod # noqa: E402
  372. qwen_mod.torch_chunk_gated_delta_rule = patched_torch_chunk_gated_delta_rule
  373. qwen_mod.torch_causal_conv1d_update = patched_torch_causal_conv1d_update
  374. qwen_mod.apply_rotary_pos_emb = patched_apply_rope
  375. qwen_mod.torch_recurrent_gated_delta_rule = patched_torch_recurrent_gated_delta_rule
  376. # Store original functions for patching
  377. original_functions = {}
  378. def debug_hook(name):
  379. def fn(_m, input, output):
  380. if isinstance(input, torch.Tensor):
  381. summarize(input, name + "_in")
  382. elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor):
  383. summarize(input[0], name + "_in")
  384. if isinstance(output, torch.Tensor):
  385. summarize(output, name + "_out")
  386. elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor):
  387. summarize(output[0], name + "_out")
  388. return fn
  389. def patch_all_forward_methods(model):
  390. """Apply monkey patches to all forward methods in the model"""
  391. for name, module in model.named_modules():
  392. # Set layer index if applicable
  393. parts = name.split('.')
  394. module.layer_idx = -1 # Default invalid value
  395. if len(parts) > 2 and parts[0] == 'model' and parts[1] == 'layers':
  396. try:
  397. module.layer_idx = int(parts[2]) # Convert to integer
  398. except (ValueError, IndexError):
  399. module.layer_idx = -1
  400. # Apply forward hook to log all inputs/outputs
  401. module.register_forward_hook(debug_hook(name))
  402. # Additional patches for specific methods in various modules
  403. if hasattr(module, 'forward'):
  404. original_forward = module.forward
  405. def make_patched_forward(orig_forward, mod_name):
  406. def patched_forward(*args, **kwargs):
  407. # Log inputs
  408. for i, arg in enumerate(args):
  409. if isinstance(arg, torch.Tensor):
  410. summarize(arg, f"{mod_name}.forward.arg_{i}_in")
  411. # Call original forward
  412. result = orig_forward(*args, **kwargs)
  413. if mod_name.endswith("linear_attn"):
  414. cache = kwargs["cache_params"]
  415. nameparts = mod_name.split(".")
  416. layer_idx = -1
  417. try:
  418. layer_idx = int(nameparts[2])
  419. except (ValueError, IndexError):
  420. print(f"\n\nDEBUG: Failed to calculate layer index for module: {mod_name}\n\n")
  421. rec_cache = cache.recurrent_states[layer_idx]
  422. if rec_cache is not None:
  423. summarize(rec_cache, f"recurrent_cache_{layer_idx}")
  424. # Log output
  425. if isinstance(result, torch.Tensor):
  426. summarize(result, f"{mod_name}.forward.out")
  427. elif isinstance(result, (tuple, list)):
  428. for i, res in enumerate(result):
  429. if isinstance(res, torch.Tensor):
  430. summarize(res, f"{mod_name}.forward.out_{i}")
  431. return result
  432. return patched_forward
  433. module.forward = make_patched_forward(original_forward, name)
  434. def patch_silu():
  435. """Patch torch.nn.functional.silu to log inputs and outputs"""
  436. global original_functions
  437. if 'silu' not in original_functions:
  438. original_functions['silu'] = torch.nn.functional.silu
  439. def patched_silu(input, inplace=False):
  440. # Log input
  441. summarize(input, "silu_in")
  442. # Call original function
  443. result = original_functions['silu'](input, inplace)
  444. # Log output
  445. summarize(result, "silu_out")
  446. return result
  447. # Replace the function in the torch.nn.functional module
  448. torch.nn.functional.silu = patched_silu
  449. def patch_sigmoid():
  450. """Patch torch.nn.functional.sigmoid to log inputs and outputs"""
  451. global original_functions
  452. if 'sigmoid' not in original_functions:
  453. original_functions['sigmoid'] = torch.nn.functional.sigmoid
  454. def patched_sigmoid(input):
  455. # Log input
  456. summarize(input, "sigmoid_in")
  457. # Call original function
  458. result = original_functions['sigmoid'](input)
  459. # Log output
  460. summarize(result, "sigmoid_out")
  461. return result
  462. # Replace the function in the torch.nn.functional module
  463. torch.nn.functional.sigmoid = patched_sigmoid
  464. def patch_torch_sigmoid():
  465. """Patch torch.nn.functional.sigmoid to log inputs and outputs"""
  466. global original_functions
  467. if 'torch_sigmoid' not in original_functions:
  468. original_functions['torch_sigmoid'] = torch.sigmoid
  469. def patched_torch_sigmoid(input):
  470. # Log input
  471. summarize(input, "torch_sigmoid_in")
  472. # Call original function
  473. result = original_functions['torch_sigmoid'](input)
  474. # Log output
  475. summarize(result, "torch_sigmoid_out")
  476. return result
  477. # Replace the function in the torch.nn.functional module
  478. torch.sigmoid = patched_torch_sigmoid
  479. def patch_pad():
  480. """Patch torch.nn.functional.pad to log inputs and outputs"""
  481. global original_functions
  482. if 'pad' not in original_functions:
  483. original_functions['pad'] = torch.nn.functional.pad
  484. def patched_pad(input: torch.Tensor, pad: typing.Sequence[int], mode: str = 'constant', value: float | None = None): # pyright: ignore[reportGeneralTypeIssues]
  485. # Log input
  486. summarize(input, "pad_in")
  487. print(f"Padding shape is {pad}")
  488. # Call original function
  489. result = original_functions['pad'](input=input, pad=pad, mode=mode, value=value)
  490. # Log output
  491. summarize(result, "pad_out")
  492. return result
  493. # Replace the function in the torch.nn.functional module
  494. torch.nn.functional.pad = patched_pad
  495. def save_kv_cache(past_key_values, step_num, data_dir, model_name):
  496. """Save KV cache tensors for each layer"""
  497. cache_dir = data_dir / f"kv_cache_step_{step_num}"
  498. cache_dir.mkdir(exist_ok=True)
  499. # Access past_key_values if available
  500. if past_key_values is not None:
  501. for layer_idx, cache_tuple in enumerate(past_key_values):
  502. if cache_tuple is None:
  503. print(f"Cache tuple is None for layer {layer_idx} at step {step_num}")
  504. continue
  505. # Handle different cache formats
  506. if isinstance(cache_tuple, (tuple, list)) and len(cache_tuple) >= 2:
  507. key, value = cache_tuple[0], cache_tuple[1]
  508. # Check if key and value are not None
  509. if key is not None and value is not None:
  510. # Save key cache
  511. key_filename = cache_dir / f"layer_{layer_idx}_key.bin"
  512. key.detach().cpu().numpy().astype(np.float32).tofile(key_filename)
  513. # Save value cache
  514. value_filename = cache_dir / f"layer_{layer_idx}_value.bin"
  515. value.detach().cpu().numpy().astype(np.float32).tofile(value_filename)
  516. print(f"Saved KV cache for layer {layer_idx} at step {step_num}: key.shape={key.shape}, value.shape={value.shape}")
  517. else:
  518. print(f"Key or value is None for layer {layer_idx} at step {step_num}")
  519. else:
  520. # Handle other cache formats (e.g., recurrent models)
  521. print(f"Non-standard cache format for layer {layer_idx} at step {step_num}: {type(cache_tuple)}")
  522. # Save as generic cache if it's a tensor
  523. if hasattr(cache_tuple, 'detach'):
  524. cache_filename = cache_dir / f"layer_{layer_idx}_cache.bin"
  525. cache_tuple.detach().cpu().numpy().astype(np.float32).tofile(cache_filename)
  526. print(f"Saved generic cache for layer {layer_idx} at step {step_num}: shape={cache_tuple.shape}")
  527. else:
  528. print(f"No KV cache available at step {step_num}")
  529. unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
  530. parser = argparse.ArgumentParser(description="Process model with specified path")
  531. parser.add_argument("--model-path", "-m", help="Path to the model")
  532. parser.add_argument("--num-tokens", "-n", type=int, default=5, help="Number of tokens to generate")
  533. parser.add_argument("--prompt", "-p", default="Hello, my name is", help="Input prompt")
  534. parser.add_argument("--save-cache", action="store_true", help="Save KV cache at each step")
  535. args = parser.parse_args()
  536. model_path = os.environ.get("MODEL_PATH", args.model_path)
  537. if model_path is None:
  538. parser.error(
  539. "Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
  540. )
  541. config = AutoConfig.from_pretrained(model_path)
  542. print("Model type: ", config.model_type)
  543. print("Vocab size: ", config.vocab_size)
  544. print("Hidden size: ", config.hidden_size)
  545. print("Number of layers: ", config.num_hidden_layers)
  546. print("BOS token id: ", config.bos_token_id)
  547. print("EOS token id: ", config.eos_token_id)
  548. print("Loading model and tokenizer using AutoTokenizer:", model_path)
  549. tokenizer = AutoTokenizer.from_pretrained(model_path)
  550. config = AutoConfig.from_pretrained(model_path)
  551. if unreleased_model_name:
  552. model_name_lower = unreleased_model_name.lower()
  553. unreleased_module_path = (
  554. f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
  555. )
  556. class_name = f"{unreleased_model_name}ForCausalLM"
  557. print(f"Importing unreleased model module: {unreleased_module_path}")
  558. try:
  559. model_class = getattr(
  560. importlib.import_module(unreleased_module_path), class_name
  561. )
  562. model = model_class.from_pretrained(
  563. model_path
  564. ) # Note: from_pretrained, not fromPretrained
  565. except (ImportError, AttributeError) as e:
  566. print(f"Failed to import or load model: {e}")
  567. exit(1)
  568. else:
  569. model = AutoModelForCausalLM.from_pretrained(
  570. model_path, device_map="auto", offload_folder="offload"
  571. )
  572. patch_all_forward_methods(model)
  573. patch_silu()
  574. patch_pad()
  575. patch_sigmoid()
  576. patch_torch_sigmoid()
  577. model_name = os.path.basename(model_path)
  578. # Printing the Model class to allow for easier debugging. This can be useful
  579. # when working with models that have not been publicly released yet and this
  580. # migth require that the concrete class is imported and used directly instead
  581. # of using AutoModelForCausalLM.
  582. print(f"Model class: {model.__class__.__name__}")
  583. device = next(model.parameters()).device
  584. prompt = args.prompt
  585. input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
  586. print(f"Input tokens: {input_ids}")
  587. print(f"Input text: {repr(prompt)}")
  588. print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
  589. data_dir = Path("data")
  590. data_dir.mkdir(exist_ok=True)
  591. # Store all generated tokens and logits
  592. all_generated_tokens = []
  593. all_logits = []
  594. with torch.no_grad():
  595. # Initial forward pass
  596. print(f"\n=== Initial Forward Pass ===")
  597. outputs = model(input_ids, use_cache=True)
  598. logits = outputs.logits
  599. # Extract logits for the last token (next token prediction)
  600. last_logits = logits[0, -1, :].cpu().numpy()
  601. all_logits.append(last_logits)
  602. print(f"Logits shape: {logits.shape}")
  603. print(f"Last token logits shape: {last_logits.shape}")
  604. # Generate first token
  605. next_token_id = np.argmax(last_logits).item()
  606. all_generated_tokens.append(next_token_id)
  607. # Show top 5 predicted tokens for first step
  608. top_indices = np.argsort(last_logits)[-5:][::-1]
  609. print("Top 5 predictions for first token:")
  610. for idx in top_indices:
  611. token = tokenizer.decode([idx])
  612. print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
  613. print(f"Generated token {next_token_id} ({repr(tokenizer.decode([next_token_id]))})")
  614. # Save KV cache if requested
  615. if args.save_cache:
  616. save_kv_cache(outputs.past_key_values, 0, data_dir, model_name)
  617. # Prepare for next iteration
  618. past_key_values = outputs.past_key_values
  619. current_input = torch.tensor([[next_token_id]], device=device)
  620. # Generate remaining tokens
  621. for step in range(1, args.num_tokens):
  622. print(f"\n=== Generation Step {step} ===")
  623. # Forward pass with cache
  624. outputs = model(
  625. input_ids=current_input,
  626. past_key_values=past_key_values,
  627. use_cache=True
  628. )
  629. logits = outputs.logits
  630. last_logits = logits[0, -1, :].cpu().numpy()
  631. all_logits.append(last_logits)
  632. # Generate next token
  633. next_token_id = np.argmax(last_logits).item()
  634. all_generated_tokens.append(next_token_id)
  635. # Show top 5 predicted tokens for this step
  636. top_indices = np.argsort(last_logits)[-5:][::-1]
  637. print(f"Top 5 predictions for step {step}:")
  638. for idx in top_indices:
  639. token = tokenizer.decode([idx])
  640. print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
  641. print(f"Generated token {next_token_id} ({repr(tokenizer.decode([next_token_id]))})")
  642. # Save KV cache if requested
  643. if args.save_cache:
  644. save_kv_cache(outputs.past_key_values, step, data_dir, model_name)
  645. # Update for next iteration
  646. past_key_values = outputs.past_key_values
  647. current_input = torch.tensor([[next_token_id]], device=device)
  648. # Save results
  649. bin_filename = data_dir / f"pytorch-{model_name}-multi-token.bin"
  650. txt_filename = data_dir / f"pytorch-{model_name}-multi-token.txt"
  651. # Save all logits concatenated
  652. all_logits_array = np.array(all_logits)
  653. all_logits_array.astype(np.float32).tofile(bin_filename)
  654. # Also save as text file for easy inspection
  655. with open(txt_filename, "w") as f:
  656. f.write(f"Generated tokens: {all_generated_tokens}\n")
  657. f.write(f"Generated text: {repr(tokenizer.decode(all_generated_tokens))}\n")
  658. f.write(f"Full sequence: {repr(tokenizer.decode(input_ids[0].tolist() + all_generated_tokens))}\n\n")
  659. for step, logits in enumerate(all_logits):
  660. f.write(f"=== Step {step} logits ===\n")
  661. for i, logit in enumerate(logits):
  662. f.write(f"{i}: {logit:.6f}\n")
  663. f.write("\n")
  664. print(f"\n=== Generation Complete ===")
  665. print(f"Generated {len(all_generated_tokens)} tokens: {all_generated_tokens}")
  666. print(f"Generated text: {repr(tokenizer.decode(all_generated_tokens))}")
  667. print(f"Full sequence: {repr(tokenizer.decode(input_ids[0].tolist() + all_generated_tokens))}")
  668. print(f"Saved bin logits to: {bin_filename}")
  669. print(f"Saved txt logits to: {txt_filename}")
  670. if args.save_cache:
  671. print(f"KV cache saved to: {data_dir}/kv_cache_step_*")