Piotr Wilkin 3 månader sedan
förälder
incheckning
413652178f
1 ändrade filer med 19 tillägg och 9 borttagningar
  1. 19 9
      src/models/llm_build_qwen3next.cpp

+ 19 - 9
src/models/llm_build_qwen3next.cpp

@@ -704,17 +704,27 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
         GGML_ASSERT(num_v_heads % num_k_heads == 0);
         int64_t repeat_factor = num_v_heads / num_k_heads;
 
-        // Step 1: Reshape to add a new dimension for the repeats
-        ggml_tensor * q_reshaped = ggml_reshape_4d(ctx0, q_conv, head_k_dim, num_k_heads, 1, n_seq_tokens * n_seqs);
-        ggml_tensor * k_reshaped = ggml_reshape_4d(ctx0, k_conv, head_k_dim, num_k_heads, 1, n_seq_tokens * n_seqs);
+        // GGML tensor layout: [head_dim, num_heads, n_seq_tokens, n_seqs]
         
-        // Step 2: Expand along the new dimension
-        ggml_tensor * q_expanded = ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs);
-        ggml_tensor * k_expanded = ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs);
+        // Step 1: Flatten the sequence and batch dimensions to work with them more easily
+        ggml_tensor * q_flat = ggml_reshape_2d(ctx0, q_conv, head_k_dim, num_k_heads * n_seq_tokens * n_seqs);
+        ggml_tensor * k_flat = ggml_reshape_2d(ctx0, k_conv, head_k_dim, num_k_heads * n_seq_tokens * n_seqs);
         
-        // Step 3: Reshape back to merge the repeated dimensions
-        q_conv = ggml_reshape_4d(ctx0, q_expanded, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
-        k_conv = ggml_reshape_4d(ctx0, k_expanded, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
+        // Step 2: Reshape to prepare for repeat_interleave
+        // From [head_dim, num_k_heads * n_seq_tokens * n_seqs]
+        // To [head_dim, num_k_heads, 1, n_seq_tokens * n_seqs]
+        ggml_tensor * q_reshaped = ggml_reshape_4d(ctx0, q_flat, head_k_dim, num_k_heads, 1, n_seq_tokens * n_seqs);
+        ggml_tensor * k_reshaped = ggml_reshape_4d(ctx0, k_flat, head_k_dim, num_k_heads, 1, n_seq_tokens * n_seqs);
+        
+        // Step 3: Repeat along the third dimension (the new dimension with size 1)
+        ggml_tensor * q_repeated = ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs);
+        ggml_tensor * k_repeated = ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs);
+        
+        // Step 4: Reshape back to merge the head and repeat dimensions
+        // From [head_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs]
+        // Back to [head_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs]
+        q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
+        k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
     }
 
     cb(q_conv, "q_conv_predelta", il);