Piotr Wilkin 3 месяцев назад
Родитель
Сommit
10032affcf
1 измененных файлов с 245 добавлено и 10 удалено
  1. 245 10
      examples/model-conversion/scripts/causal/run-org-model-multi-token.py

+ 245 - 10
examples/model-conversion/scripts/causal/run-org-model-multi-token.py

@@ -118,14 +118,15 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
     print(f"                                     sum = {t.sum().item():.6f}\n")
 
     pattern = r"model\.layers\.[0-9]+_out"
-    if re.fullmatch(pattern, name):
+    pattern2 = r"recurrent_cache_[0-9]+"
+    if re.fullmatch(pattern, name) or re.fullmatch(pattern2, name):
         if name not in token_counter:
             token_counter[name] = 1
         else:
             token_counter[name] = token_counter[name] + 1
         save_tensor(t, f"reference/tensors/org/{name}_{token_counter[name]}.bin")
 
-from transformers.models.qwen3_next.modeling_qwen3_next import torch_causal_conv1d_update, apply_rotary_pos_emb  # noqa: E402
+from transformers.models.qwen3_next.modeling_qwen3_next import torch_causal_conv1d_update, apply_rotary_pos_emb, l2norm  # noqa: E402
 orig_conv1d_update = torch_causal_conv1d_update
 orig_rope = apply_rotary_pos_emb
 import torch.nn.functional as F  # noqa: E402
@@ -189,17 +190,17 @@ def patched_apply_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
     summarize(k, "RoPE.k_in")
     summarize(cos, "cos")
     summarize(sin, "sin")
-    if q.shape[1] == 2 and k.shape[1] == 1 and k.shape[2] == 1 and not already_dumped_rope:
-        already_dumped_rope = True
-        print("Dumping input tensors")
-        save_tensor(q, "reference/tensors/testrope_q_in.bin")
-        save_tensor(k, "reference/tensors/testrope_k_in.bin")
-        save_tensor(cos, "reference/tensors/testrope_cos_in.bin")
-        save_tensor(sin, "reference/tensors/testrope_sin_in.bin")
+    # if q.shape[1] == 2 and k.shape[1] == 1 and k.shape[2] == 1 and not already_dumped_rope:
+    #     already_dumped_rope = True
+    #     print("Dumping input tensors")
+    #     save_tensor(q, "reference/tensors/testrope_q_in.bin")
+    #     save_tensor(k, "reference/tensors/testrope_k_in.bin")
+    #     save_tensor(cos, "reference/tensors/testrope_cos_in.bin")
+    #     save_tensor(sin, "reference/tensors/testrope_sin_in.bin")
 
     if position_ids:
         summarize(position_ids, "position_ids")
-    print(f"Rotary dim is {cos.unsqueeze(unsqueeze_dim).shape[-1]}")
+    # print(f"Rotary dim is {cos.unsqueeze(unsqueeze_dim).shape[-1]}")
 
     # call original
     q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
@@ -210,9 +211,231 @@ def patched_apply_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
 
     return q_out, k_out
 
+def patched_torch_chunk_gated_delta_rule(
+    query,
+    key,
+    value,
+    g,
+    beta,
+    chunk_size=64,
+    initial_state=None,
+    output_final_state=False,
+    use_qk_l2norm_in_kernel=False,
+    long=False
+):
+    torch.set_printoptions(threshold=10_000_000, sci_mode=False, precision=10, linewidth=200)
+    initial_dtype = query.dtype
+    [ summarize(x, y) for (x, y) in ((query, "q_prenorm"), (key, "k_prenorm")) ]
+    if use_qk_l2norm_in_kernel:
+        query = l2norm(query, dim=-1, eps=1e-6)
+        key = l2norm(key, dim=-1, eps=1e-6)
+    [ summarize(x, y) for (x, y) in ((query, "q_orig"), (key, "k_orig"), (value, "v_orig"), (beta, "b_orig"), (g, "g_orig")) ]
+    query, key, value, beta, g = [
+        x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
+    ]
+    [ summarize(x, y) for (x, y) in ((query, "q_tra"), (key, "k_tra"), (value, "v_tra"), (beta, "b_tra"), (g, "g_tra")) ]
+    batch_size, sequence_length, num_heads, k_head_dim = key.shape
+    print(f"batch_size = {batch_size}, seq_len = {sequence_length}, num_heads = {num_heads}, k_head_dim = {k_head_dim}")
+    v_head_dim = value.shape[-1]
+    pad_size = (chunk_size - num_heads % chunk_size) % chunk_size
+    print(f"Pad size = {pad_size}, chunk_size = {chunk_size}")
+    query = F.pad(query, (0, 0, 0, pad_size))
+    key = F.pad(key, (0, 0, 0, pad_size))
+    value = F.pad(value, (0, 0, 0, pad_size))
+    beta = F.pad(beta, (0, pad_size))
+    g = F.pad(g, (0, pad_size))
+    [ summarize(x, y) for (x, y) in ((query, "q_pad"), (key, "k_pad"), (value, "v_pad"), (beta, "b_pad"), (g, "g_pad")) ]
+    tot_heads = num_heads + pad_size
+    scale = 1 / (query.shape[-1] ** 0.5)
+    print(f"Scale for delta is {scale} (from {query.shape[-1]})")
+    query = query * scale
+
+    summarize(query, "q_scaled")
+    summarize(key, "k")
+    summarize(beta.unsqueeze(-1), "beta")
+    v_beta = value * beta.unsqueeze(-1)
+    k_beta = key * beta.unsqueeze(-1)
+    summarize(k_beta, "k_beta")
+    summarize(v_beta, "v_beta")
+    # reshape to chunks
+    query, key, value, k_beta, v_beta = [
+        x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)
+    ]
+    g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
+    [ summarize(x, y) for (x, y) in ((query, "q_resh"), (k_beta, "k_beta_resh"), (v_beta, "v_beta_resh"), (key, "k_resh"), (value, "v_resh")) ]
+
+    mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
+
+    # chunk decay
+    g = g.cumsum(dim=-1)
+    summarize(g, "g_cumsum")
+    sub = g.unsqueeze(-1) - g.unsqueeze(-2)
+    bt1, bt2 = torch.broadcast_tensors(g.unsqueeze(-1), g.unsqueeze(-2))
+    summarize(bt1, "bt1")
+    summarize(bt2, "bt2")
+    summarize(sub, "sub")
+    decay_mask = sub.tril()
+    summarize(decay_mask, "sub_tril")
+    decay_mask = decay_mask.exp()
+    summarize(decay_mask, "sub_tril_exp")
+    decay_mask = decay_mask.float()
+    summarize(decay_mask, "sub_tril_exp_float")
+    decay_mask = decay_mask.tril()
+    summarize(decay_mask, "decay_mask")
+    k_t = key.transpose(-1, -2)
+    summarize(k_t, "k_t")
+    kmul = k_beta @ k_t
+    summarize(kmul, "k_beta @ k_t")
+    #if not long:
+        #print(f"k_beta @ k_t:\n{kmul[:,:,:,:8,:8]}\n\n")
+    kmul_decay = kmul * decay_mask
+    summarize(kmul_decay, "(k_beta @ k_t) * decay_mask")
+    attn = -(kmul_decay).masked_fill(mask, 0)
+    summarize(attn, "attn_in")
+    for i in range(1, chunk_size):
+        row = attn[..., i, :i].clone()
+        sub = attn[..., :i, :i].clone()
+        attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
+        #if i <= num_heads and not long: 
+            #print(f"Chunk {i}: row:\n{row}\n\nsub:\n{sub}\nrow_unsq:\n{row.unsqueeze(-1)}\nrow_unsq * sub:\n{row.unsqueeze(-1)*sub}\n")
+            #print(f"attn => sum = {attn[..., i, :i].sum()}, tensor: \n{attn[..., i, :i]}\n\n")
+    summarize(attn, "attn_chunks")
+    attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
+    summarize(attn, "attn_eye")
+    
+    value = attn @ v_beta
+    summarize(value, "value")
+        
+    k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
+    summarize(k_cumdecay, "k_cumdecay")
+    
+    last_recurrent_state = (
+        torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value)
+        if initial_state is None
+        else initial_state.to(value)
+    )
+    core_attn_out = torch.zeros_like(value)
+    mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
+
+    # for each chunk
+    for i in range(0, tot_heads // chunk_size):
+        print(f"\n=== Processing chunk {i} ===")
+        q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
+        summarize(q_i, f"q_i_chunk_{i}")
+        summarize(k_i, f"k_i_chunk_{i}")
+        summarize(v_i, f"v_i_chunk_{i}")
+        
+        attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
+        summarize(attn, f"attn_chunk_{i}")
+        
+        v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
+        summarize(v_prime, f"v_prime_chunk_{i}")
+        
+        v_new = v_i - v_prime
+        summarize(v_new, f"v_new_chunk_{i}")
+        
+        attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
+        summarize(attn_inter, f"attn_inter_chunk_{i}")
+        
+        core_attn_out[:, :, i] = attn_inter + attn @ v_new
+        summarize(core_attn_out[:, :, i], f"core_attn_out_chunk_{i}")
+        
+        g_last = g[:, :, i, -1, None, None].exp()
+        summarize(g_last, f"g_last_chunk_{i}")
+        
+        g_diff_exp = (g[:, :, i, -1, None] - g[:, :, i]).exp()
+        last_recurrent_state = (
+            last_recurrent_state * g_last
+            + (k_i * g_diff_exp[..., None]).transpose(-1, -2) @ v_new
+        )
+        summarize(last_recurrent_state, f"updated_state_chunk_{i}")
+
+    if not output_final_state:
+        last_recurrent_state = None
+    core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
+    core_attn_out = core_attn_out[:, :, :num_heads]
+    core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
+    summarize(core_attn_out, "attn_out")
+    if not long:
+        print(f"attn_out:\n{core_attn_out}\n\n")
+        
+    if isinstance(last_recurrent_state, torch.Tensor):
+        summarize(last_recurrent_state, "state_out")
+        if not long:
+            print(f"state_out:\n{last_recurrent_state}\n\n")
+    return core_attn_out, last_recurrent_state
+
+
+def patched_torch_recurrent_gated_delta_rule(
+    query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False
+):
+    initial_dtype = query.dtype
+    if use_qk_l2norm_in_kernel:
+        query = l2norm(query, dim=-1, eps=1e-6)
+        key = l2norm(key, dim=-1, eps=1e-6)
+    query, key, value, beta, g = [
+        x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
+    ]
+    summarize(query, "q_t")
+    summarize(key, "k_t")
+    summarize(value, "v_t")
+    summarize(beta, "beta_t")
+    summarize(g, "g_t")
+
+    batch_size, num_heads, sequence_length, k_head_dim = key.shape
+    v_head_dim = value.shape[-1]
+    scale = 1 / (query.shape[-1] ** 0.5)
+    query = query * scale
+
+    summarize(query, "q_scaled")
+    if initial_state is not None:
+        summarize(initial_state, "initial_state")
+
+    core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value)
+    last_recurrent_state = (
+        torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
+        if initial_state is None
+        else initial_state.to(value)
+    )
+
+    for i in range(sequence_length):
+        q_t = query[:, :, i]
+        k_t = key[:, :, i]
+        v_t = value[:, :, i]
+        g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
+        summarize(g_t, "g_exp_unsq")
+        beta_t = beta[:, :, i].unsqueeze(-1)
+        summarize(beta_t, "beta_t_unsq")
+
+        last_recurrent_state = last_recurrent_state * g_t
+        summarize(last_recurrent_state, "gated_state")
+        k_unsq = k_t.unsqueeze(-1)
+        summarize(k_unsq, "k_unsqueeze")
+        state_k = last_recurrent_state * k_unsq
+        summarize(state_k, "state_k_product")
+        kv_mem = state_k.sum(dim=-2)
+        summarize(kv_mem, "kv_mem")
+        delta = (v_t - kv_mem) * beta_t
+        summarize(delta, "delta")
+        k_delta = k_t.unsqueeze(-1) * delta.unsqueeze(-2)
+        summarize(k_delta, "k_delta")
+        last_recurrent_state = last_recurrent_state + k_delta
+        summarize(last_recurrent_state, "state_plus_k_delta")
+        state_q_prod = last_recurrent_state * q_t.unsqueeze(-1)
+        summarize(state_q_prod, "state_q_product")
+        core_attn_out[:, :, i] = state_q_prod.sum(dim=-2)
+        summarize(core_attn_out, "core_attn_out")
+
+    if not output_final_state:
+        last_recurrent_state = None
+    core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
+    return core_attn_out, last_recurrent_state
+
 import transformers.models.qwen3_next.modeling_qwen3_next as qwen_mod  # noqa: E402
+qwen_mod.torch_chunk_gated_delta_rule = patched_torch_chunk_gated_delta_rule
 qwen_mod.torch_causal_conv1d_update = patched_torch_causal_conv1d_update
 qwen_mod.apply_rotary_pos_emb = patched_apply_rope
+qwen_mod.torch_recurrent_gated_delta_rule = patched_torch_recurrent_gated_delta_rule
 
 # Store original functions for patching
 original_functions = {}
@@ -259,6 +482,18 @@ def patch_all_forward_methods(model):
                     # Call original forward
                     result = orig_forward(*args, **kwargs)
 
+                    if mod_name.endswith("linear_attn"):
+                        cache = kwargs["cache_params"]
+                        nameparts = mod_name.split(".")
+                        layer_idx = -1
+                        try:
+                            layer_idx = int(nameparts[2])
+                        except (ValueError, IndexError):
+                            print(f"\n\nDEBUG: Failed to calculate layer index for module: {mod_name}\n\n")
+                        rec_cache = cache.recurrent_states[layer_idx]
+                        if rec_cache is not None:
+                            summarize(rec_cache, f"recurrent_cache_{layer_idx}")
+
                     # Log output
                     if isinstance(result, torch.Tensor):
                         summarize(result, f"{mod_name}.forward.out")