Просмотр исходного кода

Yes, I finally managed to implement it with ssm_conv :>

Piotr Wilkin 3 месяцев назад
Родитель
Сommit
232ec56251
1 измененных файлов с 22 добавлено и 17 удалено
  1. 22 17
      src/models/llm_build_qwen3next.cpp

+ 22 - 17
src/models/llm_build_qwen3next.cpp

@@ -501,8 +501,10 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs]
     ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
     qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
+    cb(qkv_mixed, "qkv_mixed", il);
+
     qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
-    cb(qkv_mixed, "qkv_mixed_concatenated", il);
+    cb(qkv_mixed, "qkv_mixed_permuted", il);
 
     // Calculate the total conv dimension
     int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
@@ -511,7 +513,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d;
     const int64_t conv_kernel_size = conv_kernel->ne[0];
     const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
-    conv_kernel = ggml_permute(ctx0, conv_kernel, 0, 2, 1, 3);
+    //conv_kernel = ggml_permute(ctx0, conv_kernel, 0, 2, 1, 3);
     conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
     cb(conv_states, "conv_states_reshaped", il);
 
@@ -522,29 +524,32 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     // Extract the last (conv_kernel_size - 1) states from conv_input
     ggml_tensor * last_conv_states =
         ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], conv_input->nb[2],
-            n_seq_tokens * (conv_input->nb[0]));
+            (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
+    cb(last_conv_states, "last_conv_states", il);
+
+    ggml_tensor * state_update_target = ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs,
+        mctx_cur->get_head() * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
+    cb(state_update_target, "state_update_target", il);
 
-    ggml_build_forward_expand(gf,
-        ggml_cpy(ctx0, last_conv_states, ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs,
-            mctx_cur->get_head() * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all))));
+    ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
     cb(conv_states_all, "conv_states_updated", il);
 
     // Apply convolution
-    ggml_tensor * conv_output = ggml_conv_1d_dw_f32(ctx0, conv_kernel, conv_input, 1, conv_kernel_size - 1, 1);
-    cb(conv_output, "conv_output_raw", il);
+    ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); //ggml_conv_1d_dw_f32(ctx0, conv_kernel, conv_input, 1, conv_kernel_size - 1, 1);
+    cb(conv_output_proper, "conv_output_raw", il);
 
     // Remove the padding
-    ggml_tensor * conv_output_no_padding = ggml_view_4d(ctx0, conv_output, conv_output->ne[0] - (conv_kernel_size - 1), conv_output->ne[1], conv_output->ne[2], conv_output->ne[3],
-                                                        conv_output->nb[1], conv_output->nb[2], conv_output->nb[3], 
-                                                        (conv_kernel_size - 1) * ggml_element_size(conv_output));
-    cb(conv_output_no_padding, "conv_output_no_padding", il);
+    // ggml_tensor * conv_output_no_padding = ggml_view_4d(ctx0, conv_output, conv_output->ne[0] - (conv_kernel_size - 1), conv_output->ne[1], conv_output->ne[2], conv_output->ne[3],
+    //                                                     conv_output->nb[1], conv_output->nb[2], conv_output->nb[3], 
+    //                                                     (conv_kernel_size - 1) * ggml_element_size(conv_output));
+    // cb(conv_output_no_padding, "conv_output_no_padding", il);
 
-    // Take only the first n_seq_tokens values
-    ggml_tensor * conv_output_proper = ggml_view_4d(ctx0, conv_output_no_padding, n_seq_tokens, conv_output_no_padding->ne[1], conv_output_no_padding->ne[2], conv_output_no_padding->ne[3],
-                                                        conv_output_no_padding->nb[1], conv_output_no_padding->nb[2], conv_output_no_padding->nb[3], 0);
-    cb(conv_output_proper, "conv_output_proper", il);
+    // // Take only the first n_seq_tokens values
+    // ggml_tensor * conv_output_proper = ggml_view_4d(ctx0, conv_output_no_padding, n_seq_tokens, conv_output_no_padding->ne[1], conv_output_no_padding->ne[2], conv_output_no_padding->ne[3],
+    //                                                     conv_output_no_padding->nb[1], conv_output_no_padding->nb[2], conv_output_no_padding->nb[3], 0);
+    // cb(conv_output_proper, "conv_output_proper", il);
 
-    conv_output_proper = ggml_permute(ctx0, conv_output_proper, 0, 1, 3, 2);
+    conv_output_proper = ggml_transpose(ctx0, conv_output_proper);
     conv_output_proper = ggml_cont_4d(ctx0, conv_output_proper, qkv_dim, 1, n_seq_tokens, n_seqs);
 
     ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper);