|
|
@@ -118,14 +118,15 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
|
|
|
print(f" sum = {t.sum().item():.6f}\n")
|
|
|
|
|
|
pattern = r"model\.layers\.[0-9]+_out"
|
|
|
- if re.fullmatch(pattern, name):
|
|
|
+ pattern2 = r"recurrent_cache_[0-9]+"
|
|
|
+ if re.fullmatch(pattern, name) or re.fullmatch(pattern2, name):
|
|
|
if name not in token_counter:
|
|
|
token_counter[name] = 1
|
|
|
else:
|
|
|
token_counter[name] = token_counter[name] + 1
|
|
|
save_tensor(t, f"reference/tensors/org/{name}_{token_counter[name]}.bin")
|
|
|
|
|
|
-from transformers.models.qwen3_next.modeling_qwen3_next import torch_causal_conv1d_update, apply_rotary_pos_emb # noqa: E402
|
|
|
+from transformers.models.qwen3_next.modeling_qwen3_next import torch_causal_conv1d_update, apply_rotary_pos_emb, l2norm # noqa: E402
|
|
|
orig_conv1d_update = torch_causal_conv1d_update
|
|
|
orig_rope = apply_rotary_pos_emb
|
|
|
import torch.nn.functional as F # noqa: E402
|
|
|
@@ -189,17 +190,17 @@ def patched_apply_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
summarize(k, "RoPE.k_in")
|
|
|
summarize(cos, "cos")
|
|
|
summarize(sin, "sin")
|
|
|
- if q.shape[1] == 2 and k.shape[1] == 1 and k.shape[2] == 1 and not already_dumped_rope:
|
|
|
- already_dumped_rope = True
|
|
|
- print("Dumping input tensors")
|
|
|
- save_tensor(q, "reference/tensors/testrope_q_in.bin")
|
|
|
- save_tensor(k, "reference/tensors/testrope_k_in.bin")
|
|
|
- save_tensor(cos, "reference/tensors/testrope_cos_in.bin")
|
|
|
- save_tensor(sin, "reference/tensors/testrope_sin_in.bin")
|
|
|
+ # if q.shape[1] == 2 and k.shape[1] == 1 and k.shape[2] == 1 and not already_dumped_rope:
|
|
|
+ # already_dumped_rope = True
|
|
|
+ # print("Dumping input tensors")
|
|
|
+ # save_tensor(q, "reference/tensors/testrope_q_in.bin")
|
|
|
+ # save_tensor(k, "reference/tensors/testrope_k_in.bin")
|
|
|
+ # save_tensor(cos, "reference/tensors/testrope_cos_in.bin")
|
|
|
+ # save_tensor(sin, "reference/tensors/testrope_sin_in.bin")
|
|
|
|
|
|
if position_ids:
|
|
|
summarize(position_ids, "position_ids")
|
|
|
- print(f"Rotary dim is {cos.unsqueeze(unsqueeze_dim).shape[-1]}")
|
|
|
+ # print(f"Rotary dim is {cos.unsqueeze(unsqueeze_dim).shape[-1]}")
|
|
|
|
|
|
# call original
|
|
|
q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
|
|
|
@@ -210,9 +211,231 @@ def patched_apply_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
|
|
|
return q_out, k_out
|
|
|
|
|
|
+def patched_torch_chunk_gated_delta_rule(
|
|
|
+ query,
|
|
|
+ key,
|
|
|
+ value,
|
|
|
+ g,
|
|
|
+ beta,
|
|
|
+ chunk_size=64,
|
|
|
+ initial_state=None,
|
|
|
+ output_final_state=False,
|
|
|
+ use_qk_l2norm_in_kernel=False,
|
|
|
+ long=False
|
|
|
+):
|
|
|
+ torch.set_printoptions(threshold=10_000_000, sci_mode=False, precision=10, linewidth=200)
|
|
|
+ initial_dtype = query.dtype
|
|
|
+ [ summarize(x, y) for (x, y) in ((query, "q_prenorm"), (key, "k_prenorm")) ]
|
|
|
+ if use_qk_l2norm_in_kernel:
|
|
|
+ query = l2norm(query, dim=-1, eps=1e-6)
|
|
|
+ key = l2norm(key, dim=-1, eps=1e-6)
|
|
|
+ [ summarize(x, y) for (x, y) in ((query, "q_orig"), (key, "k_orig"), (value, "v_orig"), (beta, "b_orig"), (g, "g_orig")) ]
|
|
|
+ query, key, value, beta, g = [
|
|
|
+ x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
|
|
|
+ ]
|
|
|
+ [ summarize(x, y) for (x, y) in ((query, "q_tra"), (key, "k_tra"), (value, "v_tra"), (beta, "b_tra"), (g, "g_tra")) ]
|
|
|
+ batch_size, sequence_length, num_heads, k_head_dim = key.shape
|
|
|
+ print(f"batch_size = {batch_size}, seq_len = {sequence_length}, num_heads = {num_heads}, k_head_dim = {k_head_dim}")
|
|
|
+ v_head_dim = value.shape[-1]
|
|
|
+ pad_size = (chunk_size - num_heads % chunk_size) % chunk_size
|
|
|
+ print(f"Pad size = {pad_size}, chunk_size = {chunk_size}")
|
|
|
+ query = F.pad(query, (0, 0, 0, pad_size))
|
|
|
+ key = F.pad(key, (0, 0, 0, pad_size))
|
|
|
+ value = F.pad(value, (0, 0, 0, pad_size))
|
|
|
+ beta = F.pad(beta, (0, pad_size))
|
|
|
+ g = F.pad(g, (0, pad_size))
|
|
|
+ [ summarize(x, y) for (x, y) in ((query, "q_pad"), (key, "k_pad"), (value, "v_pad"), (beta, "b_pad"), (g, "g_pad")) ]
|
|
|
+ tot_heads = num_heads + pad_size
|
|
|
+ scale = 1 / (query.shape[-1] ** 0.5)
|
|
|
+ print(f"Scale for delta is {scale} (from {query.shape[-1]})")
|
|
|
+ query = query * scale
|
|
|
+
|
|
|
+ summarize(query, "q_scaled")
|
|
|
+ summarize(key, "k")
|
|
|
+ summarize(beta.unsqueeze(-1), "beta")
|
|
|
+ v_beta = value * beta.unsqueeze(-1)
|
|
|
+ k_beta = key * beta.unsqueeze(-1)
|
|
|
+ summarize(k_beta, "k_beta")
|
|
|
+ summarize(v_beta, "v_beta")
|
|
|
+ # reshape to chunks
|
|
|
+ query, key, value, k_beta, v_beta = [
|
|
|
+ x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)
|
|
|
+ ]
|
|
|
+ g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
|
|
|
+ [ summarize(x, y) for (x, y) in ((query, "q_resh"), (k_beta, "k_beta_resh"), (v_beta, "v_beta_resh"), (key, "k_resh"), (value, "v_resh")) ]
|
|
|
+
|
|
|
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
|
|
|
+
|
|
|
+ # chunk decay
|
|
|
+ g = g.cumsum(dim=-1)
|
|
|
+ summarize(g, "g_cumsum")
|
|
|
+ sub = g.unsqueeze(-1) - g.unsqueeze(-2)
|
|
|
+ bt1, bt2 = torch.broadcast_tensors(g.unsqueeze(-1), g.unsqueeze(-2))
|
|
|
+ summarize(bt1, "bt1")
|
|
|
+ summarize(bt2, "bt2")
|
|
|
+ summarize(sub, "sub")
|
|
|
+ decay_mask = sub.tril()
|
|
|
+ summarize(decay_mask, "sub_tril")
|
|
|
+ decay_mask = decay_mask.exp()
|
|
|
+ summarize(decay_mask, "sub_tril_exp")
|
|
|
+ decay_mask = decay_mask.float()
|
|
|
+ summarize(decay_mask, "sub_tril_exp_float")
|
|
|
+ decay_mask = decay_mask.tril()
|
|
|
+ summarize(decay_mask, "decay_mask")
|
|
|
+ k_t = key.transpose(-1, -2)
|
|
|
+ summarize(k_t, "k_t")
|
|
|
+ kmul = k_beta @ k_t
|
|
|
+ summarize(kmul, "k_beta @ k_t")
|
|
|
+ #if not long:
|
|
|
+ #print(f"k_beta @ k_t:\n{kmul[:,:,:,:8,:8]}\n\n")
|
|
|
+ kmul_decay = kmul * decay_mask
|
|
|
+ summarize(kmul_decay, "(k_beta @ k_t) * decay_mask")
|
|
|
+ attn = -(kmul_decay).masked_fill(mask, 0)
|
|
|
+ summarize(attn, "attn_in")
|
|
|
+ for i in range(1, chunk_size):
|
|
|
+ row = attn[..., i, :i].clone()
|
|
|
+ sub = attn[..., :i, :i].clone()
|
|
|
+ attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
|
|
+ #if i <= num_heads and not long:
|
|
|
+ #print(f"Chunk {i}: row:\n{row}\n\nsub:\n{sub}\nrow_unsq:\n{row.unsqueeze(-1)}\nrow_unsq * sub:\n{row.unsqueeze(-1)*sub}\n")
|
|
|
+ #print(f"attn => sum = {attn[..., i, :i].sum()}, tensor: \n{attn[..., i, :i]}\n\n")
|
|
|
+ summarize(attn, "attn_chunks")
|
|
|
+ attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
|
|
|
+ summarize(attn, "attn_eye")
|
|
|
+
|
|
|
+ value = attn @ v_beta
|
|
|
+ summarize(value, "value")
|
|
|
+
|
|
|
+ k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
|
|
|
+ summarize(k_cumdecay, "k_cumdecay")
|
|
|
+
|
|
|
+ last_recurrent_state = (
|
|
|
+ torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value)
|
|
|
+ if initial_state is None
|
|
|
+ else initial_state.to(value)
|
|
|
+ )
|
|
|
+ core_attn_out = torch.zeros_like(value)
|
|
|
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
|
|
|
+
|
|
|
+ # for each chunk
|
|
|
+ for i in range(0, tot_heads // chunk_size):
|
|
|
+ print(f"\n=== Processing chunk {i} ===")
|
|
|
+ q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
|
|
|
+ summarize(q_i, f"q_i_chunk_{i}")
|
|
|
+ summarize(k_i, f"k_i_chunk_{i}")
|
|
|
+ summarize(v_i, f"v_i_chunk_{i}")
|
|
|
+
|
|
|
+ attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
|
|
|
+ summarize(attn, f"attn_chunk_{i}")
|
|
|
+
|
|
|
+ v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
|
|
+ summarize(v_prime, f"v_prime_chunk_{i}")
|
|
|
+
|
|
|
+ v_new = v_i - v_prime
|
|
|
+ summarize(v_new, f"v_new_chunk_{i}")
|
|
|
+
|
|
|
+ attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
|
|
+ summarize(attn_inter, f"attn_inter_chunk_{i}")
|
|
|
+
|
|
|
+ core_attn_out[:, :, i] = attn_inter + attn @ v_new
|
|
|
+ summarize(core_attn_out[:, :, i], f"core_attn_out_chunk_{i}")
|
|
|
+
|
|
|
+ g_last = g[:, :, i, -1, None, None].exp()
|
|
|
+ summarize(g_last, f"g_last_chunk_{i}")
|
|
|
+
|
|
|
+ g_diff_exp = (g[:, :, i, -1, None] - g[:, :, i]).exp()
|
|
|
+ last_recurrent_state = (
|
|
|
+ last_recurrent_state * g_last
|
|
|
+ + (k_i * g_diff_exp[..., None]).transpose(-1, -2) @ v_new
|
|
|
+ )
|
|
|
+ summarize(last_recurrent_state, f"updated_state_chunk_{i}")
|
|
|
+
|
|
|
+ if not output_final_state:
|
|
|
+ last_recurrent_state = None
|
|
|
+ core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
|
|
|
+ core_attn_out = core_attn_out[:, :, :num_heads]
|
|
|
+ core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
|
|
|
+ summarize(core_attn_out, "attn_out")
|
|
|
+ if not long:
|
|
|
+ print(f"attn_out:\n{core_attn_out}\n\n")
|
|
|
+
|
|
|
+ if isinstance(last_recurrent_state, torch.Tensor):
|
|
|
+ summarize(last_recurrent_state, "state_out")
|
|
|
+ if not long:
|
|
|
+ print(f"state_out:\n{last_recurrent_state}\n\n")
|
|
|
+ return core_attn_out, last_recurrent_state
|
|
|
+
|
|
|
+
|
|
|
+def patched_torch_recurrent_gated_delta_rule(
|
|
|
+ query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False
|
|
|
+):
|
|
|
+ initial_dtype = query.dtype
|
|
|
+ if use_qk_l2norm_in_kernel:
|
|
|
+ query = l2norm(query, dim=-1, eps=1e-6)
|
|
|
+ key = l2norm(key, dim=-1, eps=1e-6)
|
|
|
+ query, key, value, beta, g = [
|
|
|
+ x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
|
|
|
+ ]
|
|
|
+ summarize(query, "q_t")
|
|
|
+ summarize(key, "k_t")
|
|
|
+ summarize(value, "v_t")
|
|
|
+ summarize(beta, "beta_t")
|
|
|
+ summarize(g, "g_t")
|
|
|
+
|
|
|
+ batch_size, num_heads, sequence_length, k_head_dim = key.shape
|
|
|
+ v_head_dim = value.shape[-1]
|
|
|
+ scale = 1 / (query.shape[-1] ** 0.5)
|
|
|
+ query = query * scale
|
|
|
+
|
|
|
+ summarize(query, "q_scaled")
|
|
|
+ if initial_state is not None:
|
|
|
+ summarize(initial_state, "initial_state")
|
|
|
+
|
|
|
+ core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value)
|
|
|
+ last_recurrent_state = (
|
|
|
+ torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
|
|
|
+ if initial_state is None
|
|
|
+ else initial_state.to(value)
|
|
|
+ )
|
|
|
+
|
|
|
+ for i in range(sequence_length):
|
|
|
+ q_t = query[:, :, i]
|
|
|
+ k_t = key[:, :, i]
|
|
|
+ v_t = value[:, :, i]
|
|
|
+ g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
|
|
|
+ summarize(g_t, "g_exp_unsq")
|
|
|
+ beta_t = beta[:, :, i].unsqueeze(-1)
|
|
|
+ summarize(beta_t, "beta_t_unsq")
|
|
|
+
|
|
|
+ last_recurrent_state = last_recurrent_state * g_t
|
|
|
+ summarize(last_recurrent_state, "gated_state")
|
|
|
+ k_unsq = k_t.unsqueeze(-1)
|
|
|
+ summarize(k_unsq, "k_unsqueeze")
|
|
|
+ state_k = last_recurrent_state * k_unsq
|
|
|
+ summarize(state_k, "state_k_product")
|
|
|
+ kv_mem = state_k.sum(dim=-2)
|
|
|
+ summarize(kv_mem, "kv_mem")
|
|
|
+ delta = (v_t - kv_mem) * beta_t
|
|
|
+ summarize(delta, "delta")
|
|
|
+ k_delta = k_t.unsqueeze(-1) * delta.unsqueeze(-2)
|
|
|
+ summarize(k_delta, "k_delta")
|
|
|
+ last_recurrent_state = last_recurrent_state + k_delta
|
|
|
+ summarize(last_recurrent_state, "state_plus_k_delta")
|
|
|
+ state_q_prod = last_recurrent_state * q_t.unsqueeze(-1)
|
|
|
+ summarize(state_q_prod, "state_q_product")
|
|
|
+ core_attn_out[:, :, i] = state_q_prod.sum(dim=-2)
|
|
|
+ summarize(core_attn_out, "core_attn_out")
|
|
|
+
|
|
|
+ if not output_final_state:
|
|
|
+ last_recurrent_state = None
|
|
|
+ core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
|
|
|
+ return core_attn_out, last_recurrent_state
|
|
|
+
|
|
|
import transformers.models.qwen3_next.modeling_qwen3_next as qwen_mod # noqa: E402
|
|
|
+qwen_mod.torch_chunk_gated_delta_rule = patched_torch_chunk_gated_delta_rule
|
|
|
qwen_mod.torch_causal_conv1d_update = patched_torch_causal_conv1d_update
|
|
|
qwen_mod.apply_rotary_pos_emb = patched_apply_rope
|
|
|
+qwen_mod.torch_recurrent_gated_delta_rule = patched_torch_recurrent_gated_delta_rule
|
|
|
|
|
|
# Store original functions for patching
|
|
|
original_functions = {}
|
|
|
@@ -259,6 +482,18 @@ def patch_all_forward_methods(model):
|
|
|
# Call original forward
|
|
|
result = orig_forward(*args, **kwargs)
|
|
|
|
|
|
+ if mod_name.endswith("linear_attn"):
|
|
|
+ cache = kwargs["cache_params"]
|
|
|
+ nameparts = mod_name.split(".")
|
|
|
+ layer_idx = -1
|
|
|
+ try:
|
|
|
+ layer_idx = int(nameparts[2])
|
|
|
+ except (ValueError, IndexError):
|
|
|
+ print(f"\n\nDEBUG: Failed to calculate layer index for module: {mod_name}\n\n")
|
|
|
+ rec_cache = cache.recurrent_states[layer_idx]
|
|
|
+ if rec_cache is not None:
|
|
|
+ summarize(rec_cache, f"recurrent_cache_{layer_idx}")
|
|
|
+
|
|
|
# Log output
|
|
|
if isinstance(result, torch.Tensor):
|
|
|
summarize(result, f"{mod_name}.forward.out")
|