Parcourir la source

QKV splits done right

Piotr Wilkin il y a 3 mois
Parent
commit
594c1f98ef
1 fichiers modifiés avec 20 ajouts et 22 suppressions
  1. 20 22
      src/models/llm_build_qwen3next.cpp

+ 20 - 22
src/models/llm_build_qwen3next.cpp

@@ -305,11 +305,11 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     cb(mixed_ba, "linear_attn_mixed_ba", il);
     cb(mixed_ba, "linear_attn_mixed_ba", il);
 
 
     int64_t       qkvz_new_dim        = 2 * head_k_dim + 2 * head_v_dim * num_v_heads / num_k_heads;
     int64_t       qkvz_new_dim        = 2 * head_k_dim + 2 * head_v_dim * num_v_heads / num_k_heads;
-    ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_tokens, n_seqs);
+    ggml_tensor * mixed_qkvz_reshaped = ggml_cont_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_tokens, n_seqs);
 
 
     // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads]
     // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads]
     int64_t       ba_new_dim        = 2 * num_v_heads / num_k_heads;
     int64_t       ba_new_dim        = 2 * num_v_heads / num_k_heads;
-    ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_tokens, n_seqs);
+    ggml_tensor * mixed_ba_reshaped = ggml_cont_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_tokens, n_seqs);
 
 
     // Split mixed_ba into b and a (beta and alpha parameters)
     // Split mixed_ba into b and a (beta and alpha parameters)
     int64_t split_sizes_ba[2] = {
     int64_t split_sizes_ba[2] = {
@@ -358,27 +358,23 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
         head_v_dim * num_v_heads / num_k_heads   // z size
         head_v_dim * num_v_heads / num_k_heads   // z size
     };
     };
 
 
-    ggml_tensor * query = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads,
-                                                       n_tokens, n_seqs, split_sizes_qkvz[0] * sizeof(float),
-                                                       mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], 0));
+    ggml_tensor * query = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_tokens, n_seqs, 
+                                            mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0));
     cb(query, "q", il);
     cb(query, "q", il);
 
 
-    ggml_tensor * key =
-        ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_tokens, n_seqs,
-                                     split_sizes_qkvz[1] * sizeof(float), mixed_qkvz_reshaped->nb[1],
-                                     mixed_qkvz_reshaped->nb[2], split_sizes_qkvz[0] * sizeof(float)));
+    ggml_tensor * key = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_tokens, n_seqs,
+                        mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
+                        split_sizes_qkvz[0] * sizeof(float)));
     cb(key, "k", il);
     cb(key, "k", il);
 
 
-    ggml_tensor * value =
-        ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_tokens, n_seqs,
-                     split_sizes_qkvz[2] * sizeof(float), mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2],
-                     (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float));
+    ggml_tensor * value = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_tokens, n_seqs,
+                        mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
+                        (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float)));
     cb(value, "v", il);
     cb(value, "v", il);
 
 
-    ggml_tensor * z =
-        ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_tokens, n_seqs,
-                     split_sizes_qkvz[3] * sizeof(float), mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2],
-                     (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float));
+    ggml_tensor * z = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_tokens, n_seqs,
+                        mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
+                        (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float)));
     cb(z, "z", il);
     cb(z, "z", il);
 
 
     // Reshape value and z to merge head dimensions: [batch, seq_len, num_k_heads, head_v_dim*num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads, head_v_dim]
     // Reshape value and z to merge head dimensions: [batch, seq_len, num_k_heads, head_v_dim*num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads, head_v_dim]
@@ -456,15 +452,17 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
 
 
     // Extract the convolved Q, K, V from conv_output
     // Extract the convolved Q, K, V from conv_output
     ggml_tensor * q_conv = ggml_cont(ctx0, ggml_view_4d(ctx0, conv_output, head_k_dim, num_k_heads, n_tokens, n_seqs,
     ggml_tensor * q_conv = ggml_cont(ctx0, ggml_view_4d(ctx0, conv_output, head_k_dim, num_k_heads, n_tokens, n_seqs,
-                                                        head_k_dim, conv_output->nb[1], conv_output->nb[2], 0));
+                                                        conv_output->nb[1], conv_output->nb[2], conv_output->nb[3], 0));
     cb(q_conv, "q_conv", il);
     cb(q_conv, "q_conv", il);
     ggml_tensor * k_conv = ggml_cont(
     ggml_tensor * k_conv = ggml_cont(
-        ctx0, ggml_view_4d(ctx0, conv_output, head_k_dim, num_k_heads, n_tokens, n_seqs, head_k_dim, conv_output->nb[1],
-                           conv_output->nb[2], head_k_dim * num_k_heads * ggml_element_size(conv_output)));
+        ctx0, ggml_view_4d(ctx0, conv_output, head_k_dim, num_k_heads, n_tokens, n_seqs, 
+                conv_output->nb[1], conv_output->nb[2], conv_output->nb[3],
+                head_k_dim * num_k_heads * ggml_element_size(conv_output)));
     cb(q_conv, "k_conv", il);
     cb(q_conv, "k_conv", il);
     ggml_tensor * v_conv = ggml_cont(
     ggml_tensor * v_conv = ggml_cont(
-        ctx0, ggml_view_4d(ctx0, conv_output, head_v_dim, num_v_heads, n_tokens, n_seqs, head_v_dim, conv_output->nb[1],
-                           conv_output->nb[2], 2 * head_k_dim * num_k_heads * ggml_element_size(conv_output)));
+        ctx0, ggml_view_4d(ctx0, conv_output, head_v_dim, num_v_heads, n_tokens, n_seqs, 
+            conv_output->nb[1], conv_output->nb[2], conv_output->nb[3],
+            2 * head_k_dim * num_k_heads * ggml_element_size(conv_output)));
     cb(q_conv, "v_conv", il);
     cb(q_conv, "v_conv", il);
 
 
     ggml_build_forward_expand(gf, ssm_states_all);
     ggml_build_forward_expand(gf, ssm_states_all);