Просмотр исходного кода

Transpose input for convolution

Piotr Wilkin 3 месяцев назад
Родитель
Сommit
638057a29b
1 измененных файлов с 1 добавлено и 2 удалено
  1. 1 2
      src/models/llm_build_qwen3next.cpp

+ 1 - 2
src/models/llm_build_qwen3next.cpp

@@ -407,7 +407,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
 
     // Reshape to [n_tokens, qkv_dim, n_seqs] for proper convolution input format
-    qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, n_tokens, qkv_dim, n_seqs);
+    qkv_mixed = ggml_cont_3d(ctx0, ggml_transpose(ctx0, qkv_mixed), n_tokens, qkv_dim, n_seqs);
     cb(qkv_mixed, "qkv_mixed_for_conv", il);
 
     // Calculate convolution kernel size
@@ -415,7 +415,6 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state, n_seqs);
     cb(conv_states, "conv_states_reshaped", il);
 
-    // Now concatenate along the sequence dimension (dim 0 in Llama.cpp)
     ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);    
     cb(conv_input, "conv_input", il);