Piotr Wilkin 3 месяцев назад
Родитель
Сommit
a4fe12821b
1 измененных файлов с 48 добавлено и 45 удалено
  1. 48 45
      examples/model-conversion/scripts/causal/run-org-model-multi-token.py

+ 48 - 45
examples/model-conversion/scripts/causal/run-org-model-multi-token.py

@@ -69,7 +69,7 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
         t = t.unsqueeze(0)
     elif t.ndim == 4:
         _, s, _, _ = t.shape
-    
+
     else:
         print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
         return
@@ -124,7 +124,6 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
 
     indexed_patterns = [ r"model\.layers\.[0-9]+_out", r"recurrent_cache_[0-9]+" ]
     non_indexed_patterns = [ r"k_pad", r"v_pad", r"q_scaled" ]
-    
     if any(re.fullmatch(p, name) for p in indexed_patterns):
         if name not in token_counter:
             token_counter[name] = 1
@@ -135,13 +134,17 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
     if any(re.fullmatch(p, name) for p in non_indexed_patterns):
         if name not in token_counter:
             token_counter[name] = 1
-        else:
-            token_counter[name] = token_counter[name] + 1
-        if name not in layer_counter or layer_counter[name] == num_model_layers - 1:
+
+        if name not in layer_counter:
+            layer_counter[name] = 0
+        elif layer_counter[name] >= num_model_layers:
             layer_counter[name] = 0
+            token_counter[name] = token_counter[name] + 1
         else:
             layer_counter[name] = layer_counter[name] + 1
-        save_tensor(t, f"reference/tensors/org/{name}_{layer_counter[name] - 1}_{token_counter[name]}.bin")
+            if layer_counter[name] % 4 == 3:
+                layer_counter[name] = layer_counter[name] + 1  # skip attention layers
+        save_tensor(t, f"reference/tensors/org/{name}_{layer_counter[name]}_{token_counter[name]}.bin")
 
 from transformers.models.qwen3_next.modeling_qwen3_next import torch_causal_conv1d_update, apply_rotary_pos_emb, l2norm  # noqa: E402
 orig_conv1d_update = torch_causal_conv1d_update
@@ -181,20 +184,20 @@ def save_tensor(tensor, filename):
     """Save tensor to binary file with shape information."""
     # Ensure tensors directory exists
     os.makedirs(os.path.dirname(filename), exist_ok=True)
-    
+
     # Convert to numpy and save
     np_array = tensor.detach().cpu().numpy()
-    
+
     # Save shape first (4 int64 values), then data
     with open(filename, 'wb') as f:
         shape = list(np_array.shape)
         while len(shape) < 4:
             shape.insert(0, 0)
-        
+
         # Write shape as int64
         shape_array = np.array(shape, dtype=np.int64)
         f.write(shape_array.tobytes())
-        
+
         # Write data as float32
         np_array_float32 = np_array.astype(np.float32)
         f.write(np_array_float32.tobytes())
@@ -311,19 +314,19 @@ def patched_torch_chunk_gated_delta_rule(
         row = attn[..., i, :i].clone()
         sub = attn[..., :i, :i].clone()
         attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
-        #if i <= num_heads and not long: 
+        #if i <= num_heads and not long:
             #print(f"Chunk {i}: row:\n{row}\n\nsub:\n{sub}\nrow_unsq:\n{row.unsqueeze(-1)}\nrow_unsq * sub:\n{row.unsqueeze(-1)*sub}\n")
             #print(f"attn => sum = {attn[..., i, :i].sum()}, tensor: \n{attn[..., i, :i]}\n\n")
     summarize(attn, "attn_chunks")
     attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
     summarize(attn, "attn_eye")
-    
+
     value = attn @ v_beta
     summarize(value, "value")
-        
+
     k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
     summarize(k_cumdecay, "k_cumdecay")
-    
+
     last_recurrent_state = (
         torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value)
         if initial_state is None
@@ -339,25 +342,25 @@ def patched_torch_chunk_gated_delta_rule(
         summarize(q_i, f"q_i_chunk_{i}")
         summarize(k_i, f"k_i_chunk_{i}")
         summarize(v_i, f"v_i_chunk_{i}")
-        
+
         attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
         summarize(attn, f"attn_chunk_{i}")
-        
+
         v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
         summarize(v_prime, f"v_prime_chunk_{i}")
-        
+
         v_new = v_i - v_prime
         summarize(v_new, f"v_new_chunk_{i}")
-        
+
         attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
         summarize(attn_inter, f"attn_inter_chunk_{i}")
-        
+
         core_attn_out[:, :, i] = attn_inter + attn @ v_new
         summarize(core_attn_out[:, :, i], f"core_attn_out_chunk_{i}")
-        
+
         g_last = g[:, :, i, -1, None, None].exp()
         summarize(g_last, f"g_last_chunk_{i}")
-        
+
         g_diff_exp = (g[:, :, i, -1, None] - g[:, :, i]).exp()
         last_recurrent_state = (
             last_recurrent_state * g_last
@@ -371,7 +374,7 @@ def patched_torch_chunk_gated_delta_rule(
     core_attn_out = core_attn_out[:, :, :num_heads]
     core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
     summarize(core_attn_out, "attn_out")
-        
+
     if isinstance(last_recurrent_state, torch.Tensor):
         summarize(last_recurrent_state, "state_out")
 
@@ -615,28 +618,28 @@ def save_kv_cache(past_key_values, step_num, data_dir, model_name):
     """Save KV cache tensors for each layer"""
     cache_dir = data_dir / f"kv_cache_step_{step_num}"
     cache_dir.mkdir(exist_ok=True)
-    
+
     # Access past_key_values if available
     if past_key_values is not None:
         for layer_idx, cache_tuple in enumerate(past_key_values):
             if cache_tuple is None:
                 print(f"Cache tuple is None for layer {layer_idx} at step {step_num}")
                 continue
-                
+
             # Handle different cache formats
             if isinstance(cache_tuple, (tuple, list)) and len(cache_tuple) >= 2:
                 key, value = cache_tuple[0], cache_tuple[1]
-                
+
                 # Check if key and value are not None
                 if key is not None and value is not None:
                     # Save key cache
                     key_filename = cache_dir / f"layer_{layer_idx}_key.bin"
                     key.detach().cpu().numpy().astype(np.float32).tofile(key_filename)
-                    
+
                     # Save value cache
                     value_filename = cache_dir / f"layer_{layer_idx}_value.bin"
                     value.detach().cpu().numpy().astype(np.float32).tofile(value_filename)
-                    
+
                     print(f"Saved KV cache for layer {layer_idx} at step {step_num}: key.shape={key.shape}, value.shape={value.shape}")
                 else:
                     print(f"Key or value is None for layer {layer_idx} at step {step_num}")
@@ -738,67 +741,67 @@ with torch.no_grad():
     print(f"\n=== Initial Forward Pass ===")
     outputs = model(input_ids, use_cache=True)
     logits = outputs.logits
-    
+
     # Extract logits for the last token (next token prediction)
     last_logits = logits[0, -1, :].cpu().numpy()
     all_logits.append(last_logits)
-    
+
     print(f"Logits shape: {logits.shape}")
     print(f"Last token logits shape: {last_logits.shape}")
-    
+
     # Generate first token
     next_token_id = np.argmax(last_logits).item()
     all_generated_tokens.append(next_token_id)
-    
+
     # Show top 5 predicted tokens for first step
     top_indices = np.argsort(last_logits)[-5:][::-1]
     print("Top 5 predictions for first token:")
     for idx in top_indices:
         token = tokenizer.decode([idx])
         print(f"  Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
-    
+
     print(f"Generated token {next_token_id} ({repr(tokenizer.decode([next_token_id]))})")
-    
+
     # Save KV cache if requested
     if args.save_cache:
         save_kv_cache(outputs.past_key_values, 0, data_dir, model_name)
-    
+
     # Prepare for next iteration
     past_key_values = outputs.past_key_values
     current_input = torch.tensor([[next_token_id]], device=device)
-    
+
     # Generate remaining tokens
     for step in range(1, args.num_tokens):
         print(f"\n=== Generation Step {step} ===")
-        
+
         # Forward pass with cache
         outputs = model(
-            input_ids=current_input, 
+            input_ids=current_input,
             past_key_values=past_key_values,
             use_cache=True
         )
-        
+
         logits = outputs.logits
         last_logits = logits[0, -1, :].cpu().numpy()
         all_logits.append(last_logits)
-        
+
         # Generate next token
         next_token_id = np.argmax(last_logits).item()
         all_generated_tokens.append(next_token_id)
-        
+
         # Show top 5 predicted tokens for this step
         top_indices = np.argsort(last_logits)[-5:][::-1]
         print(f"Top 5 predictions for step {step}:")
         for idx in top_indices:
             token = tokenizer.decode([idx])
             print(f"  Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
-        
+
         print(f"Generated token {next_token_id} ({repr(tokenizer.decode([next_token_id]))})")
-        
+
         # Save KV cache if requested
         if args.save_cache:
             save_kv_cache(outputs.past_key_values, step, data_dir, model_name)
-        
+
         # Update for next iteration
         past_key_values = outputs.past_key_values
         current_input = torch.tensor([[next_token_id]], device=device)
@@ -816,7 +819,7 @@ with open(txt_filename, "w") as f:
     f.write(f"Generated tokens: {all_generated_tokens}\n")
     f.write(f"Generated text: {repr(tokenizer.decode(all_generated_tokens))}\n")
     f.write(f"Full sequence: {repr(tokenizer.decode(input_ids[0].tolist() + all_generated_tokens))}\n\n")
-    
+
     for step, logits in enumerate(all_logits):
         f.write(f"=== Step {step} logits ===\n")
         for i, logit in enumerate(logits):
@@ -832,4 +835,4 @@ print(f"Saved bin logits to: {bin_filename}")
 print(f"Saved txt logits to: {txt_filename}")
 
 if args.save_cache:
-    print(f"KV cache saved to: {data_dir}/kv_cache_step_*")
+    print(f"KV cache saved to: {data_dir}/kv_cache_step_*")