|
|
@@ -704,17 +704,27 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
GGML_ASSERT(num_v_heads % num_k_heads == 0);
|
|
|
int64_t repeat_factor = num_v_heads / num_k_heads;
|
|
|
|
|
|
- // Step 1: Reshape to add a new dimension for the repeats
|
|
|
- ggml_tensor * q_reshaped = ggml_reshape_4d(ctx0, q_conv, head_k_dim, num_k_heads, 1, n_seq_tokens * n_seqs);
|
|
|
- ggml_tensor * k_reshaped = ggml_reshape_4d(ctx0, k_conv, head_k_dim, num_k_heads, 1, n_seq_tokens * n_seqs);
|
|
|
+ // GGML tensor layout: [head_dim, num_heads, n_seq_tokens, n_seqs]
|
|
|
|
|
|
- // Step 2: Expand along the new dimension
|
|
|
- ggml_tensor * q_expanded = ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs);
|
|
|
- ggml_tensor * k_expanded = ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs);
|
|
|
+ // Step 1: Flatten the sequence and batch dimensions to work with them more easily
|
|
|
+ ggml_tensor * q_flat = ggml_reshape_2d(ctx0, q_conv, head_k_dim, num_k_heads * n_seq_tokens * n_seqs);
|
|
|
+ ggml_tensor * k_flat = ggml_reshape_2d(ctx0, k_conv, head_k_dim, num_k_heads * n_seq_tokens * n_seqs);
|
|
|
|
|
|
- // Step 3: Reshape back to merge the repeated dimensions
|
|
|
- q_conv = ggml_reshape_4d(ctx0, q_expanded, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
|
|
|
- k_conv = ggml_reshape_4d(ctx0, k_expanded, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
|
|
|
+ // Step 2: Reshape to prepare for repeat_interleave
|
|
|
+ // From [head_dim, num_k_heads * n_seq_tokens * n_seqs]
|
|
|
+ // To [head_dim, num_k_heads, 1, n_seq_tokens * n_seqs]
|
|
|
+ ggml_tensor * q_reshaped = ggml_reshape_4d(ctx0, q_flat, head_k_dim, num_k_heads, 1, n_seq_tokens * n_seqs);
|
|
|
+ ggml_tensor * k_reshaped = ggml_reshape_4d(ctx0, k_flat, head_k_dim, num_k_heads, 1, n_seq_tokens * n_seqs);
|
|
|
+
|
|
|
+ // Step 3: Repeat along the third dimension (the new dimension with size 1)
|
|
|
+ ggml_tensor * q_repeated = ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs);
|
|
|
+ ggml_tensor * k_repeated = ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs);
|
|
|
+
|
|
|
+ // Step 4: Reshape back to merge the head and repeat dimensions
|
|
|
+ // From [head_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs]
|
|
|
+ // Back to [head_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs]
|
|
|
+ q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
|
|
|
+ k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
|
|
|
}
|
|
|
|
|
|
cb(q_conv, "q_conv_predelta", il);
|