|
|
@@ -69,7 +69,7 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
|
|
|
t = t.unsqueeze(0)
|
|
|
elif t.ndim == 4:
|
|
|
_, s, _, _ = t.shape
|
|
|
-
|
|
|
+
|
|
|
else:
|
|
|
print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
|
|
|
return
|
|
|
@@ -124,7 +124,6 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
|
|
|
|
|
|
indexed_patterns = [ r"model\.layers\.[0-9]+_out", r"recurrent_cache_[0-9]+" ]
|
|
|
non_indexed_patterns = [ r"k_pad", r"v_pad", r"q_scaled" ]
|
|
|
-
|
|
|
if any(re.fullmatch(p, name) for p in indexed_patterns):
|
|
|
if name not in token_counter:
|
|
|
token_counter[name] = 1
|
|
|
@@ -135,13 +134,17 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
|
|
|
if any(re.fullmatch(p, name) for p in non_indexed_patterns):
|
|
|
if name not in token_counter:
|
|
|
token_counter[name] = 1
|
|
|
- else:
|
|
|
- token_counter[name] = token_counter[name] + 1
|
|
|
- if name not in layer_counter or layer_counter[name] == num_model_layers - 1:
|
|
|
+
|
|
|
+ if name not in layer_counter:
|
|
|
+ layer_counter[name] = 0
|
|
|
+ elif layer_counter[name] >= num_model_layers:
|
|
|
layer_counter[name] = 0
|
|
|
+ token_counter[name] = token_counter[name] + 1
|
|
|
else:
|
|
|
layer_counter[name] = layer_counter[name] + 1
|
|
|
- save_tensor(t, f"reference/tensors/org/{name}_{layer_counter[name] - 1}_{token_counter[name]}.bin")
|
|
|
+ if layer_counter[name] % 4 == 3:
|
|
|
+ layer_counter[name] = layer_counter[name] + 1 # skip attention layers
|
|
|
+ save_tensor(t, f"reference/tensors/org/{name}_{layer_counter[name]}_{token_counter[name]}.bin")
|
|
|
|
|
|
from transformers.models.qwen3_next.modeling_qwen3_next import torch_causal_conv1d_update, apply_rotary_pos_emb, l2norm # noqa: E402
|
|
|
orig_conv1d_update = torch_causal_conv1d_update
|
|
|
@@ -181,20 +184,20 @@ def save_tensor(tensor, filename):
|
|
|
"""Save tensor to binary file with shape information."""
|
|
|
# Ensure tensors directory exists
|
|
|
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
|
|
-
|
|
|
+
|
|
|
# Convert to numpy and save
|
|
|
np_array = tensor.detach().cpu().numpy()
|
|
|
-
|
|
|
+
|
|
|
# Save shape first (4 int64 values), then data
|
|
|
with open(filename, 'wb') as f:
|
|
|
shape = list(np_array.shape)
|
|
|
while len(shape) < 4:
|
|
|
shape.insert(0, 0)
|
|
|
-
|
|
|
+
|
|
|
# Write shape as int64
|
|
|
shape_array = np.array(shape, dtype=np.int64)
|
|
|
f.write(shape_array.tobytes())
|
|
|
-
|
|
|
+
|
|
|
# Write data as float32
|
|
|
np_array_float32 = np_array.astype(np.float32)
|
|
|
f.write(np_array_float32.tobytes())
|
|
|
@@ -311,19 +314,19 @@ def patched_torch_chunk_gated_delta_rule(
|
|
|
row = attn[..., i, :i].clone()
|
|
|
sub = attn[..., :i, :i].clone()
|
|
|
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
|
|
- #if i <= num_heads and not long:
|
|
|
+ #if i <= num_heads and not long:
|
|
|
#print(f"Chunk {i}: row:\n{row}\n\nsub:\n{sub}\nrow_unsq:\n{row.unsqueeze(-1)}\nrow_unsq * sub:\n{row.unsqueeze(-1)*sub}\n")
|
|
|
#print(f"attn => sum = {attn[..., i, :i].sum()}, tensor: \n{attn[..., i, :i]}\n\n")
|
|
|
summarize(attn, "attn_chunks")
|
|
|
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
|
|
|
summarize(attn, "attn_eye")
|
|
|
-
|
|
|
+
|
|
|
value = attn @ v_beta
|
|
|
summarize(value, "value")
|
|
|
-
|
|
|
+
|
|
|
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
|
|
|
summarize(k_cumdecay, "k_cumdecay")
|
|
|
-
|
|
|
+
|
|
|
last_recurrent_state = (
|
|
|
torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value)
|
|
|
if initial_state is None
|
|
|
@@ -339,25 +342,25 @@ def patched_torch_chunk_gated_delta_rule(
|
|
|
summarize(q_i, f"q_i_chunk_{i}")
|
|
|
summarize(k_i, f"k_i_chunk_{i}")
|
|
|
summarize(v_i, f"v_i_chunk_{i}")
|
|
|
-
|
|
|
+
|
|
|
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
|
|
|
summarize(attn, f"attn_chunk_{i}")
|
|
|
-
|
|
|
+
|
|
|
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
|
|
summarize(v_prime, f"v_prime_chunk_{i}")
|
|
|
-
|
|
|
+
|
|
|
v_new = v_i - v_prime
|
|
|
summarize(v_new, f"v_new_chunk_{i}")
|
|
|
-
|
|
|
+
|
|
|
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
|
|
summarize(attn_inter, f"attn_inter_chunk_{i}")
|
|
|
-
|
|
|
+
|
|
|
core_attn_out[:, :, i] = attn_inter + attn @ v_new
|
|
|
summarize(core_attn_out[:, :, i], f"core_attn_out_chunk_{i}")
|
|
|
-
|
|
|
+
|
|
|
g_last = g[:, :, i, -1, None, None].exp()
|
|
|
summarize(g_last, f"g_last_chunk_{i}")
|
|
|
-
|
|
|
+
|
|
|
g_diff_exp = (g[:, :, i, -1, None] - g[:, :, i]).exp()
|
|
|
last_recurrent_state = (
|
|
|
last_recurrent_state * g_last
|
|
|
@@ -371,7 +374,7 @@ def patched_torch_chunk_gated_delta_rule(
|
|
|
core_attn_out = core_attn_out[:, :, :num_heads]
|
|
|
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
|
|
|
summarize(core_attn_out, "attn_out")
|
|
|
-
|
|
|
+
|
|
|
if isinstance(last_recurrent_state, torch.Tensor):
|
|
|
summarize(last_recurrent_state, "state_out")
|
|
|
|
|
|
@@ -615,28 +618,28 @@ def save_kv_cache(past_key_values, step_num, data_dir, model_name):
|
|
|
"""Save KV cache tensors for each layer"""
|
|
|
cache_dir = data_dir / f"kv_cache_step_{step_num}"
|
|
|
cache_dir.mkdir(exist_ok=True)
|
|
|
-
|
|
|
+
|
|
|
# Access past_key_values if available
|
|
|
if past_key_values is not None:
|
|
|
for layer_idx, cache_tuple in enumerate(past_key_values):
|
|
|
if cache_tuple is None:
|
|
|
print(f"Cache tuple is None for layer {layer_idx} at step {step_num}")
|
|
|
continue
|
|
|
-
|
|
|
+
|
|
|
# Handle different cache formats
|
|
|
if isinstance(cache_tuple, (tuple, list)) and len(cache_tuple) >= 2:
|
|
|
key, value = cache_tuple[0], cache_tuple[1]
|
|
|
-
|
|
|
+
|
|
|
# Check if key and value are not None
|
|
|
if key is not None and value is not None:
|
|
|
# Save key cache
|
|
|
key_filename = cache_dir / f"layer_{layer_idx}_key.bin"
|
|
|
key.detach().cpu().numpy().astype(np.float32).tofile(key_filename)
|
|
|
-
|
|
|
+
|
|
|
# Save value cache
|
|
|
value_filename = cache_dir / f"layer_{layer_idx}_value.bin"
|
|
|
value.detach().cpu().numpy().astype(np.float32).tofile(value_filename)
|
|
|
-
|
|
|
+
|
|
|
print(f"Saved KV cache for layer {layer_idx} at step {step_num}: key.shape={key.shape}, value.shape={value.shape}")
|
|
|
else:
|
|
|
print(f"Key or value is None for layer {layer_idx} at step {step_num}")
|
|
|
@@ -738,67 +741,67 @@ with torch.no_grad():
|
|
|
print(f"\n=== Initial Forward Pass ===")
|
|
|
outputs = model(input_ids, use_cache=True)
|
|
|
logits = outputs.logits
|
|
|
-
|
|
|
+
|
|
|
# Extract logits for the last token (next token prediction)
|
|
|
last_logits = logits[0, -1, :].cpu().numpy()
|
|
|
all_logits.append(last_logits)
|
|
|
-
|
|
|
+
|
|
|
print(f"Logits shape: {logits.shape}")
|
|
|
print(f"Last token logits shape: {last_logits.shape}")
|
|
|
-
|
|
|
+
|
|
|
# Generate first token
|
|
|
next_token_id = np.argmax(last_logits).item()
|
|
|
all_generated_tokens.append(next_token_id)
|
|
|
-
|
|
|
+
|
|
|
# Show top 5 predicted tokens for first step
|
|
|
top_indices = np.argsort(last_logits)[-5:][::-1]
|
|
|
print("Top 5 predictions for first token:")
|
|
|
for idx in top_indices:
|
|
|
token = tokenizer.decode([idx])
|
|
|
print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
|
|
|
-
|
|
|
+
|
|
|
print(f"Generated token {next_token_id} ({repr(tokenizer.decode([next_token_id]))})")
|
|
|
-
|
|
|
+
|
|
|
# Save KV cache if requested
|
|
|
if args.save_cache:
|
|
|
save_kv_cache(outputs.past_key_values, 0, data_dir, model_name)
|
|
|
-
|
|
|
+
|
|
|
# Prepare for next iteration
|
|
|
past_key_values = outputs.past_key_values
|
|
|
current_input = torch.tensor([[next_token_id]], device=device)
|
|
|
-
|
|
|
+
|
|
|
# Generate remaining tokens
|
|
|
for step in range(1, args.num_tokens):
|
|
|
print(f"\n=== Generation Step {step} ===")
|
|
|
-
|
|
|
+
|
|
|
# Forward pass with cache
|
|
|
outputs = model(
|
|
|
- input_ids=current_input,
|
|
|
+ input_ids=current_input,
|
|
|
past_key_values=past_key_values,
|
|
|
use_cache=True
|
|
|
)
|
|
|
-
|
|
|
+
|
|
|
logits = outputs.logits
|
|
|
last_logits = logits[0, -1, :].cpu().numpy()
|
|
|
all_logits.append(last_logits)
|
|
|
-
|
|
|
+
|
|
|
# Generate next token
|
|
|
next_token_id = np.argmax(last_logits).item()
|
|
|
all_generated_tokens.append(next_token_id)
|
|
|
-
|
|
|
+
|
|
|
# Show top 5 predicted tokens for this step
|
|
|
top_indices = np.argsort(last_logits)[-5:][::-1]
|
|
|
print(f"Top 5 predictions for step {step}:")
|
|
|
for idx in top_indices:
|
|
|
token = tokenizer.decode([idx])
|
|
|
print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
|
|
|
-
|
|
|
+
|
|
|
print(f"Generated token {next_token_id} ({repr(tokenizer.decode([next_token_id]))})")
|
|
|
-
|
|
|
+
|
|
|
# Save KV cache if requested
|
|
|
if args.save_cache:
|
|
|
save_kv_cache(outputs.past_key_values, step, data_dir, model_name)
|
|
|
-
|
|
|
+
|
|
|
# Update for next iteration
|
|
|
past_key_values = outputs.past_key_values
|
|
|
current_input = torch.tensor([[next_token_id]], device=device)
|
|
|
@@ -816,7 +819,7 @@ with open(txt_filename, "w") as f:
|
|
|
f.write(f"Generated tokens: {all_generated_tokens}\n")
|
|
|
f.write(f"Generated text: {repr(tokenizer.decode(all_generated_tokens))}\n")
|
|
|
f.write(f"Full sequence: {repr(tokenizer.decode(input_ids[0].tolist() + all_generated_tokens))}\n\n")
|
|
|
-
|
|
|
+
|
|
|
for step, logits in enumerate(all_logits):
|
|
|
f.write(f"=== Step {step} logits ===\n")
|
|
|
for i, logit in enumerate(logits):
|
|
|
@@ -832,4 +835,4 @@ print(f"Saved bin logits to: {bin_filename}")
|
|
|
print(f"Saved txt logits to: {txt_filename}")
|
|
|
|
|
|
if args.save_cache:
|
|
|
- print(f"KV cache saved to: {data_dir}/kv_cache_step_*")
|
|
|
+ print(f"KV cache saved to: {data_dir}/kv_cache_step_*")
|