Piotr Wilkin 3 meses atrás
pai
commit
20424d8785
1 arquivos alterados com 10 adições e 22 exclusões
  1. 10 22
      src/models/llm_build_qwen3next.cpp

+ 10 - 22
src/models/llm_build_qwen3next.cpp

@@ -279,24 +279,21 @@ struct ggml_tensor * llm_build_qwen3next::delta_net(
     cb(q, "q_postscale", il);
     cb(beta, "beta_sigmoid", il);   
 
-    // First, permute to chunked format: [S_k, n_tokens, H_k, n_seqs]
+    // Pad first along the token dimension  
+    q = ggml_pad(ctx, q, 0, 0, pad_size, 0); 
+    k = ggml_pad(ctx, k, 0, 0, pad_size, 0);
+    v = ggml_pad(ctx, v, 0, 0, pad_size, 0);
+
     q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3));
-    cb(q, "q_reshape", il);
     k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3));
-    cb(k, "k_reshape", il);
     v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3));
-    cb(v, "v_reshape", il);
     
     beta = ggml_cont(ctx, ggml_permute(ctx, beta, 1, 2, 0, 3));
     cb(beta, "beta_reshape", il);
 
     g = ggml_cont(ctx, ggml_permute(ctx, g, 2, 0, 3, 1));
     cb(g, "g_permute", il);
-
-    // Then, pad the second dimension (n_tokens) to chunk_size
-    q = ggml_pad(ctx, q, 0, pad_size, 0, 0); 
-    k = ggml_pad(ctx, k, 0, pad_size, 0, 0);
-    v = ggml_pad(ctx, v, 0, pad_size, 0, 0);
+    
     // ... except for beta and g, where we pad the last dimension
     beta = ggml_pad(ctx, beta, pad_size, 0, 0, 0);
     g = ggml_pad(ctx, g, pad_size, 0, 0, 0);
@@ -704,23 +701,14 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
         GGML_ASSERT(num_v_heads % num_k_heads == 0);
         int64_t repeat_factor = num_v_heads / num_k_heads;
 
-        // GGML tensor layout: [head_dim, num_heads, n_seq_tokens, n_seqs]
-        
-        // Step 1: Flatten the sequence and batch dimensions to work with them more easily
-        ggml_tensor * q_flat = ggml_reshape_2d(ctx0, q_conv, head_k_dim, num_k_heads * n_seq_tokens * n_seqs);
-        ggml_tensor * k_flat = ggml_reshape_2d(ctx0, k_conv, head_k_dim, num_k_heads * n_seq_tokens * n_seqs);
-        
-        // Step 2: Reshape to prepare for repeat_interleave
-        // From [head_dim, num_k_heads * n_seq_tokens * n_seqs]
-        // To [head_dim, num_k_heads, 1, n_seq_tokens * n_seqs]
-        ggml_tensor * q_reshaped = ggml_reshape_4d(ctx0, q_flat, head_k_dim, num_k_heads, 1, n_seq_tokens * n_seqs);
-        ggml_tensor * k_reshaped = ggml_reshape_4d(ctx0, k_flat, head_k_dim, num_k_heads, 1, n_seq_tokens * n_seqs);
+        ggml_tensor * q_reshaped = ggml_reshape_4d(ctx0, q_conv, head_k_dim, num_k_heads, 1, n_seq_tokens * n_seqs);
+        ggml_tensor * k_reshaped = ggml_reshape_4d(ctx0, k_conv, head_k_dim, num_k_heads, 1, n_seq_tokens * n_seqs);
         
-        // Step 3: Repeat along the third dimension (the new dimension with size 1)
+        // Repeat along the third dimension (the new dimension with size 1)
         ggml_tensor * q_repeated = ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs);
         ggml_tensor * k_repeated = ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs);
         
-        // Step 4: Reshape back to merge the head and repeat dimensions
+        // Reshape back to merge the head and repeat dimensions
         // From [head_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs]
         // Back to [head_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs]
         q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);