run-org-model-multi-token.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827
  1. #!/usr/bin/env python3
  2. import argparse
  3. import os
  4. import importlib
  5. from pathlib import Path
  6. import re
  7. from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
  8. import torch
  9. import numpy as np
  10. ### If you want to dump RoPE activations, apply this monkey patch to the model
  11. ### class from Transformers that you are running (replace apertus.modeling_apertus
  12. ### with the proper package and class for your model
  13. ### === START ROPE DEBUG ===
  14. # from transformers.models.apertus.modeling_apertus import apply_rotary_pos_emb
  15. # orig_rope = apply_rotary_pos_emb
  16. # torch.set_printoptions(threshold=float('inf'))
  17. # torch.set_printoptions(precision=6, sci_mode=False)
  18. # def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  19. # # log inputs
  20. # summarize(q, "RoPE.q_in")
  21. # summarize(k, "RoPE.k_in")
  22. # # call original
  23. # q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
  24. # # log outputs
  25. # summarize(q_out, "RoPE.q_out")
  26. # summarize(k_out, "RoPE.k_out")
  27. # return q_out, k_out
  28. # # Patch it
  29. # import transformers.models.apertus.modeling_apertus as apertus_mod # noqa: E402
  30. # apertus_mod.apply_rotary_pos_emb = debug_rope
  31. ### == END ROPE DEBUG ===
  32. token_counter = {}
  33. layer_counter = {}
  34. num_model_layers = 0
  35. def summarize(tensor: torch.Tensor, name: str, max_seq: int = 4):
  36. torch.set_printoptions(precision = 6, edgeitems = max_seq, linewidth = 160, sci_mode = False, threshold = 50)
  37. global num_model_layers, layer_counter, token_counter
  38. """
  39. Print a tensor in llama.cpp debug style.
  40. Shows first and last max_vals of each vector per sequence position.
  41. """
  42. t = tensor.detach().to(torch.float32).cpu()
  43. print(f"ggml_debug: {name} = (f32) ... = {{{t.shape}}}\n")
  44. print(t)
  45. print(f"\n sum = {t.sum().item():.6f}\n")
  46. indexed_patterns = [ r"model\.layers\.[0-9]+_out", r"recurrent_cache_[0-9]+" ]
  47. non_indexed_patterns = [ r"k_pad", r"v_pad", r"q_scaled" ]
  48. if any(re.fullmatch(p, name) for p in indexed_patterns):
  49. if name not in token_counter:
  50. token_counter[name] = 1
  51. else:
  52. token_counter[name] = token_counter[name] + 1
  53. save_tensor(t, f"reference/tensors/org/{name}_{token_counter[name]}.bin")
  54. if any(re.fullmatch(p, name) for p in non_indexed_patterns):
  55. if name not in token_counter:
  56. token_counter[name] = 1
  57. if name not in layer_counter:
  58. layer_counter[name] = 0
  59. elif layer_counter[name] >= num_model_layers:
  60. layer_counter[name] = 0
  61. token_counter[name] = token_counter[name] + 1
  62. else:
  63. layer_counter[name] = layer_counter[name] + 1
  64. if layer_counter[name] % 4 == 3:
  65. layer_counter[name] = layer_counter[name] + 1 # skip attention layers
  66. save_tensor(t, f"reference/tensors/org/{name}_{layer_counter[name]}_{token_counter[name]}.bin")
  67. from transformers.models.qwen3_next.modeling_qwen3_next import torch_causal_conv1d_update, apply_rotary_pos_emb, l2norm, repeat_kv # noqa: E402
  68. from transformers.processing_utils import Unpack # noqa: E402
  69. from transformers.utils.generic import TransformersKwargs # noqa: E402
  70. orig_conv1d_update = torch_causal_conv1d_update
  71. orig_rope = apply_rotary_pos_emb
  72. import torch.nn.functional as F # noqa: E402
  73. import typing # noqa: E402
  74. from torch import nn # noqa: E402
  75. def patched_eager_attention_forward(
  76. module: nn.Module,
  77. query: torch.Tensor,
  78. key: torch.Tensor,
  79. value: torch.Tensor,
  80. attention_mask: typing.Optional[torch.Tensor],
  81. scaling: float,
  82. dropout: float = 0.0,
  83. **kwargs: Unpack[TransformersKwargs],
  84. ):
  85. print(f"\nAttention scaling: {scaling}\n")
  86. key_states = repeat_kv(key, module.num_key_value_groups)
  87. value_states = repeat_kv(value, module.num_key_value_groups)
  88. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  89. if attention_mask is not None:
  90. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  91. attn_weights = attn_weights + causal_mask
  92. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  93. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  94. attn_output = torch.matmul(attn_weights, value_states)
  95. attn_output = attn_output.transpose(1, 2).contiguous()
  96. summarize(attn_output, "attn_output")
  97. return attn_output, attn_weights
  98. def patched_torch_causal_conv1d_update(
  99. hidden_states,
  100. conv_state,
  101. weight,
  102. bias=None,
  103. activation=None,
  104. ):
  105. _, hidden_size, seq_len = hidden_states.shape
  106. state_len = conv_state.shape[-1]
  107. summarize(hidden_states, "hidden_states_in")
  108. summarize(conv_state, "conv_state_in")
  109. hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype)
  110. summarize(hidden_states_new, "hidden_states_new")
  111. summarize(hidden_states_new[:, :, -state_len:], "hidden_states_to_copy")
  112. summarize(conv_state, "conv_state_pre")
  113. conv_state.copy_(hidden_states_new[:, :, -state_len:])
  114. summarize(conv_state, "conv_state_post")
  115. out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size)
  116. summarize(out, "out")
  117. summarize(out[:, :, -seq_len:], "out_proper")
  118. out = F.silu(out[:, :, -seq_len:])
  119. summarize(out, "out_silu")
  120. out = out.to(hidden_states.dtype)
  121. return out
  122. already_dumped_rope = False
  123. def save_tensor(tensor, filename):
  124. """Save tensor to binary file with shape information."""
  125. # Ensure tensors directory exists
  126. os.makedirs(os.path.dirname(filename), exist_ok=True)
  127. # Convert to numpy and save
  128. np_array = tensor.detach().cpu().numpy()
  129. # Save shape first (4 int64 values), then data
  130. with open(filename, 'wb') as f:
  131. shape = list(np_array.shape)
  132. while len(shape) < 4:
  133. shape.insert(0, 0)
  134. # Write shape as int64
  135. shape_array = np.array(shape, dtype=np.int64)
  136. f.write(shape_array.tobytes())
  137. # Write data as float32
  138. np_array_float32 = np_array.astype(np.float32)
  139. f.write(np_array_float32.tobytes())
  140. def patched_apply_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  141. global already_dumped_rope
  142. # log inputs
  143. summarize(q, "RoPE.q_in")
  144. summarize(k, "RoPE.k_in")
  145. summarize(cos, "cos")
  146. summarize(sin, "sin")
  147. # if q.shape[1] == 2 and k.shape[1] == 1 and k.shape[2] == 1 and not already_dumped_rope:
  148. # already_dumped_rope = True
  149. # print("Dumping input tensors")
  150. # save_tensor(q, "reference/tensors/testrope_q_in.bin")
  151. # save_tensor(k, "reference/tensors/testrope_k_in.bin")
  152. # save_tensor(cos, "reference/tensors/testrope_cos_in.bin")
  153. # save_tensor(sin, "reference/tensors/testrope_sin_in.bin")
  154. if position_ids:
  155. summarize(position_ids, "position_ids")
  156. # print(f"Rotary dim is {cos.unsqueeze(unsqueeze_dim).shape[-1]}")
  157. # call original
  158. q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
  159. # log outputs
  160. summarize(q_out, "RoPE.q_out")
  161. summarize(k_out, "RoPE.k_out")
  162. return q_out, k_out
  163. def patched_torch_chunk_gated_delta_rule(
  164. query,
  165. key,
  166. value,
  167. g,
  168. beta,
  169. chunk_size=64,
  170. initial_state=None,
  171. output_final_state=False,
  172. use_qk_l2norm_in_kernel=False
  173. ):
  174. initial_dtype = query.dtype
  175. [ summarize(x, y) for (x, y) in ((query, "q_prenorm"), (key, "k_prenorm")) ]
  176. if use_qk_l2norm_in_kernel:
  177. query = l2norm(query, dim=-1, eps=1e-6)
  178. key = l2norm(key, dim=-1, eps=1e-6)
  179. [ summarize(x, y) for (x, y) in ((query, "q_orig"), (key, "k_orig"), (value, "v_orig"), (beta, "b_orig"), (g, "g_orig")) ]
  180. query, key, value, beta, g = [
  181. x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
  182. ]
  183. [ summarize(x, y) for (x, y) in ((query, "q_tra"), (key, "k_tra"), (value, "v_tra"), (beta, "b_tra"), (g, "g_tra")) ]
  184. batch_size, sequence_length, num_heads, k_head_dim = key.shape
  185. print(f"batch_size = {batch_size}, seq_len = {sequence_length}, num_heads = {num_heads}, k_head_dim = {k_head_dim}")
  186. v_head_dim = value.shape[-1]
  187. pad_size = (chunk_size - num_heads % chunk_size) % chunk_size
  188. print(f"Pad size = {pad_size}, chunk_size = {chunk_size}")
  189. query = F.pad(query, (0, 0, 0, pad_size))
  190. key = F.pad(key, (0, 0, 0, pad_size))
  191. value = F.pad(value, (0, 0, 0, pad_size))
  192. beta = F.pad(beta, (0, pad_size))
  193. g = F.pad(g, (0, pad_size))
  194. [ summarize(x, y) for (x, y) in ((query, "q_pad"), (key, "k_pad"), (value, "v_pad"), (beta, "b_pad"), (g, "g_pad")) ]
  195. tot_heads = num_heads + pad_size
  196. scale = 1 / (query.shape[-1] ** 0.5)
  197. print(f"Scale for delta is {scale} (from {query.shape[-1]})")
  198. query = query * scale
  199. summarize(query, "q_scaled")
  200. summarize(key, "k")
  201. summarize(beta.unsqueeze(-1), "beta")
  202. v_beta = value * beta.unsqueeze(-1)
  203. k_beta = key * beta.unsqueeze(-1)
  204. summarize(k_beta, "k_beta")
  205. summarize(v_beta, "v_beta")
  206. # reshape to chunks
  207. query, key, value, k_beta, v_beta = [
  208. x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)
  209. ]
  210. g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
  211. [ summarize(x, y) for (x, y) in ((query, "q_resh"), (k_beta, "k_beta_resh"), (v_beta, "v_beta_resh"), (key, "k_resh"), (value, "v_resh")) ]
  212. mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
  213. # chunk decay
  214. g = g.cumsum(dim=-1)
  215. summarize(g, "g_cumsum")
  216. sub = g.unsqueeze(-1) - g.unsqueeze(-2)
  217. bt1, bt2 = torch.broadcast_tensors(g.unsqueeze(-1), g.unsqueeze(-2))
  218. summarize(bt1, "bt1")
  219. summarize(bt2, "bt2")
  220. summarize(sub, "sub")
  221. decay_mask = sub.tril()
  222. summarize(decay_mask, "sub_tril")
  223. decay_mask = decay_mask.exp()
  224. summarize(decay_mask, "sub_tril_exp")
  225. decay_mask = decay_mask.float()
  226. summarize(decay_mask, "sub_tril_exp_float")
  227. decay_mask = decay_mask.tril()
  228. summarize(decay_mask, "decay_mask")
  229. k_t = key.transpose(-1, -2)
  230. summarize(k_t, "k_t")
  231. kmul = k_beta @ k_t
  232. summarize(kmul, "k_beta @ k_t")
  233. #if not long:
  234. #print(f"k_beta @ k_t:\n{kmul[:,:,:,:8,:8]}\n\n")
  235. kmul_decay = kmul * decay_mask
  236. summarize(kmul_decay, "(k_beta @ k_t) * decay_mask")
  237. attn = -(kmul_decay).masked_fill(mask, 0)
  238. summarize(attn, "attn_in")
  239. for i in range(1, chunk_size):
  240. row = attn[..., i, :i].clone()
  241. sub = attn[..., :i, :i].clone()
  242. attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
  243. #if i <= num_heads and not long:
  244. #print(f"Chunk {i}: row:\n{row}\n\nsub:\n{sub}\nrow_unsq:\n{row.unsqueeze(-1)}\nrow_unsq * sub:\n{row.unsqueeze(-1)*sub}\n")
  245. #print(f"attn => sum = {attn[..., i, :i].sum()}, tensor: \n{attn[..., i, :i]}\n\n")
  246. summarize(attn, "attn_chunks")
  247. attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
  248. summarize(attn, "attn_eye")
  249. value = attn @ v_beta
  250. summarize(value, "value")
  251. k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
  252. summarize(k_cumdecay, "k_cumdecay")
  253. last_recurrent_state = (
  254. torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value)
  255. if initial_state is None
  256. else initial_state.to(value)
  257. )
  258. core_attn_out = torch.zeros_like(value)
  259. mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
  260. # for each chunk
  261. for i in range(0, tot_heads // chunk_size):
  262. print(f"\n=== Processing chunk {i} ===")
  263. q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
  264. summarize(q_i, f"q_i_chunk_{i}")
  265. summarize(k_i, f"k_i_chunk_{i}")
  266. summarize(v_i, f"v_i_chunk_{i}")
  267. q_k_trans = q_i @ k_i.transpose(-1, -2)
  268. summarize(q_k_trans, f"q_k_trans_{i}")
  269. q_k_trans_decay = q_k_trans * decay_mask[:, :, i]
  270. summarize(q_k_trans_decay, f"q_k_trans_decay_{i}")
  271. attn = q_k_trans_decay.masked_fill_(mask, 0)
  272. summarize(attn, f"attn_chunk_{i}")
  273. v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
  274. summarize(v_prime, f"v_prime_chunk_{i}")
  275. v_new = v_i - v_prime
  276. summarize(v_new, f"v_new_chunk_{i}")
  277. attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
  278. summarize(attn_inter, f"attn_inter_chunk_{i}")
  279. core_attn_out[:, :, i] = attn_inter + attn @ v_new
  280. summarize(core_attn_out[:, :, i], f"core_attn_out_chunk_{i}")
  281. g_last = g[:, :, i, -1, None, None].exp()
  282. summarize(g_last, f"g_last_chunk_{i}")
  283. g_diff_exp = (g[:, :, i, -1, None] - g[:, :, i]).exp()
  284. summarize(g_diff_exp, f"g_diff_exp_chunk_{i}")
  285. state_g_last = last_recurrent_state * g_last
  286. summarize(state_g_last, f"state_g_last_{i}")
  287. k_g_diffexp = (k_i * g_diff_exp[..., None])
  288. summarize(k_g_diffexp, f"k_g_diffexp_{i}")
  289. k_g_diffexp_T = k_g_diffexp.transpose(-1, -2)
  290. summarize(k_g_diffexp, f"k_g_diffexp_T_{i}")
  291. kgd_mul_vnew = k_g_diffexp_T @ v_new
  292. summarize(kgd_mul_vnew, f"kgd_mul_vnew_{i}")
  293. last_recurrent_state = state_g_last + kgd_mul_vnew
  294. summarize(last_recurrent_state, f"updated_state_chunk_{i}")
  295. if not output_final_state:
  296. last_recurrent_state = None
  297. core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
  298. summarize(core_attn_out, "attn_out_reshaped")
  299. core_attn_out = core_attn_out[:, :, :num_heads]
  300. summarize(core_attn_out, "attn_out_truncated")
  301. core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
  302. summarize(core_attn_out, "attn_out")
  303. if isinstance(last_recurrent_state, torch.Tensor):
  304. summarize(last_recurrent_state, "state_out")
  305. return core_attn_out, last_recurrent_state
  306. def patched_torch_recurrent_gated_delta_rule(
  307. query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False
  308. ):
  309. initial_dtype = query.dtype
  310. if use_qk_l2norm_in_kernel:
  311. query = l2norm(query, dim=-1, eps=1e-6)
  312. key = l2norm(key, dim=-1, eps=1e-6)
  313. query, key, value, beta, g = [
  314. x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
  315. ]
  316. summarize(query, "q_t")
  317. summarize(key, "k_t")
  318. summarize(value, "v_t")
  319. summarize(beta, "beta_t")
  320. summarize(g, "g_t")
  321. batch_size, num_heads, sequence_length, k_head_dim = key.shape
  322. v_head_dim = value.shape[-1]
  323. scale = 1 / (query.shape[-1] ** 0.5)
  324. query = query * scale
  325. summarize(query, "q_scaled")
  326. if initial_state is not None:
  327. summarize(initial_state, "initial_state")
  328. core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value)
  329. last_recurrent_state = (
  330. torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
  331. if initial_state is None
  332. else initial_state.to(value)
  333. )
  334. for i in range(sequence_length):
  335. q_t = query[:, :, i]
  336. k_t = key[:, :, i]
  337. v_t = value[:, :, i]
  338. g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
  339. summarize(g_t, "g_exp_unsq")
  340. beta_t = beta[:, :, i].unsqueeze(-1)
  341. summarize(beta_t, "beta_t_unsq")
  342. last_recurrent_state = last_recurrent_state * g_t
  343. summarize(last_recurrent_state, "gated_state")
  344. k_unsq = k_t.unsqueeze(-1)
  345. summarize(k_unsq, "k_unsqueeze")
  346. state_k = last_recurrent_state * k_unsq
  347. summarize(state_k, "state_k_product")
  348. kv_mem = state_k.sum(dim=-2)
  349. summarize(kv_mem, "kv_mem")
  350. delta = (v_t - kv_mem) * beta_t
  351. summarize(delta, "delta")
  352. k_delta = k_t.unsqueeze(-1) * delta.unsqueeze(-2)
  353. summarize(k_delta, "k_delta")
  354. last_recurrent_state = last_recurrent_state + k_delta
  355. summarize(last_recurrent_state, "state_plus_k_delta")
  356. state_q_prod = last_recurrent_state * q_t.unsqueeze(-1)
  357. summarize(state_q_prod, "state_q_product")
  358. core_attn_out[:, :, i] = state_q_prod.sum(dim=-2)
  359. summarize(core_attn_out, "core_attn_out")
  360. if not output_final_state:
  361. last_recurrent_state = None
  362. core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
  363. return core_attn_out, last_recurrent_state
  364. import transformers.models.qwen3_next.modeling_qwen3_next as qwen_mod # noqa: E402
  365. qwen_mod.torch_chunk_gated_delta_rule = patched_torch_chunk_gated_delta_rule
  366. qwen_mod.torch_causal_conv1d_update = patched_torch_causal_conv1d_update
  367. qwen_mod.apply_rotary_pos_emb = patched_apply_rope
  368. qwen_mod.torch_recurrent_gated_delta_rule = patched_torch_recurrent_gated_delta_rule
  369. qwen_mod.eager_attention_forward = patched_eager_attention_forward
  370. # Store original functions for patching
  371. original_functions = {}
  372. def debug_hook(name):
  373. def fn(_m, input, output):
  374. if isinstance(input, torch.Tensor):
  375. summarize(input, name + "_in")
  376. elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor):
  377. summarize(input[0], name + "_in")
  378. if isinstance(output, torch.Tensor):
  379. summarize(output, name + "_out")
  380. elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor):
  381. summarize(output[0], name + "_out")
  382. return fn
  383. def patch_all_forward_methods(model):
  384. """Apply monkey patches to all forward methods in the model"""
  385. for name, module in model.named_modules():
  386. # Set layer index if applicable
  387. parts = name.split('.')
  388. module.layer_idx = -1 # Default invalid value
  389. if len(parts) > 2 and parts[0] == 'model' and parts[1] == 'layers':
  390. try:
  391. module.layer_idx = int(parts[2]) # Convert to integer
  392. except (ValueError, IndexError):
  393. module.layer_idx = -1
  394. # Apply forward hook to log all inputs/outputs
  395. module.register_forward_hook(debug_hook(name))
  396. # Additional patches for specific methods in various modules
  397. if hasattr(module, 'forward'):
  398. original_forward = module.forward
  399. def make_patched_forward(orig_forward, mod_name):
  400. def patched_forward(*args, **kwargs):
  401. # Log inputs
  402. for i, arg in enumerate(args):
  403. if isinstance(arg, torch.Tensor):
  404. summarize(arg, f"{mod_name}.forward.arg_{i}_in")
  405. # Call original forward
  406. result = orig_forward(*args, **kwargs)
  407. if mod_name.endswith("linear_attn"):
  408. cache = kwargs["cache_params"]
  409. nameparts = mod_name.split(".")
  410. layer_idx = -1
  411. try:
  412. layer_idx = int(nameparts[2])
  413. except (ValueError, IndexError):
  414. print(f"\n\nDEBUG: Failed to calculate layer index for module: {mod_name}\n\n")
  415. rec_cache = cache.recurrent_states[layer_idx]
  416. if rec_cache is not None:
  417. summarize(rec_cache, f"recurrent_cache_{layer_idx}")
  418. # Log output
  419. if isinstance(result, torch.Tensor):
  420. summarize(result, f"{mod_name}.forward.out")
  421. elif isinstance(result, (tuple, list)):
  422. for i, res in enumerate(result):
  423. if isinstance(res, torch.Tensor):
  424. summarize(res, f"{mod_name}.forward.out_{i}")
  425. return result
  426. return patched_forward
  427. module.forward = make_patched_forward(original_forward, name)
  428. def patch_silu():
  429. """Patch torch.nn.functional.silu to log inputs and outputs"""
  430. global original_functions
  431. if 'silu' not in original_functions:
  432. original_functions['silu'] = torch.nn.functional.silu
  433. def patched_silu(input, inplace=False):
  434. # Log input
  435. summarize(input, "silu_in")
  436. # Call original function
  437. result = original_functions['silu'](input, inplace)
  438. # Log output
  439. summarize(result, "silu_out")
  440. return result
  441. # Replace the function in the torch.nn.functional module
  442. torch.nn.functional.silu = patched_silu
  443. def patch_sigmoid():
  444. """Patch torch.nn.functional.sigmoid to log inputs and outputs"""
  445. global original_functions
  446. if 'sigmoid' not in original_functions:
  447. original_functions['sigmoid'] = torch.nn.functional.sigmoid
  448. def patched_sigmoid(input):
  449. # Log input
  450. summarize(input, "sigmoid_in")
  451. # Call original function
  452. result = original_functions['sigmoid'](input)
  453. # Log output
  454. summarize(result, "sigmoid_out")
  455. return result
  456. # Replace the function in the torch.nn.functional module
  457. torch.nn.functional.sigmoid = patched_sigmoid
  458. def patch_torch_sigmoid():
  459. """Patch torch.nn.functional.sigmoid to log inputs and outputs"""
  460. global original_functions
  461. if 'torch_sigmoid' not in original_functions:
  462. original_functions['torch_sigmoid'] = torch.sigmoid
  463. def patched_torch_sigmoid(input):
  464. # Log input
  465. summarize(input, "torch_sigmoid_in")
  466. # Call original function
  467. result = original_functions['torch_sigmoid'](input)
  468. # Log output
  469. summarize(result, "torch_sigmoid_out")
  470. return result
  471. # Replace the function in the torch.nn.functional module
  472. torch.sigmoid = patched_torch_sigmoid
  473. def patch_pad():
  474. """Patch torch.nn.functional.pad to log inputs and outputs"""
  475. global original_functions
  476. if 'pad' not in original_functions:
  477. original_functions['pad'] = torch.nn.functional.pad
  478. def patched_pad(input: torch.Tensor, pad: typing.Sequence[int], mode: str = 'constant', value: float | None = None): # pyright: ignore[reportGeneralTypeIssues]
  479. # Log input
  480. summarize(input, "pad_in")
  481. print(f"Padding shape is {pad}")
  482. # Call original function
  483. result = original_functions['pad'](input=input, pad=pad, mode=mode, value=value)
  484. # Log output
  485. summarize(result, "pad_out")
  486. return result
  487. # Replace the function in the torch.nn.functional module
  488. torch.nn.functional.pad = patched_pad
  489. def save_kv_cache(past_key_values, step_num, data_dir, model_name):
  490. """Save KV cache tensors for each layer"""
  491. cache_dir = data_dir / f"kv_cache_step_{step_num}"
  492. cache_dir.mkdir(exist_ok=True)
  493. # Access past_key_values if available
  494. if past_key_values is not None:
  495. for layer_idx, cache_tuple in enumerate(past_key_values):
  496. if cache_tuple is None:
  497. print(f"Cache tuple is None for layer {layer_idx} at step {step_num}")
  498. continue
  499. # Handle different cache formats
  500. if isinstance(cache_tuple, (tuple, list)) and len(cache_tuple) >= 2:
  501. key, value = cache_tuple[0], cache_tuple[1]
  502. # Check if key and value are not None
  503. if key is not None and value is not None:
  504. # Save key cache
  505. key_filename = cache_dir / f"layer_{layer_idx}_key.bin"
  506. key.detach().cpu().numpy().astype(np.float32).tofile(key_filename)
  507. # Save value cache
  508. value_filename = cache_dir / f"layer_{layer_idx}_value.bin"
  509. value.detach().cpu().numpy().astype(np.float32).tofile(value_filename)
  510. print(f"Saved KV cache for layer {layer_idx} at step {step_num}: key.shape={key.shape}, value.shape={value.shape}")
  511. else:
  512. print(f"Key or value is None for layer {layer_idx} at step {step_num}")
  513. else:
  514. # Handle other cache formats (e.g., recurrent models)
  515. print(f"Non-standard cache format for layer {layer_idx} at step {step_num}: {type(cache_tuple)}")
  516. # Save as generic cache if it's a tensor
  517. if hasattr(cache_tuple, 'detach'):
  518. cache_filename = cache_dir / f"layer_{layer_idx}_cache.bin"
  519. cache_tuple.detach().cpu().numpy().astype(np.float32).tofile(cache_filename)
  520. print(f"Saved generic cache for layer {layer_idx} at step {step_num}: shape={cache_tuple.shape}")
  521. else:
  522. print(f"No KV cache available at step {step_num}")
  523. unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
  524. parser = argparse.ArgumentParser(description="Process model with specified path")
  525. parser.add_argument("--model-path", "-m", help="Path to the model")
  526. parser.add_argument("--num-tokens", "-n", type=int, default=5, help="Number of tokens to generate")
  527. parser.add_argument("--prompt", "-p", default="Hello, my name is", help="Input prompt")
  528. parser.add_argument("--save-cache", action="store_true", help="Save KV cache at each step")
  529. args = parser.parse_args()
  530. model_path = os.environ.get("MODEL_PATH", args.model_path)
  531. if model_path is None:
  532. parser.error(
  533. "Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
  534. )
  535. config = AutoConfig.from_pretrained(model_path)
  536. print("Model type: ", config.model_type)
  537. print("Vocab size: ", config.vocab_size)
  538. print("Hidden size: ", config.hidden_size)
  539. print("Number of layers: ", config.num_hidden_layers)
  540. print("BOS token id: ", config.bos_token_id)
  541. print("EOS token id: ", config.eos_token_id)
  542. num_model_layers = config.num_hidden_layers
  543. print("Loading model and tokenizer using AutoTokenizer:", model_path)
  544. tokenizer = AutoTokenizer.from_pretrained(model_path)
  545. config = AutoConfig.from_pretrained(model_path)
  546. if unreleased_model_name:
  547. model_name_lower = unreleased_model_name.lower()
  548. unreleased_module_path = (
  549. f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
  550. )
  551. class_name = f"{unreleased_model_name}ForCausalLM"
  552. print(f"Importing unreleased model module: {unreleased_module_path}")
  553. try:
  554. model_class = getattr(
  555. importlib.import_module(unreleased_module_path), class_name
  556. )
  557. model = model_class.from_pretrained(
  558. model_path
  559. ) # Note: from_pretrained, not fromPretrained
  560. except (ImportError, AttributeError) as e:
  561. print(f"Failed to import or load model: {e}")
  562. exit(1)
  563. else:
  564. model = AutoModelForCausalLM.from_pretrained(
  565. model_path, device_map="auto", offload_folder="offload"
  566. )
  567. patch_all_forward_methods(model)
  568. patch_silu()
  569. patch_pad()
  570. patch_sigmoid()
  571. patch_torch_sigmoid()
  572. model_name = os.path.basename(model_path)
  573. # Printing the Model class to allow for easier debugging. This can be useful
  574. # when working with models that have not been publicly released yet and this
  575. # migth require that the concrete class is imported and used directly instead
  576. # of using AutoModelForCausalLM.
  577. print(f"Model class: {model.__class__.__name__}")
  578. device = next(model.parameters()).device
  579. prompt = args.prompt
  580. input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
  581. print(f"Input tokens: {input_ids}")
  582. print(f"Input text: {repr(prompt)}")
  583. print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
  584. data_dir = Path("data")
  585. data_dir.mkdir(exist_ok=True)
  586. # Store all generated tokens and logits
  587. all_generated_tokens = []
  588. all_logits = []
  589. model.config._attn_implementation = "eager"
  590. with torch.no_grad():
  591. # Initial forward pass
  592. print(f"\n=== Initial Forward Pass ===")
  593. outputs = model(input_ids, use_cache=True)
  594. logits = outputs.logits
  595. # Extract logits for the last token (next token prediction)
  596. last_logits = logits[0, -1, :].cpu().numpy()
  597. all_logits.append(last_logits)
  598. print(f"Logits shape: {logits.shape}")
  599. print(f"Last token logits shape: {last_logits.shape}")
  600. # Generate first token
  601. next_token_id = np.argmax(last_logits).item()
  602. all_generated_tokens.append(next_token_id)
  603. # Show top 5 predicted tokens for first step
  604. top_indices = np.argsort(last_logits)[-5:][::-1]
  605. print("Top 5 predictions for first token:")
  606. for idx in top_indices:
  607. token = tokenizer.decode([idx])
  608. print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
  609. print(f"Generated token {next_token_id} ({repr(tokenizer.decode([next_token_id]))})")
  610. # Save KV cache if requested
  611. if args.save_cache:
  612. save_kv_cache(outputs.past_key_values, 0, data_dir, model_name)
  613. # Prepare for next iteration
  614. past_key_values = outputs.past_key_values
  615. current_input = torch.tensor([[next_token_id]], device=device)
  616. # Generate remaining tokens
  617. for step in range(1, args.num_tokens):
  618. print(f"\n=== Generation Step {step} ===")
  619. # Forward pass with cache
  620. outputs = model(
  621. input_ids=current_input,
  622. past_key_values=past_key_values,
  623. use_cache=True
  624. )
  625. logits = outputs.logits
  626. last_logits = logits[0, -1, :].cpu().numpy()
  627. all_logits.append(last_logits)
  628. # Generate next token
  629. next_token_id = np.argmax(last_logits).item()
  630. all_generated_tokens.append(next_token_id)
  631. # Show top 5 predicted tokens for this step
  632. top_indices = np.argsort(last_logits)[-5:][::-1]
  633. print(f"Top 5 predictions for step {step}:")
  634. for idx in top_indices:
  635. token = tokenizer.decode([idx])
  636. print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
  637. print(f"Generated token {next_token_id} ({repr(tokenizer.decode([next_token_id]))})")
  638. # Save KV cache if requested
  639. if args.save_cache:
  640. save_kv_cache(outputs.past_key_values, step, data_dir, model_name)
  641. # Update for next iteration
  642. past_key_values = outputs.past_key_values
  643. current_input = torch.tensor([[next_token_id]], device=device)
  644. # Save results
  645. bin_filename = data_dir / f"pytorch-{model_name}-multi-token.bin"
  646. txt_filename = data_dir / f"pytorch-{model_name}-multi-token.txt"
  647. # Save all logits concatenated
  648. all_logits_array = np.array(all_logits)
  649. all_logits_array.astype(np.float32).tofile(bin_filename)
  650. # Also save as text file for easy inspection
  651. with open(txt_filename, "w") as f:
  652. f.write(f"Generated tokens: {all_generated_tokens}\n")
  653. f.write(f"Generated text: {repr(tokenizer.decode(all_generated_tokens))}\n")
  654. f.write(f"Full sequence: {repr(tokenizer.decode(input_ids[0].tolist() + all_generated_tokens))}\n\n")
  655. for step, logits in enumerate(all_logits):
  656. f.write(f"=== Step {step} logits ===\n")
  657. for i, logit in enumerate(logits):
  658. f.write(f"{i}: {logit:.6f}\n")
  659. f.write("\n")
  660. print(f"\n=== Generation Complete ===")
  661. print(f"Generated {len(all_generated_tokens)} tokens: {all_generated_tokens}")
  662. print(f"Generated text: {repr(tokenizer.decode(all_generated_tokens))}")
  663. print(f"Full sequence: {repr(tokenizer.decode(input_ids[0].tolist() + all_generated_tokens))}")
  664. print(f"Saved bin logits to: {bin_filename}")
  665. print(f"Saved txt logits to: {txt_filename}")
  666. if args.save_cache:
  667. print(f"KV cache saved to: {data_dir}/kv_cache_step_*")