run-org-model-multi-token.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835
  1. #!/usr/bin/env python3
  2. import argparse
  3. import os
  4. import importlib
  5. from pathlib import Path
  6. import re
  7. from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
  8. import torch
  9. import numpy as np
  10. ### If you want to dump RoPE activations, apply this monkey patch to the model
  11. ### class from Transformers that you are running (replace apertus.modeling_apertus
  12. ### with the proper package and class for your model
  13. ### === START ROPE DEBUG ===
  14. # from transformers.models.apertus.modeling_apertus import apply_rotary_pos_emb
  15. # orig_rope = apply_rotary_pos_emb
  16. # torch.set_printoptions(threshold=float('inf'))
  17. # torch.set_printoptions(precision=6, sci_mode=False)
  18. # def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  19. # # log inputs
  20. # summarize(q, "RoPE.q_in")
  21. # summarize(k, "RoPE.k_in")
  22. # # call original
  23. # q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
  24. # # log outputs
  25. # summarize(q_out, "RoPE.q_out")
  26. # summarize(k_out, "RoPE.k_out")
  27. # return q_out, k_out
  28. # # Patch it
  29. # import transformers.models.apertus.modeling_apertus as apertus_mod # noqa: E402
  30. # apertus_mod.apply_rotary_pos_emb = debug_rope
  31. ### == END ROPE DEBUG ===
  32. token_counter = {}
  33. layer_counter = {}
  34. num_model_layers = 0
  35. def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
  36. global num_model_layers, layer_counter, token_counter
  37. """
  38. Print a tensor in llama.cpp debug style.
  39. Supports:
  40. - 2D tensors (seq, hidden)
  41. - 3D tensors (batch, seq, hidden)
  42. - 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
  43. - 5D tensors
  44. Shows first and last max_vals of each vector per sequence position.
  45. """
  46. t = tensor.detach().to(torch.float32).cpu()
  47. ten_shape = t.shape
  48. while t.ndim > 4:
  49. t = t.squeeze(0)
  50. # Determine dimensions
  51. if t.ndim == 3:
  52. _, s, _ = t.shape
  53. elif t.ndim == 2:
  54. _, s = 1, t.shape[0]
  55. t = t.unsqueeze(0)
  56. elif t.ndim == 4:
  57. _, s, _, _ = t.shape
  58. else:
  59. print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
  60. return
  61. print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}")
  62. print(" [")
  63. print(" [")
  64. # Determine indices for first and last sequences
  65. first_indices = list(range(min(s, max_seq)))
  66. last_indices = list(range(max(0, s - max_seq), s))
  67. # Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
  68. has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s)
  69. # Combine indices
  70. if has_overlap:
  71. # If there's overlap, just use the combined unique indices
  72. indices = sorted(list(set(first_indices + last_indices)))
  73. separator_index = None
  74. else:
  75. # If no overlap, we'll add a separator between first and last sequences
  76. indices = first_indices + last_indices
  77. separator_index = len(first_indices)
  78. for i, si in enumerate(indices):
  79. # Add separator if needed
  80. if separator_index is not None and i == separator_index:
  81. print(" ...")
  82. # Extract appropriate slice
  83. vec = t[0, si]
  84. if vec.ndim == 2: # 4D case: flatten heads × dim_per_head
  85. flat = vec.flatten().tolist()
  86. else: # 2D or 3D case
  87. flat = vec.tolist()
  88. # First and last slices
  89. first = flat[:max_vals]
  90. last = flat[-max_vals:] if len(flat) >= 2 * max_vals else flat
  91. first_str = ", ".join(f"{v:12.4f}" for v in first)
  92. last_str = ", ".join(f"{v:12.4f}" for v in last)
  93. if len(flat) >= 2 * max_vals:
  94. print(f" [{first_str}, ..., {last_str}]")
  95. else:
  96. print(f" [{last_str}]")
  97. print(" ],")
  98. print(" ]")
  99. print(f" sum = {t.sum().item():.6f}\n")
  100. indexed_patterns = [ r"model\.layers\.[0-9]+_out", r"recurrent_cache_[0-9]+" ]
  101. non_indexed_patterns = [ r"k_pad", r"v_pad", r"q_pad" ]
  102. if any(re.fullmatch(p, name) for p in indexed_patterns):
  103. if name not in token_counter:
  104. token_counter[name] = 1
  105. else:
  106. token_counter[name] = token_counter[name] + 1
  107. save_tensor(t, f"reference/tensors/org/{name}_{token_counter[name]}.bin")
  108. if any(re.fullmatch(p, name) for p in non_indexed_patterns):
  109. if name not in token_counter:
  110. token_counter[name] = 1
  111. else:
  112. token_counter[name] = token_counter[name] + 1
  113. if name not in layer_counter or layer_counter[name] == num_model_layers - 1:
  114. layer_counter[name] = 0
  115. else:
  116. layer_counter[name] = layer_counter[name] + 1
  117. save_tensor(t, f"reference/tensors/org/{name}_{layer_counter[name] - 1}_{token_counter[name]}.bin")
  118. from transformers.models.qwen3_next.modeling_qwen3_next import torch_causal_conv1d_update, apply_rotary_pos_emb, l2norm # noqa: E402
  119. orig_conv1d_update = torch_causal_conv1d_update
  120. orig_rope = apply_rotary_pos_emb
  121. import torch.nn.functional as F # noqa: E402
  122. import typing # noqa: E402
  123. def patched_torch_causal_conv1d_update(
  124. hidden_states,
  125. conv_state,
  126. weight,
  127. bias=None,
  128. activation=None,
  129. ):
  130. _, hidden_size, seq_len = hidden_states.shape
  131. state_len = conv_state.shape[-1]
  132. summarize(hidden_states, "hidden_states_in")
  133. summarize(conv_state, "conv_state_in")
  134. hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype)
  135. summarize(hidden_states_new, "hidden_states_new")
  136. summarize(hidden_states_new[:, :, -state_len:], "hidden_states_to_copy")
  137. summarize(conv_state, "conv_state_pre")
  138. conv_state.copy_(hidden_states_new[:, :, -state_len:])
  139. summarize(conv_state, "conv_state_post")
  140. out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size)
  141. summarize(out, "out")
  142. summarize(out[:, :, -seq_len:], "out_proper")
  143. out = F.silu(out[:, :, -seq_len:])
  144. summarize(out, "out_silu")
  145. out = out.to(hidden_states.dtype)
  146. return out
  147. already_dumped_rope = False
  148. def save_tensor(tensor, filename):
  149. """Save tensor to binary file with shape information."""
  150. # Ensure tensors directory exists
  151. os.makedirs(os.path.dirname(filename), exist_ok=True)
  152. # Convert to numpy and save
  153. np_array = tensor.detach().cpu().numpy()
  154. # Save shape first (4 int64 values), then data
  155. with open(filename, 'wb') as f:
  156. shape = list(np_array.shape)
  157. while len(shape) < 4:
  158. shape.insert(0, 0)
  159. # Write shape as int64
  160. shape_array = np.array(shape, dtype=np.int64)
  161. f.write(shape_array.tobytes())
  162. # Write data as float32
  163. np_array_float32 = np_array.astype(np.float32)
  164. f.write(np_array_float32.tobytes())
  165. def patched_apply_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  166. global already_dumped_rope
  167. # log inputs
  168. summarize(q, "RoPE.q_in")
  169. summarize(k, "RoPE.k_in")
  170. summarize(cos, "cos")
  171. summarize(sin, "sin")
  172. # if q.shape[1] == 2 and k.shape[1] == 1 and k.shape[2] == 1 and not already_dumped_rope:
  173. # already_dumped_rope = True
  174. # print("Dumping input tensors")
  175. # save_tensor(q, "reference/tensors/testrope_q_in.bin")
  176. # save_tensor(k, "reference/tensors/testrope_k_in.bin")
  177. # save_tensor(cos, "reference/tensors/testrope_cos_in.bin")
  178. # save_tensor(sin, "reference/tensors/testrope_sin_in.bin")
  179. if position_ids:
  180. summarize(position_ids, "position_ids")
  181. # print(f"Rotary dim is {cos.unsqueeze(unsqueeze_dim).shape[-1]}")
  182. # call original
  183. q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
  184. # log outputs
  185. summarize(q_out, "RoPE.q_out")
  186. summarize(k_out, "RoPE.k_out")
  187. return q_out, k_out
  188. def patched_torch_chunk_gated_delta_rule(
  189. query,
  190. key,
  191. value,
  192. g,
  193. beta,
  194. chunk_size=64,
  195. initial_state=None,
  196. output_final_state=False,
  197. use_qk_l2norm_in_kernel=False
  198. ):
  199. initial_dtype = query.dtype
  200. [ summarize(x, y) for (x, y) in ((query, "q_prenorm"), (key, "k_prenorm")) ]
  201. if use_qk_l2norm_in_kernel:
  202. query = l2norm(query, dim=-1, eps=1e-6)
  203. key = l2norm(key, dim=-1, eps=1e-6)
  204. [ summarize(x, y) for (x, y) in ((query, "q_orig"), (key, "k_orig"), (value, "v_orig"), (beta, "b_orig"), (g, "g_orig")) ]
  205. query, key, value, beta, g = [
  206. x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
  207. ]
  208. [ summarize(x, y) for (x, y) in ((query, "q_tra"), (key, "k_tra"), (value, "v_tra"), (beta, "b_tra"), (g, "g_tra")) ]
  209. batch_size, sequence_length, num_heads, k_head_dim = key.shape
  210. print(f"batch_size = {batch_size}, seq_len = {sequence_length}, num_heads = {num_heads}, k_head_dim = {k_head_dim}")
  211. v_head_dim = value.shape[-1]
  212. pad_size = (chunk_size - num_heads % chunk_size) % chunk_size
  213. print(f"Pad size = {pad_size}, chunk_size = {chunk_size}")
  214. query = F.pad(query, (0, 0, 0, pad_size))
  215. key = F.pad(key, (0, 0, 0, pad_size))
  216. value = F.pad(value, (0, 0, 0, pad_size))
  217. beta = F.pad(beta, (0, pad_size))
  218. g = F.pad(g, (0, pad_size))
  219. [ summarize(x, y) for (x, y) in ((query, "q_pad"), (key, "k_pad"), (value, "v_pad"), (beta, "b_pad"), (g, "g_pad")) ]
  220. tot_heads = num_heads + pad_size
  221. scale = 1 / (query.shape[-1] ** 0.5)
  222. print(f"Scale for delta is {scale} (from {query.shape[-1]})")
  223. query = query * scale
  224. summarize(query, "q_scaled")
  225. summarize(key, "k")
  226. summarize(beta.unsqueeze(-1), "beta")
  227. v_beta = value * beta.unsqueeze(-1)
  228. k_beta = key * beta.unsqueeze(-1)
  229. summarize(k_beta, "k_beta")
  230. summarize(v_beta, "v_beta")
  231. # reshape to chunks
  232. query, key, value, k_beta, v_beta = [
  233. x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)
  234. ]
  235. g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
  236. [ summarize(x, y) for (x, y) in ((query, "q_resh"), (k_beta, "k_beta_resh"), (v_beta, "v_beta_resh"), (key, "k_resh"), (value, "v_resh")) ]
  237. mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
  238. # chunk decay
  239. g = g.cumsum(dim=-1)
  240. summarize(g, "g_cumsum")
  241. sub = g.unsqueeze(-1) - g.unsqueeze(-2)
  242. bt1, bt2 = torch.broadcast_tensors(g.unsqueeze(-1), g.unsqueeze(-2))
  243. summarize(bt1, "bt1")
  244. summarize(bt2, "bt2")
  245. summarize(sub, "sub")
  246. decay_mask = sub.tril()
  247. summarize(decay_mask, "sub_tril")
  248. decay_mask = decay_mask.exp()
  249. summarize(decay_mask, "sub_tril_exp")
  250. decay_mask = decay_mask.float()
  251. summarize(decay_mask, "sub_tril_exp_float")
  252. decay_mask = decay_mask.tril()
  253. summarize(decay_mask, "decay_mask")
  254. k_t = key.transpose(-1, -2)
  255. summarize(k_t, "k_t")
  256. kmul = k_beta @ k_t
  257. summarize(kmul, "k_beta @ k_t")
  258. #if not long:
  259. #print(f"k_beta @ k_t:\n{kmul[:,:,:,:8,:8]}\n\n")
  260. kmul_decay = kmul * decay_mask
  261. summarize(kmul_decay, "(k_beta @ k_t) * decay_mask")
  262. attn = -(kmul_decay).masked_fill(mask, 0)
  263. summarize(attn, "attn_in")
  264. for i in range(1, chunk_size):
  265. row = attn[..., i, :i].clone()
  266. sub = attn[..., :i, :i].clone()
  267. attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
  268. #if i <= num_heads and not long:
  269. #print(f"Chunk {i}: row:\n{row}\n\nsub:\n{sub}\nrow_unsq:\n{row.unsqueeze(-1)}\nrow_unsq * sub:\n{row.unsqueeze(-1)*sub}\n")
  270. #print(f"attn => sum = {attn[..., i, :i].sum()}, tensor: \n{attn[..., i, :i]}\n\n")
  271. summarize(attn, "attn_chunks")
  272. attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
  273. summarize(attn, "attn_eye")
  274. value = attn @ v_beta
  275. summarize(value, "value")
  276. k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
  277. summarize(k_cumdecay, "k_cumdecay")
  278. last_recurrent_state = (
  279. torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value)
  280. if initial_state is None
  281. else initial_state.to(value)
  282. )
  283. core_attn_out = torch.zeros_like(value)
  284. mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
  285. # for each chunk
  286. for i in range(0, tot_heads // chunk_size):
  287. print(f"\n=== Processing chunk {i} ===")
  288. q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
  289. summarize(q_i, f"q_i_chunk_{i}")
  290. summarize(k_i, f"k_i_chunk_{i}")
  291. summarize(v_i, f"v_i_chunk_{i}")
  292. attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
  293. summarize(attn, f"attn_chunk_{i}")
  294. v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
  295. summarize(v_prime, f"v_prime_chunk_{i}")
  296. v_new = v_i - v_prime
  297. summarize(v_new, f"v_new_chunk_{i}")
  298. attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
  299. summarize(attn_inter, f"attn_inter_chunk_{i}")
  300. core_attn_out[:, :, i] = attn_inter + attn @ v_new
  301. summarize(core_attn_out[:, :, i], f"core_attn_out_chunk_{i}")
  302. g_last = g[:, :, i, -1, None, None].exp()
  303. summarize(g_last, f"g_last_chunk_{i}")
  304. g_diff_exp = (g[:, :, i, -1, None] - g[:, :, i]).exp()
  305. last_recurrent_state = (
  306. last_recurrent_state * g_last
  307. + (k_i * g_diff_exp[..., None]).transpose(-1, -2) @ v_new
  308. )
  309. summarize(last_recurrent_state, f"updated_state_chunk_{i}")
  310. if not output_final_state:
  311. last_recurrent_state = None
  312. core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
  313. core_attn_out = core_attn_out[:, :, :num_heads]
  314. core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
  315. summarize(core_attn_out, "attn_out")
  316. if isinstance(last_recurrent_state, torch.Tensor):
  317. summarize(last_recurrent_state, "state_out")
  318. return core_attn_out, last_recurrent_state
  319. def patched_torch_recurrent_gated_delta_rule(
  320. query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False
  321. ):
  322. initial_dtype = query.dtype
  323. if use_qk_l2norm_in_kernel:
  324. query = l2norm(query, dim=-1, eps=1e-6)
  325. key = l2norm(key, dim=-1, eps=1e-6)
  326. query, key, value, beta, g = [
  327. x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
  328. ]
  329. summarize(query, "q_t")
  330. summarize(key, "k_t")
  331. summarize(value, "v_t")
  332. summarize(beta, "beta_t")
  333. summarize(g, "g_t")
  334. batch_size, num_heads, sequence_length, k_head_dim = key.shape
  335. v_head_dim = value.shape[-1]
  336. scale = 1 / (query.shape[-1] ** 0.5)
  337. query = query * scale
  338. summarize(query, "q_scaled")
  339. if initial_state is not None:
  340. summarize(initial_state, "initial_state")
  341. core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value)
  342. last_recurrent_state = (
  343. torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
  344. if initial_state is None
  345. else initial_state.to(value)
  346. )
  347. for i in range(sequence_length):
  348. q_t = query[:, :, i]
  349. k_t = key[:, :, i]
  350. v_t = value[:, :, i]
  351. g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
  352. summarize(g_t, "g_exp_unsq")
  353. beta_t = beta[:, :, i].unsqueeze(-1)
  354. summarize(beta_t, "beta_t_unsq")
  355. last_recurrent_state = last_recurrent_state * g_t
  356. summarize(last_recurrent_state, "gated_state")
  357. k_unsq = k_t.unsqueeze(-1)
  358. summarize(k_unsq, "k_unsqueeze")
  359. state_k = last_recurrent_state * k_unsq
  360. summarize(state_k, "state_k_product")
  361. kv_mem = state_k.sum(dim=-2)
  362. summarize(kv_mem, "kv_mem")
  363. delta = (v_t - kv_mem) * beta_t
  364. summarize(delta, "delta")
  365. k_delta = k_t.unsqueeze(-1) * delta.unsqueeze(-2)
  366. summarize(k_delta, "k_delta")
  367. last_recurrent_state = last_recurrent_state + k_delta
  368. summarize(last_recurrent_state, "state_plus_k_delta")
  369. state_q_prod = last_recurrent_state * q_t.unsqueeze(-1)
  370. summarize(state_q_prod, "state_q_product")
  371. core_attn_out[:, :, i] = state_q_prod.sum(dim=-2)
  372. summarize(core_attn_out, "core_attn_out")
  373. if not output_final_state:
  374. last_recurrent_state = None
  375. core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
  376. return core_attn_out, last_recurrent_state
  377. import transformers.models.qwen3_next.modeling_qwen3_next as qwen_mod # noqa: E402
  378. qwen_mod.torch_chunk_gated_delta_rule = patched_torch_chunk_gated_delta_rule
  379. qwen_mod.torch_causal_conv1d_update = patched_torch_causal_conv1d_update
  380. qwen_mod.apply_rotary_pos_emb = patched_apply_rope
  381. qwen_mod.torch_recurrent_gated_delta_rule = patched_torch_recurrent_gated_delta_rule
  382. # Store original functions for patching
  383. original_functions = {}
  384. def debug_hook(name):
  385. def fn(_m, input, output):
  386. if isinstance(input, torch.Tensor):
  387. summarize(input, name + "_in")
  388. elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor):
  389. summarize(input[0], name + "_in")
  390. if isinstance(output, torch.Tensor):
  391. summarize(output, name + "_out")
  392. elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor):
  393. summarize(output[0], name + "_out")
  394. return fn
  395. def patch_all_forward_methods(model):
  396. """Apply monkey patches to all forward methods in the model"""
  397. for name, module in model.named_modules():
  398. # Set layer index if applicable
  399. parts = name.split('.')
  400. module.layer_idx = -1 # Default invalid value
  401. if len(parts) > 2 and parts[0] == 'model' and parts[1] == 'layers':
  402. try:
  403. module.layer_idx = int(parts[2]) # Convert to integer
  404. except (ValueError, IndexError):
  405. module.layer_idx = -1
  406. # Apply forward hook to log all inputs/outputs
  407. module.register_forward_hook(debug_hook(name))
  408. # Additional patches for specific methods in various modules
  409. if hasattr(module, 'forward'):
  410. original_forward = module.forward
  411. def make_patched_forward(orig_forward, mod_name):
  412. def patched_forward(*args, **kwargs):
  413. # Log inputs
  414. for i, arg in enumerate(args):
  415. if isinstance(arg, torch.Tensor):
  416. summarize(arg, f"{mod_name}.forward.arg_{i}_in")
  417. # Call original forward
  418. result = orig_forward(*args, **kwargs)
  419. if mod_name.endswith("linear_attn"):
  420. cache = kwargs["cache_params"]
  421. nameparts = mod_name.split(".")
  422. layer_idx = -1
  423. try:
  424. layer_idx = int(nameparts[2])
  425. except (ValueError, IndexError):
  426. print(f"\n\nDEBUG: Failed to calculate layer index for module: {mod_name}\n\n")
  427. rec_cache = cache.recurrent_states[layer_idx]
  428. if rec_cache is not None:
  429. summarize(rec_cache, f"recurrent_cache_{layer_idx}")
  430. # Log output
  431. if isinstance(result, torch.Tensor):
  432. summarize(result, f"{mod_name}.forward.out")
  433. elif isinstance(result, (tuple, list)):
  434. for i, res in enumerate(result):
  435. if isinstance(res, torch.Tensor):
  436. summarize(res, f"{mod_name}.forward.out_{i}")
  437. return result
  438. return patched_forward
  439. module.forward = make_patched_forward(original_forward, name)
  440. def patch_silu():
  441. """Patch torch.nn.functional.silu to log inputs and outputs"""
  442. global original_functions
  443. if 'silu' not in original_functions:
  444. original_functions['silu'] = torch.nn.functional.silu
  445. def patched_silu(input, inplace=False):
  446. # Log input
  447. summarize(input, "silu_in")
  448. # Call original function
  449. result = original_functions['silu'](input, inplace)
  450. # Log output
  451. summarize(result, "silu_out")
  452. return result
  453. # Replace the function in the torch.nn.functional module
  454. torch.nn.functional.silu = patched_silu
  455. def patch_sigmoid():
  456. """Patch torch.nn.functional.sigmoid to log inputs and outputs"""
  457. global original_functions
  458. if 'sigmoid' not in original_functions:
  459. original_functions['sigmoid'] = torch.nn.functional.sigmoid
  460. def patched_sigmoid(input):
  461. # Log input
  462. summarize(input, "sigmoid_in")
  463. # Call original function
  464. result = original_functions['sigmoid'](input)
  465. # Log output
  466. summarize(result, "sigmoid_out")
  467. return result
  468. # Replace the function in the torch.nn.functional module
  469. torch.nn.functional.sigmoid = patched_sigmoid
  470. def patch_torch_sigmoid():
  471. """Patch torch.nn.functional.sigmoid to log inputs and outputs"""
  472. global original_functions
  473. if 'torch_sigmoid' not in original_functions:
  474. original_functions['torch_sigmoid'] = torch.sigmoid
  475. def patched_torch_sigmoid(input):
  476. # Log input
  477. summarize(input, "torch_sigmoid_in")
  478. # Call original function
  479. result = original_functions['torch_sigmoid'](input)
  480. # Log output
  481. summarize(result, "torch_sigmoid_out")
  482. return result
  483. # Replace the function in the torch.nn.functional module
  484. torch.sigmoid = patched_torch_sigmoid
  485. def patch_pad():
  486. """Patch torch.nn.functional.pad to log inputs and outputs"""
  487. global original_functions
  488. if 'pad' not in original_functions:
  489. original_functions['pad'] = torch.nn.functional.pad
  490. def patched_pad(input: torch.Tensor, pad: typing.Sequence[int], mode: str = 'constant', value: float | None = None): # pyright: ignore[reportGeneralTypeIssues]
  491. # Log input
  492. summarize(input, "pad_in")
  493. print(f"Padding shape is {pad}")
  494. # Call original function
  495. result = original_functions['pad'](input=input, pad=pad, mode=mode, value=value)
  496. # Log output
  497. summarize(result, "pad_out")
  498. return result
  499. # Replace the function in the torch.nn.functional module
  500. torch.nn.functional.pad = patched_pad
  501. def save_kv_cache(past_key_values, step_num, data_dir, model_name):
  502. """Save KV cache tensors for each layer"""
  503. cache_dir = data_dir / f"kv_cache_step_{step_num}"
  504. cache_dir.mkdir(exist_ok=True)
  505. # Access past_key_values if available
  506. if past_key_values is not None:
  507. for layer_idx, cache_tuple in enumerate(past_key_values):
  508. if cache_tuple is None:
  509. print(f"Cache tuple is None for layer {layer_idx} at step {step_num}")
  510. continue
  511. # Handle different cache formats
  512. if isinstance(cache_tuple, (tuple, list)) and len(cache_tuple) >= 2:
  513. key, value = cache_tuple[0], cache_tuple[1]
  514. # Check if key and value are not None
  515. if key is not None and value is not None:
  516. # Save key cache
  517. key_filename = cache_dir / f"layer_{layer_idx}_key.bin"
  518. key.detach().cpu().numpy().astype(np.float32).tofile(key_filename)
  519. # Save value cache
  520. value_filename = cache_dir / f"layer_{layer_idx}_value.bin"
  521. value.detach().cpu().numpy().astype(np.float32).tofile(value_filename)
  522. print(f"Saved KV cache for layer {layer_idx} at step {step_num}: key.shape={key.shape}, value.shape={value.shape}")
  523. else:
  524. print(f"Key or value is None for layer {layer_idx} at step {step_num}")
  525. else:
  526. # Handle other cache formats (e.g., recurrent models)
  527. print(f"Non-standard cache format for layer {layer_idx} at step {step_num}: {type(cache_tuple)}")
  528. # Save as generic cache if it's a tensor
  529. if hasattr(cache_tuple, 'detach'):
  530. cache_filename = cache_dir / f"layer_{layer_idx}_cache.bin"
  531. cache_tuple.detach().cpu().numpy().astype(np.float32).tofile(cache_filename)
  532. print(f"Saved generic cache for layer {layer_idx} at step {step_num}: shape={cache_tuple.shape}")
  533. else:
  534. print(f"No KV cache available at step {step_num}")
  535. unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
  536. parser = argparse.ArgumentParser(description="Process model with specified path")
  537. parser.add_argument("--model-path", "-m", help="Path to the model")
  538. parser.add_argument("--num-tokens", "-n", type=int, default=5, help="Number of tokens to generate")
  539. parser.add_argument("--prompt", "-p", default="Hello, my name is", help="Input prompt")
  540. parser.add_argument("--save-cache", action="store_true", help="Save KV cache at each step")
  541. args = parser.parse_args()
  542. model_path = os.environ.get("MODEL_PATH", args.model_path)
  543. if model_path is None:
  544. parser.error(
  545. "Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
  546. )
  547. config = AutoConfig.from_pretrained(model_path)
  548. print("Model type: ", config.model_type)
  549. print("Vocab size: ", config.vocab_size)
  550. print("Hidden size: ", config.hidden_size)
  551. print("Number of layers: ", config.num_hidden_layers)
  552. print("BOS token id: ", config.bos_token_id)
  553. print("EOS token id: ", config.eos_token_id)
  554. num_model_layers = config.num_hidden_layers
  555. print("Loading model and tokenizer using AutoTokenizer:", model_path)
  556. tokenizer = AutoTokenizer.from_pretrained(model_path)
  557. config = AutoConfig.from_pretrained(model_path)
  558. if unreleased_model_name:
  559. model_name_lower = unreleased_model_name.lower()
  560. unreleased_module_path = (
  561. f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
  562. )
  563. class_name = f"{unreleased_model_name}ForCausalLM"
  564. print(f"Importing unreleased model module: {unreleased_module_path}")
  565. try:
  566. model_class = getattr(
  567. importlib.import_module(unreleased_module_path), class_name
  568. )
  569. model = model_class.from_pretrained(
  570. model_path
  571. ) # Note: from_pretrained, not fromPretrained
  572. except (ImportError, AttributeError) as e:
  573. print(f"Failed to import or load model: {e}")
  574. exit(1)
  575. else:
  576. model = AutoModelForCausalLM.from_pretrained(
  577. model_path, device_map="auto", offload_folder="offload"
  578. )
  579. patch_all_forward_methods(model)
  580. patch_silu()
  581. patch_pad()
  582. patch_sigmoid()
  583. patch_torch_sigmoid()
  584. model_name = os.path.basename(model_path)
  585. # Printing the Model class to allow for easier debugging. This can be useful
  586. # when working with models that have not been publicly released yet and this
  587. # migth require that the concrete class is imported and used directly instead
  588. # of using AutoModelForCausalLM.
  589. print(f"Model class: {model.__class__.__name__}")
  590. device = next(model.parameters()).device
  591. prompt = args.prompt
  592. input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
  593. print(f"Input tokens: {input_ids}")
  594. print(f"Input text: {repr(prompt)}")
  595. print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
  596. data_dir = Path("data")
  597. data_dir.mkdir(exist_ok=True)
  598. # Store all generated tokens and logits
  599. all_generated_tokens = []
  600. all_logits = []
  601. with torch.no_grad():
  602. # Initial forward pass
  603. print(f"\n=== Initial Forward Pass ===")
  604. outputs = model(input_ids, use_cache=True)
  605. logits = outputs.logits
  606. # Extract logits for the last token (next token prediction)
  607. last_logits = logits[0, -1, :].cpu().numpy()
  608. all_logits.append(last_logits)
  609. print(f"Logits shape: {logits.shape}")
  610. print(f"Last token logits shape: {last_logits.shape}")
  611. # Generate first token
  612. next_token_id = np.argmax(last_logits).item()
  613. all_generated_tokens.append(next_token_id)
  614. # Show top 5 predicted tokens for first step
  615. top_indices = np.argsort(last_logits)[-5:][::-1]
  616. print("Top 5 predictions for first token:")
  617. for idx in top_indices:
  618. token = tokenizer.decode([idx])
  619. print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
  620. print(f"Generated token {next_token_id} ({repr(tokenizer.decode([next_token_id]))})")
  621. # Save KV cache if requested
  622. if args.save_cache:
  623. save_kv_cache(outputs.past_key_values, 0, data_dir, model_name)
  624. # Prepare for next iteration
  625. past_key_values = outputs.past_key_values
  626. current_input = torch.tensor([[next_token_id]], device=device)
  627. # Generate remaining tokens
  628. for step in range(1, args.num_tokens):
  629. print(f"\n=== Generation Step {step} ===")
  630. # Forward pass with cache
  631. outputs = model(
  632. input_ids=current_input,
  633. past_key_values=past_key_values,
  634. use_cache=True
  635. )
  636. logits = outputs.logits
  637. last_logits = logits[0, -1, :].cpu().numpy()
  638. all_logits.append(last_logits)
  639. # Generate next token
  640. next_token_id = np.argmax(last_logits).item()
  641. all_generated_tokens.append(next_token_id)
  642. # Show top 5 predicted tokens for this step
  643. top_indices = np.argsort(last_logits)[-5:][::-1]
  644. print(f"Top 5 predictions for step {step}:")
  645. for idx in top_indices:
  646. token = tokenizer.decode([idx])
  647. print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
  648. print(f"Generated token {next_token_id} ({repr(tokenizer.decode([next_token_id]))})")
  649. # Save KV cache if requested
  650. if args.save_cache:
  651. save_kv_cache(outputs.past_key_values, step, data_dir, model_name)
  652. # Update for next iteration
  653. past_key_values = outputs.past_key_values
  654. current_input = torch.tensor([[next_token_id]], device=device)
  655. # Save results
  656. bin_filename = data_dir / f"pytorch-{model_name}-multi-token.bin"
  657. txt_filename = data_dir / f"pytorch-{model_name}-multi-token.txt"
  658. # Save all logits concatenated
  659. all_logits_array = np.array(all_logits)
  660. all_logits_array.astype(np.float32).tofile(bin_filename)
  661. # Also save as text file for easy inspection
  662. with open(txt_filename, "w") as f:
  663. f.write(f"Generated tokens: {all_generated_tokens}\n")
  664. f.write(f"Generated text: {repr(tokenizer.decode(all_generated_tokens))}\n")
  665. f.write(f"Full sequence: {repr(tokenizer.decode(input_ids[0].tolist() + all_generated_tokens))}\n\n")
  666. for step, logits in enumerate(all_logits):
  667. f.write(f"=== Step {step} logits ===\n")
  668. for i, logit in enumerate(logits):
  669. f.write(f"{i}: {logit:.6f}\n")
  670. f.write("\n")
  671. print(f"\n=== Generation Complete ===")
  672. print(f"Generated {len(all_generated_tokens)} tokens: {all_generated_tokens}")
  673. print(f"Generated text: {repr(tokenizer.decode(all_generated_tokens))}")
  674. print(f"Full sequence: {repr(tokenizer.decode(input_ids[0].tolist() + all_generated_tokens))}")
  675. print(f"Saved bin logits to: {bin_filename}")
  676. print(f"Saved txt logits to: {txt_filename}")
  677. if args.save_cache:
  678. print(f"KV cache saved to: {data_dir}/kv_cache_step_*")