run-org-model-multi-token.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838
  1. #!/usr/bin/env python3
  2. import argparse
  3. import os
  4. import importlib
  5. from pathlib import Path
  6. import re
  7. from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
  8. import torch
  9. import numpy as np
  10. ### If you want to dump RoPE activations, apply this monkey patch to the model
  11. ### class from Transformers that you are running (replace apertus.modeling_apertus
  12. ### with the proper package and class for your model
  13. ### === START ROPE DEBUG ===
  14. # from transformers.models.apertus.modeling_apertus import apply_rotary_pos_emb
  15. # orig_rope = apply_rotary_pos_emb
  16. # torch.set_printoptions(threshold=float('inf'))
  17. # torch.set_printoptions(precision=6, sci_mode=False)
  18. # def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  19. # # log inputs
  20. # summarize(q, "RoPE.q_in")
  21. # summarize(k, "RoPE.k_in")
  22. # # call original
  23. # q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
  24. # # log outputs
  25. # summarize(q_out, "RoPE.q_out")
  26. # summarize(k_out, "RoPE.k_out")
  27. # return q_out, k_out
  28. # # Patch it
  29. # import transformers.models.apertus.modeling_apertus as apertus_mod # noqa: E402
  30. # apertus_mod.apply_rotary_pos_emb = debug_rope
  31. ### == END ROPE DEBUG ===
  32. token_counter = {}
  33. layer_counter = {}
  34. num_model_layers = 0
  35. def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
  36. global num_model_layers, layer_counter, token_counter
  37. """
  38. Print a tensor in llama.cpp debug style.
  39. Supports:
  40. - 2D tensors (seq, hidden)
  41. - 3D tensors (batch, seq, hidden)
  42. - 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
  43. - 5D tensors
  44. Shows first and last max_vals of each vector per sequence position.
  45. """
  46. t = tensor.detach().to(torch.float32).cpu()
  47. ten_shape = t.shape
  48. while t.ndim > 4:
  49. t = t.squeeze(0)
  50. # Determine dimensions
  51. if t.ndim == 3:
  52. _, s, _ = t.shape
  53. elif t.ndim == 2:
  54. _, s = 1, t.shape[0]
  55. t = t.unsqueeze(0)
  56. elif t.ndim == 4:
  57. _, s, _, _ = t.shape
  58. else:
  59. print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
  60. return
  61. print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}")
  62. print(" [")
  63. print(" [")
  64. # Determine indices for first and last sequences
  65. first_indices = list(range(min(s, max_seq)))
  66. last_indices = list(range(max(0, s - max_seq), s))
  67. # Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
  68. has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s)
  69. # Combine indices
  70. if has_overlap:
  71. # If there's overlap, just use the combined unique indices
  72. indices = sorted(list(set(first_indices + last_indices)))
  73. separator_index = None
  74. else:
  75. # If no overlap, we'll add a separator between first and last sequences
  76. indices = first_indices + last_indices
  77. separator_index = len(first_indices)
  78. for i, si in enumerate(indices):
  79. # Add separator if needed
  80. if separator_index is not None and i == separator_index:
  81. print(" ...")
  82. # Extract appropriate slice
  83. vec = t[0, si]
  84. if vec.ndim == 2: # 4D case: flatten heads × dim_per_head
  85. flat = vec.flatten().tolist()
  86. else: # 2D or 3D case
  87. flat = vec.tolist()
  88. # First and last slices
  89. first = flat[:max_vals]
  90. last = flat[-max_vals:] if len(flat) >= 2 * max_vals else flat
  91. first_str = ", ".join(f"{v:12.4f}" for v in first)
  92. last_str = ", ".join(f"{v:12.4f}" for v in last)
  93. if len(flat) >= 2 * max_vals:
  94. print(f" [{first_str}, ..., {last_str}]")
  95. else:
  96. print(f" [{last_str}]")
  97. print(" ],")
  98. print(" ]")
  99. print(f" sum = {t.sum().item():.6f}\n")
  100. indexed_patterns = [ r"model\.layers\.[0-9]+_out", r"recurrent_cache_[0-9]+" ]
  101. non_indexed_patterns = [ r"k_pad", r"v_pad", r"q_scaled" ]
  102. if any(re.fullmatch(p, name) for p in indexed_patterns):
  103. if name not in token_counter:
  104. token_counter[name] = 1
  105. else:
  106. token_counter[name] = token_counter[name] + 1
  107. save_tensor(t, f"reference/tensors/org/{name}_{token_counter[name]}.bin")
  108. if any(re.fullmatch(p, name) for p in non_indexed_patterns):
  109. if name not in token_counter:
  110. token_counter[name] = 1
  111. if name not in layer_counter:
  112. layer_counter[name] = 0
  113. elif layer_counter[name] >= num_model_layers:
  114. layer_counter[name] = 0
  115. token_counter[name] = token_counter[name] + 1
  116. else:
  117. layer_counter[name] = layer_counter[name] + 1
  118. if layer_counter[name] % 4 == 3:
  119. layer_counter[name] = layer_counter[name] + 1 # skip attention layers
  120. save_tensor(t, f"reference/tensors/org/{name}_{layer_counter[name]}_{token_counter[name]}.bin")
  121. from transformers.models.qwen3_next.modeling_qwen3_next import torch_causal_conv1d_update, apply_rotary_pos_emb, l2norm # noqa: E402
  122. orig_conv1d_update = torch_causal_conv1d_update
  123. orig_rope = apply_rotary_pos_emb
  124. import torch.nn.functional as F # noqa: E402
  125. import typing # noqa: E402
  126. def patched_torch_causal_conv1d_update(
  127. hidden_states,
  128. conv_state,
  129. weight,
  130. bias=None,
  131. activation=None,
  132. ):
  133. _, hidden_size, seq_len = hidden_states.shape
  134. state_len = conv_state.shape[-1]
  135. summarize(hidden_states, "hidden_states_in")
  136. summarize(conv_state, "conv_state_in")
  137. hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype)
  138. summarize(hidden_states_new, "hidden_states_new")
  139. summarize(hidden_states_new[:, :, -state_len:], "hidden_states_to_copy")
  140. summarize(conv_state, "conv_state_pre")
  141. conv_state.copy_(hidden_states_new[:, :, -state_len:])
  142. summarize(conv_state, "conv_state_post")
  143. out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size)
  144. summarize(out, "out")
  145. summarize(out[:, :, -seq_len:], "out_proper")
  146. out = F.silu(out[:, :, -seq_len:])
  147. summarize(out, "out_silu")
  148. out = out.to(hidden_states.dtype)
  149. return out
  150. already_dumped_rope = False
  151. def save_tensor(tensor, filename):
  152. """Save tensor to binary file with shape information."""
  153. # Ensure tensors directory exists
  154. os.makedirs(os.path.dirname(filename), exist_ok=True)
  155. # Convert to numpy and save
  156. np_array = tensor.detach().cpu().numpy()
  157. # Save shape first (4 int64 values), then data
  158. with open(filename, 'wb') as f:
  159. shape = list(np_array.shape)
  160. while len(shape) < 4:
  161. shape.insert(0, 0)
  162. # Write shape as int64
  163. shape_array = np.array(shape, dtype=np.int64)
  164. f.write(shape_array.tobytes())
  165. # Write data as float32
  166. np_array_float32 = np_array.astype(np.float32)
  167. f.write(np_array_float32.tobytes())
  168. def patched_apply_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  169. global already_dumped_rope
  170. # log inputs
  171. summarize(q, "RoPE.q_in")
  172. summarize(k, "RoPE.k_in")
  173. summarize(cos, "cos")
  174. summarize(sin, "sin")
  175. # if q.shape[1] == 2 and k.shape[1] == 1 and k.shape[2] == 1 and not already_dumped_rope:
  176. # already_dumped_rope = True
  177. # print("Dumping input tensors")
  178. # save_tensor(q, "reference/tensors/testrope_q_in.bin")
  179. # save_tensor(k, "reference/tensors/testrope_k_in.bin")
  180. # save_tensor(cos, "reference/tensors/testrope_cos_in.bin")
  181. # save_tensor(sin, "reference/tensors/testrope_sin_in.bin")
  182. if position_ids:
  183. summarize(position_ids, "position_ids")
  184. # print(f"Rotary dim is {cos.unsqueeze(unsqueeze_dim).shape[-1]}")
  185. # call original
  186. q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
  187. # log outputs
  188. summarize(q_out, "RoPE.q_out")
  189. summarize(k_out, "RoPE.k_out")
  190. return q_out, k_out
  191. def patched_torch_chunk_gated_delta_rule(
  192. query,
  193. key,
  194. value,
  195. g,
  196. beta,
  197. chunk_size=64,
  198. initial_state=None,
  199. output_final_state=False,
  200. use_qk_l2norm_in_kernel=False
  201. ):
  202. initial_dtype = query.dtype
  203. [ summarize(x, y) for (x, y) in ((query, "q_prenorm"), (key, "k_prenorm")) ]
  204. if use_qk_l2norm_in_kernel:
  205. query = l2norm(query, dim=-1, eps=1e-6)
  206. key = l2norm(key, dim=-1, eps=1e-6)
  207. [ summarize(x, y) for (x, y) in ((query, "q_orig"), (key, "k_orig"), (value, "v_orig"), (beta, "b_orig"), (g, "g_orig")) ]
  208. query, key, value, beta, g = [
  209. x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
  210. ]
  211. [ summarize(x, y) for (x, y) in ((query, "q_tra"), (key, "k_tra"), (value, "v_tra"), (beta, "b_tra"), (g, "g_tra")) ]
  212. batch_size, sequence_length, num_heads, k_head_dim = key.shape
  213. print(f"batch_size = {batch_size}, seq_len = {sequence_length}, num_heads = {num_heads}, k_head_dim = {k_head_dim}")
  214. v_head_dim = value.shape[-1]
  215. pad_size = (chunk_size - num_heads % chunk_size) % chunk_size
  216. print(f"Pad size = {pad_size}, chunk_size = {chunk_size}")
  217. query = F.pad(query, (0, 0, 0, pad_size))
  218. key = F.pad(key, (0, 0, 0, pad_size))
  219. value = F.pad(value, (0, 0, 0, pad_size))
  220. beta = F.pad(beta, (0, pad_size))
  221. g = F.pad(g, (0, pad_size))
  222. [ summarize(x, y) for (x, y) in ((query, "q_pad"), (key, "k_pad"), (value, "v_pad"), (beta, "b_pad"), (g, "g_pad")) ]
  223. tot_heads = num_heads + pad_size
  224. scale = 1 / (query.shape[-1] ** 0.5)
  225. print(f"Scale for delta is {scale} (from {query.shape[-1]})")
  226. query = query * scale
  227. summarize(query, "q_scaled")
  228. summarize(key, "k")
  229. summarize(beta.unsqueeze(-1), "beta")
  230. v_beta = value * beta.unsqueeze(-1)
  231. k_beta = key * beta.unsqueeze(-1)
  232. summarize(k_beta, "k_beta")
  233. summarize(v_beta, "v_beta")
  234. # reshape to chunks
  235. query, key, value, k_beta, v_beta = [
  236. x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)
  237. ]
  238. g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
  239. [ summarize(x, y) for (x, y) in ((query, "q_resh"), (k_beta, "k_beta_resh"), (v_beta, "v_beta_resh"), (key, "k_resh"), (value, "v_resh")) ]
  240. mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
  241. # chunk decay
  242. g = g.cumsum(dim=-1)
  243. summarize(g, "g_cumsum")
  244. sub = g.unsqueeze(-1) - g.unsqueeze(-2)
  245. bt1, bt2 = torch.broadcast_tensors(g.unsqueeze(-1), g.unsqueeze(-2))
  246. summarize(bt1, "bt1")
  247. summarize(bt2, "bt2")
  248. summarize(sub, "sub")
  249. decay_mask = sub.tril()
  250. summarize(decay_mask, "sub_tril")
  251. decay_mask = decay_mask.exp()
  252. summarize(decay_mask, "sub_tril_exp")
  253. decay_mask = decay_mask.float()
  254. summarize(decay_mask, "sub_tril_exp_float")
  255. decay_mask = decay_mask.tril()
  256. summarize(decay_mask, "decay_mask")
  257. k_t = key.transpose(-1, -2)
  258. summarize(k_t, "k_t")
  259. kmul = k_beta @ k_t
  260. summarize(kmul, "k_beta @ k_t")
  261. #if not long:
  262. #print(f"k_beta @ k_t:\n{kmul[:,:,:,:8,:8]}\n\n")
  263. kmul_decay = kmul * decay_mask
  264. summarize(kmul_decay, "(k_beta @ k_t) * decay_mask")
  265. attn = -(kmul_decay).masked_fill(mask, 0)
  266. summarize(attn, "attn_in")
  267. for i in range(1, chunk_size):
  268. row = attn[..., i, :i].clone()
  269. sub = attn[..., :i, :i].clone()
  270. attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
  271. #if i <= num_heads and not long:
  272. #print(f"Chunk {i}: row:\n{row}\n\nsub:\n{sub}\nrow_unsq:\n{row.unsqueeze(-1)}\nrow_unsq * sub:\n{row.unsqueeze(-1)*sub}\n")
  273. #print(f"attn => sum = {attn[..., i, :i].sum()}, tensor: \n{attn[..., i, :i]}\n\n")
  274. summarize(attn, "attn_chunks")
  275. attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
  276. summarize(attn, "attn_eye")
  277. value = attn @ v_beta
  278. summarize(value, "value")
  279. k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
  280. summarize(k_cumdecay, "k_cumdecay")
  281. last_recurrent_state = (
  282. torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value)
  283. if initial_state is None
  284. else initial_state.to(value)
  285. )
  286. core_attn_out = torch.zeros_like(value)
  287. mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
  288. # for each chunk
  289. for i in range(0, tot_heads // chunk_size):
  290. print(f"\n=== Processing chunk {i} ===")
  291. q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
  292. summarize(q_i, f"q_i_chunk_{i}")
  293. summarize(k_i, f"k_i_chunk_{i}")
  294. summarize(v_i, f"v_i_chunk_{i}")
  295. attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
  296. summarize(attn, f"attn_chunk_{i}")
  297. v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
  298. summarize(v_prime, f"v_prime_chunk_{i}")
  299. v_new = v_i - v_prime
  300. summarize(v_new, f"v_new_chunk_{i}")
  301. attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
  302. summarize(attn_inter, f"attn_inter_chunk_{i}")
  303. core_attn_out[:, :, i] = attn_inter + attn @ v_new
  304. summarize(core_attn_out[:, :, i], f"core_attn_out_chunk_{i}")
  305. g_last = g[:, :, i, -1, None, None].exp()
  306. summarize(g_last, f"g_last_chunk_{i}")
  307. g_diff_exp = (g[:, :, i, -1, None] - g[:, :, i]).exp()
  308. last_recurrent_state = (
  309. last_recurrent_state * g_last
  310. + (k_i * g_diff_exp[..., None]).transpose(-1, -2) @ v_new
  311. )
  312. summarize(last_recurrent_state, f"updated_state_chunk_{i}")
  313. if not output_final_state:
  314. last_recurrent_state = None
  315. core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
  316. core_attn_out = core_attn_out[:, :, :num_heads]
  317. core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
  318. summarize(core_attn_out, "attn_out")
  319. if isinstance(last_recurrent_state, torch.Tensor):
  320. summarize(last_recurrent_state, "state_out")
  321. return core_attn_out, last_recurrent_state
  322. def patched_torch_recurrent_gated_delta_rule(
  323. query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False
  324. ):
  325. initial_dtype = query.dtype
  326. if use_qk_l2norm_in_kernel:
  327. query = l2norm(query, dim=-1, eps=1e-6)
  328. key = l2norm(key, dim=-1, eps=1e-6)
  329. query, key, value, beta, g = [
  330. x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
  331. ]
  332. summarize(query, "q_t")
  333. summarize(key, "k_t")
  334. summarize(value, "v_t")
  335. summarize(beta, "beta_t")
  336. summarize(g, "g_t")
  337. batch_size, num_heads, sequence_length, k_head_dim = key.shape
  338. v_head_dim = value.shape[-1]
  339. scale = 1 / (query.shape[-1] ** 0.5)
  340. query = query * scale
  341. summarize(query, "q_scaled")
  342. if initial_state is not None:
  343. summarize(initial_state, "initial_state")
  344. core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value)
  345. last_recurrent_state = (
  346. torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
  347. if initial_state is None
  348. else initial_state.to(value)
  349. )
  350. for i in range(sequence_length):
  351. q_t = query[:, :, i]
  352. k_t = key[:, :, i]
  353. v_t = value[:, :, i]
  354. g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
  355. summarize(g_t, "g_exp_unsq")
  356. beta_t = beta[:, :, i].unsqueeze(-1)
  357. summarize(beta_t, "beta_t_unsq")
  358. last_recurrent_state = last_recurrent_state * g_t
  359. summarize(last_recurrent_state, "gated_state")
  360. k_unsq = k_t.unsqueeze(-1)
  361. summarize(k_unsq, "k_unsqueeze")
  362. state_k = last_recurrent_state * k_unsq
  363. summarize(state_k, "state_k_product")
  364. kv_mem = state_k.sum(dim=-2)
  365. summarize(kv_mem, "kv_mem")
  366. delta = (v_t - kv_mem) * beta_t
  367. summarize(delta, "delta")
  368. k_delta = k_t.unsqueeze(-1) * delta.unsqueeze(-2)
  369. summarize(k_delta, "k_delta")
  370. last_recurrent_state = last_recurrent_state + k_delta
  371. summarize(last_recurrent_state, "state_plus_k_delta")
  372. state_q_prod = last_recurrent_state * q_t.unsqueeze(-1)
  373. summarize(state_q_prod, "state_q_product")
  374. core_attn_out[:, :, i] = state_q_prod.sum(dim=-2)
  375. summarize(core_attn_out, "core_attn_out")
  376. if not output_final_state:
  377. last_recurrent_state = None
  378. core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
  379. return core_attn_out, last_recurrent_state
  380. import transformers.models.qwen3_next.modeling_qwen3_next as qwen_mod # noqa: E402
  381. qwen_mod.torch_chunk_gated_delta_rule = patched_torch_chunk_gated_delta_rule
  382. qwen_mod.torch_causal_conv1d_update = patched_torch_causal_conv1d_update
  383. qwen_mod.apply_rotary_pos_emb = patched_apply_rope
  384. qwen_mod.torch_recurrent_gated_delta_rule = patched_torch_recurrent_gated_delta_rule
  385. # Store original functions for patching
  386. original_functions = {}
  387. def debug_hook(name):
  388. def fn(_m, input, output):
  389. if isinstance(input, torch.Tensor):
  390. summarize(input, name + "_in")
  391. elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor):
  392. summarize(input[0], name + "_in")
  393. if isinstance(output, torch.Tensor):
  394. summarize(output, name + "_out")
  395. elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor):
  396. summarize(output[0], name + "_out")
  397. return fn
  398. def patch_all_forward_methods(model):
  399. """Apply monkey patches to all forward methods in the model"""
  400. for name, module in model.named_modules():
  401. # Set layer index if applicable
  402. parts = name.split('.')
  403. module.layer_idx = -1 # Default invalid value
  404. if len(parts) > 2 and parts[0] == 'model' and parts[1] == 'layers':
  405. try:
  406. module.layer_idx = int(parts[2]) # Convert to integer
  407. except (ValueError, IndexError):
  408. module.layer_idx = -1
  409. # Apply forward hook to log all inputs/outputs
  410. module.register_forward_hook(debug_hook(name))
  411. # Additional patches for specific methods in various modules
  412. if hasattr(module, 'forward'):
  413. original_forward = module.forward
  414. def make_patched_forward(orig_forward, mod_name):
  415. def patched_forward(*args, **kwargs):
  416. # Log inputs
  417. for i, arg in enumerate(args):
  418. if isinstance(arg, torch.Tensor):
  419. summarize(arg, f"{mod_name}.forward.arg_{i}_in")
  420. # Call original forward
  421. result = orig_forward(*args, **kwargs)
  422. if mod_name.endswith("linear_attn"):
  423. cache = kwargs["cache_params"]
  424. nameparts = mod_name.split(".")
  425. layer_idx = -1
  426. try:
  427. layer_idx = int(nameparts[2])
  428. except (ValueError, IndexError):
  429. print(f"\n\nDEBUG: Failed to calculate layer index for module: {mod_name}\n\n")
  430. rec_cache = cache.recurrent_states[layer_idx]
  431. if rec_cache is not None:
  432. summarize(rec_cache, f"recurrent_cache_{layer_idx}")
  433. # Log output
  434. if isinstance(result, torch.Tensor):
  435. summarize(result, f"{mod_name}.forward.out")
  436. elif isinstance(result, (tuple, list)):
  437. for i, res in enumerate(result):
  438. if isinstance(res, torch.Tensor):
  439. summarize(res, f"{mod_name}.forward.out_{i}")
  440. return result
  441. return patched_forward
  442. module.forward = make_patched_forward(original_forward, name)
  443. def patch_silu():
  444. """Patch torch.nn.functional.silu to log inputs and outputs"""
  445. global original_functions
  446. if 'silu' not in original_functions:
  447. original_functions['silu'] = torch.nn.functional.silu
  448. def patched_silu(input, inplace=False):
  449. # Log input
  450. summarize(input, "silu_in")
  451. # Call original function
  452. result = original_functions['silu'](input, inplace)
  453. # Log output
  454. summarize(result, "silu_out")
  455. return result
  456. # Replace the function in the torch.nn.functional module
  457. torch.nn.functional.silu = patched_silu
  458. def patch_sigmoid():
  459. """Patch torch.nn.functional.sigmoid to log inputs and outputs"""
  460. global original_functions
  461. if 'sigmoid' not in original_functions:
  462. original_functions['sigmoid'] = torch.nn.functional.sigmoid
  463. def patched_sigmoid(input):
  464. # Log input
  465. summarize(input, "sigmoid_in")
  466. # Call original function
  467. result = original_functions['sigmoid'](input)
  468. # Log output
  469. summarize(result, "sigmoid_out")
  470. return result
  471. # Replace the function in the torch.nn.functional module
  472. torch.nn.functional.sigmoid = patched_sigmoid
  473. def patch_torch_sigmoid():
  474. """Patch torch.nn.functional.sigmoid to log inputs and outputs"""
  475. global original_functions
  476. if 'torch_sigmoid' not in original_functions:
  477. original_functions['torch_sigmoid'] = torch.sigmoid
  478. def patched_torch_sigmoid(input):
  479. # Log input
  480. summarize(input, "torch_sigmoid_in")
  481. # Call original function
  482. result = original_functions['torch_sigmoid'](input)
  483. # Log output
  484. summarize(result, "torch_sigmoid_out")
  485. return result
  486. # Replace the function in the torch.nn.functional module
  487. torch.sigmoid = patched_torch_sigmoid
  488. def patch_pad():
  489. """Patch torch.nn.functional.pad to log inputs and outputs"""
  490. global original_functions
  491. if 'pad' not in original_functions:
  492. original_functions['pad'] = torch.nn.functional.pad
  493. def patched_pad(input: torch.Tensor, pad: typing.Sequence[int], mode: str = 'constant', value: float | None = None): # pyright: ignore[reportGeneralTypeIssues]
  494. # Log input
  495. summarize(input, "pad_in")
  496. print(f"Padding shape is {pad}")
  497. # Call original function
  498. result = original_functions['pad'](input=input, pad=pad, mode=mode, value=value)
  499. # Log output
  500. summarize(result, "pad_out")
  501. return result
  502. # Replace the function in the torch.nn.functional module
  503. torch.nn.functional.pad = patched_pad
  504. def save_kv_cache(past_key_values, step_num, data_dir, model_name):
  505. """Save KV cache tensors for each layer"""
  506. cache_dir = data_dir / f"kv_cache_step_{step_num}"
  507. cache_dir.mkdir(exist_ok=True)
  508. # Access past_key_values if available
  509. if past_key_values is not None:
  510. for layer_idx, cache_tuple in enumerate(past_key_values):
  511. if cache_tuple is None:
  512. print(f"Cache tuple is None for layer {layer_idx} at step {step_num}")
  513. continue
  514. # Handle different cache formats
  515. if isinstance(cache_tuple, (tuple, list)) and len(cache_tuple) >= 2:
  516. key, value = cache_tuple[0], cache_tuple[1]
  517. # Check if key and value are not None
  518. if key is not None and value is not None:
  519. # Save key cache
  520. key_filename = cache_dir / f"layer_{layer_idx}_key.bin"
  521. key.detach().cpu().numpy().astype(np.float32).tofile(key_filename)
  522. # Save value cache
  523. value_filename = cache_dir / f"layer_{layer_idx}_value.bin"
  524. value.detach().cpu().numpy().astype(np.float32).tofile(value_filename)
  525. print(f"Saved KV cache for layer {layer_idx} at step {step_num}: key.shape={key.shape}, value.shape={value.shape}")
  526. else:
  527. print(f"Key or value is None for layer {layer_idx} at step {step_num}")
  528. else:
  529. # Handle other cache formats (e.g., recurrent models)
  530. print(f"Non-standard cache format for layer {layer_idx} at step {step_num}: {type(cache_tuple)}")
  531. # Save as generic cache if it's a tensor
  532. if hasattr(cache_tuple, 'detach'):
  533. cache_filename = cache_dir / f"layer_{layer_idx}_cache.bin"
  534. cache_tuple.detach().cpu().numpy().astype(np.float32).tofile(cache_filename)
  535. print(f"Saved generic cache for layer {layer_idx} at step {step_num}: shape={cache_tuple.shape}")
  536. else:
  537. print(f"No KV cache available at step {step_num}")
  538. unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
  539. parser = argparse.ArgumentParser(description="Process model with specified path")
  540. parser.add_argument("--model-path", "-m", help="Path to the model")
  541. parser.add_argument("--num-tokens", "-n", type=int, default=5, help="Number of tokens to generate")
  542. parser.add_argument("--prompt", "-p", default="Hello, my name is", help="Input prompt")
  543. parser.add_argument("--save-cache", action="store_true", help="Save KV cache at each step")
  544. args = parser.parse_args()
  545. model_path = os.environ.get("MODEL_PATH", args.model_path)
  546. if model_path is None:
  547. parser.error(
  548. "Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
  549. )
  550. config = AutoConfig.from_pretrained(model_path)
  551. print("Model type: ", config.model_type)
  552. print("Vocab size: ", config.vocab_size)
  553. print("Hidden size: ", config.hidden_size)
  554. print("Number of layers: ", config.num_hidden_layers)
  555. print("BOS token id: ", config.bos_token_id)
  556. print("EOS token id: ", config.eos_token_id)
  557. num_model_layers = config.num_hidden_layers
  558. print("Loading model and tokenizer using AutoTokenizer:", model_path)
  559. tokenizer = AutoTokenizer.from_pretrained(model_path)
  560. config = AutoConfig.from_pretrained(model_path)
  561. if unreleased_model_name:
  562. model_name_lower = unreleased_model_name.lower()
  563. unreleased_module_path = (
  564. f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
  565. )
  566. class_name = f"{unreleased_model_name}ForCausalLM"
  567. print(f"Importing unreleased model module: {unreleased_module_path}")
  568. try:
  569. model_class = getattr(
  570. importlib.import_module(unreleased_module_path), class_name
  571. )
  572. model = model_class.from_pretrained(
  573. model_path
  574. ) # Note: from_pretrained, not fromPretrained
  575. except (ImportError, AttributeError) as e:
  576. print(f"Failed to import or load model: {e}")
  577. exit(1)
  578. else:
  579. model = AutoModelForCausalLM.from_pretrained(
  580. model_path, device_map="auto", offload_folder="offload"
  581. )
  582. patch_all_forward_methods(model)
  583. patch_silu()
  584. patch_pad()
  585. patch_sigmoid()
  586. patch_torch_sigmoid()
  587. model_name = os.path.basename(model_path)
  588. # Printing the Model class to allow for easier debugging. This can be useful
  589. # when working with models that have not been publicly released yet and this
  590. # migth require that the concrete class is imported and used directly instead
  591. # of using AutoModelForCausalLM.
  592. print(f"Model class: {model.__class__.__name__}")
  593. device = next(model.parameters()).device
  594. prompt = args.prompt
  595. input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
  596. print(f"Input tokens: {input_ids}")
  597. print(f"Input text: {repr(prompt)}")
  598. print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
  599. data_dir = Path("data")
  600. data_dir.mkdir(exist_ok=True)
  601. # Store all generated tokens and logits
  602. all_generated_tokens = []
  603. all_logits = []
  604. with torch.no_grad():
  605. # Initial forward pass
  606. print(f"\n=== Initial Forward Pass ===")
  607. outputs = model(input_ids, use_cache=True)
  608. logits = outputs.logits
  609. # Extract logits for the last token (next token prediction)
  610. last_logits = logits[0, -1, :].cpu().numpy()
  611. all_logits.append(last_logits)
  612. print(f"Logits shape: {logits.shape}")
  613. print(f"Last token logits shape: {last_logits.shape}")
  614. # Generate first token
  615. next_token_id = np.argmax(last_logits).item()
  616. all_generated_tokens.append(next_token_id)
  617. # Show top 5 predicted tokens for first step
  618. top_indices = np.argsort(last_logits)[-5:][::-1]
  619. print("Top 5 predictions for first token:")
  620. for idx in top_indices:
  621. token = tokenizer.decode([idx])
  622. print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
  623. print(f"Generated token {next_token_id} ({repr(tokenizer.decode([next_token_id]))})")
  624. # Save KV cache if requested
  625. if args.save_cache:
  626. save_kv_cache(outputs.past_key_values, 0, data_dir, model_name)
  627. # Prepare for next iteration
  628. past_key_values = outputs.past_key_values
  629. current_input = torch.tensor([[next_token_id]], device=device)
  630. # Generate remaining tokens
  631. for step in range(1, args.num_tokens):
  632. print(f"\n=== Generation Step {step} ===")
  633. # Forward pass with cache
  634. outputs = model(
  635. input_ids=current_input,
  636. past_key_values=past_key_values,
  637. use_cache=True
  638. )
  639. logits = outputs.logits
  640. last_logits = logits[0, -1, :].cpu().numpy()
  641. all_logits.append(last_logits)
  642. # Generate next token
  643. next_token_id = np.argmax(last_logits).item()
  644. all_generated_tokens.append(next_token_id)
  645. # Show top 5 predicted tokens for this step
  646. top_indices = np.argsort(last_logits)[-5:][::-1]
  647. print(f"Top 5 predictions for step {step}:")
  648. for idx in top_indices:
  649. token = tokenizer.decode([idx])
  650. print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
  651. print(f"Generated token {next_token_id} ({repr(tokenizer.decode([next_token_id]))})")
  652. # Save KV cache if requested
  653. if args.save_cache:
  654. save_kv_cache(outputs.past_key_values, step, data_dir, model_name)
  655. # Update for next iteration
  656. past_key_values = outputs.past_key_values
  657. current_input = torch.tensor([[next_token_id]], device=device)
  658. # Save results
  659. bin_filename = data_dir / f"pytorch-{model_name}-multi-token.bin"
  660. txt_filename = data_dir / f"pytorch-{model_name}-multi-token.txt"
  661. # Save all logits concatenated
  662. all_logits_array = np.array(all_logits)
  663. all_logits_array.astype(np.float32).tofile(bin_filename)
  664. # Also save as text file for easy inspection
  665. with open(txt_filename, "w") as f:
  666. f.write(f"Generated tokens: {all_generated_tokens}\n")
  667. f.write(f"Generated text: {repr(tokenizer.decode(all_generated_tokens))}\n")
  668. f.write(f"Full sequence: {repr(tokenizer.decode(input_ids[0].tolist() + all_generated_tokens))}\n\n")
  669. for step, logits in enumerate(all_logits):
  670. f.write(f"=== Step {step} logits ===\n")
  671. for i, logit in enumerate(logits):
  672. f.write(f"{i}: {logit:.6f}\n")
  673. f.write("\n")
  674. print(f"\n=== Generation Complete ===")
  675. print(f"Generated {len(all_generated_tokens)} tokens: {all_generated_tokens}")
  676. print(f"Generated text: {repr(tokenizer.decode(all_generated_tokens))}")
  677. print(f"Full sequence: {repr(tokenizer.decode(input_ids[0].tolist() + all_generated_tokens))}")
  678. print(f"Saved bin logits to: {bin_filename}")
  679. print(f"Saved txt logits to: {txt_filename}")
  680. if args.save_cache:
  681. print(f"KV cache saved to: {data_dir}/kv_cache_step_*")