|
@@ -40,9 +40,11 @@ import numpy as np
|
|
|
### == END ROPE DEBUG ===
|
|
### == END ROPE DEBUG ===
|
|
|
|
|
|
|
|
token_counter = {}
|
|
token_counter = {}
|
|
|
|
|
+layer_counter = {}
|
|
|
|
|
+num_model_layers = 0
|
|
|
|
|
|
|
|
def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
|
|
def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
|
|
|
- global token, token_counter
|
|
|
|
|
|
|
+ global num_model_layers, layer_counter, token_counter
|
|
|
"""
|
|
"""
|
|
|
Print a tensor in llama.cpp debug style.
|
|
Print a tensor in llama.cpp debug style.
|
|
|
|
|
|
|
@@ -120,15 +122,27 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
|
|
|
print(" ]")
|
|
print(" ]")
|
|
|
print(f" sum = {t.sum().item():.6f}\n")
|
|
print(f" sum = {t.sum().item():.6f}\n")
|
|
|
|
|
|
|
|
- pattern = r"model\.layers\.[0-9]+_out"
|
|
|
|
|
- pattern2 = r"recurrent_cache_[0-9]+"
|
|
|
|
|
- if re.fullmatch(pattern, name) or re.fullmatch(pattern2, name):
|
|
|
|
|
|
|
+ indexed_patterns = [ r"model\.layers\.[0-9]+_out", r"recurrent_cache_[0-9]+" ]
|
|
|
|
|
+ non_indexed_patterns = [ r"k_pad", r"v_pad", r"q_pad" ]
|
|
|
|
|
+
|
|
|
|
|
+ if any(re.fullmatch(p, name) for p in indexed_patterns):
|
|
|
if name not in token_counter:
|
|
if name not in token_counter:
|
|
|
token_counter[name] = 1
|
|
token_counter[name] = 1
|
|
|
else:
|
|
else:
|
|
|
token_counter[name] = token_counter[name] + 1
|
|
token_counter[name] = token_counter[name] + 1
|
|
|
save_tensor(t, f"reference/tensors/org/{name}_{token_counter[name]}.bin")
|
|
save_tensor(t, f"reference/tensors/org/{name}_{token_counter[name]}.bin")
|
|
|
|
|
|
|
|
|
|
+ if any(re.fullmatch(p, name) for p in non_indexed_patterns):
|
|
|
|
|
+ if name not in token_counter:
|
|
|
|
|
+ token_counter[name] = 1
|
|
|
|
|
+ else:
|
|
|
|
|
+ token_counter[name] = token_counter[name] + 1
|
|
|
|
|
+ if name not in layer_counter or layer_counter[name] == num_model_layers - 1:
|
|
|
|
|
+ layer_counter[name] = 0
|
|
|
|
|
+ else:
|
|
|
|
|
+ layer_counter[name] = layer_counter[name] + 1
|
|
|
|
|
+ save_tensor(t, f"reference/tensors/org/{name}_{layer_counter[name] - 1}_{token_counter[name]}.bin")
|
|
|
|
|
+
|
|
|
from transformers.models.qwen3_next.modeling_qwen3_next import torch_causal_conv1d_update, apply_rotary_pos_emb, l2norm # noqa: E402
|
|
from transformers.models.qwen3_next.modeling_qwen3_next import torch_causal_conv1d_update, apply_rotary_pos_emb, l2norm # noqa: E402
|
|
|
orig_conv1d_update = torch_causal_conv1d_update
|
|
orig_conv1d_update = torch_causal_conv1d_update
|
|
|
orig_rope = apply_rotary_pos_emb
|
|
orig_rope = apply_rotary_pos_emb
|
|
@@ -223,10 +237,8 @@ def patched_torch_chunk_gated_delta_rule(
|
|
|
chunk_size=64,
|
|
chunk_size=64,
|
|
|
initial_state=None,
|
|
initial_state=None,
|
|
|
output_final_state=False,
|
|
output_final_state=False,
|
|
|
- use_qk_l2norm_in_kernel=False,
|
|
|
|
|
- long=False
|
|
|
|
|
|
|
+ use_qk_l2norm_in_kernel=False
|
|
|
):
|
|
):
|
|
|
- torch.set_printoptions(threshold=10_000_000, sci_mode=False, precision=10, linewidth=200)
|
|
|
|
|
initial_dtype = query.dtype
|
|
initial_dtype = query.dtype
|
|
|
[ summarize(x, y) for (x, y) in ((query, "q_prenorm"), (key, "k_prenorm")) ]
|
|
[ summarize(x, y) for (x, y) in ((query, "q_prenorm"), (key, "k_prenorm")) ]
|
|
|
if use_qk_l2norm_in_kernel:
|
|
if use_qk_l2norm_in_kernel:
|
|
@@ -359,13 +371,10 @@ def patched_torch_chunk_gated_delta_rule(
|
|
|
core_attn_out = core_attn_out[:, :, :num_heads]
|
|
core_attn_out = core_attn_out[:, :, :num_heads]
|
|
|
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
|
|
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
|
|
|
summarize(core_attn_out, "attn_out")
|
|
summarize(core_attn_out, "attn_out")
|
|
|
- if not long:
|
|
|
|
|
- print(f"attn_out:\n{core_attn_out}\n\n")
|
|
|
|
|
|
|
|
|
|
if isinstance(last_recurrent_state, torch.Tensor):
|
|
if isinstance(last_recurrent_state, torch.Tensor):
|
|
|
summarize(last_recurrent_state, "state_out")
|
|
summarize(last_recurrent_state, "state_out")
|
|
|
- if not long:
|
|
|
|
|
- print(f"state_out:\n{last_recurrent_state}\n\n")
|
|
|
|
|
|
|
+
|
|
|
return core_attn_out, last_recurrent_state
|
|
return core_attn_out, last_recurrent_state
|
|
|
|
|
|
|
|
|
|
|
|
@@ -667,6 +676,8 @@ print("Number of layers: ", config.num_hidden_layers)
|
|
|
print("BOS token id: ", config.bos_token_id)
|
|
print("BOS token id: ", config.bos_token_id)
|
|
|
print("EOS token id: ", config.eos_token_id)
|
|
print("EOS token id: ", config.eos_token_id)
|
|
|
|
|
|
|
|
|
|
+num_model_layers = config.num_hidden_layers
|
|
|
|
|
+
|
|
|
print("Loading model and tokenizer using AutoTokenizer:", model_path)
|
|
print("Loading model and tokenizer using AutoTokenizer:", model_path)
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
|
config = AutoConfig.from_pretrained(model_path)
|
|
config = AutoConfig.from_pretrained(model_path)
|