Просмотр исходного кода

Proper multi-sequence convolution calculation, corrected (?) state management

Piotr Wilkin 3 месяцев назад
Родитель
Сommit
4ef6f337de
1 измененных файлов с 46 добавлено и 20 удалено
  1. 46 20
      src/models/llm_build_qwen3next.cpp

+ 46 - 20
src/models/llm_build_qwen3next.cpp

@@ -11,22 +11,45 @@ static ggml_tensor* ggml_conv_1d_dw_f32(
         int           stride,
         int           padding,
         int           dilation) {
-    // Following the pattern from ggml_conv_1d_dw but using F32
-    // Reshape input from [length, channels, batch, dummy] to [length, 1, channels, batch]
-    ggml_tensor* reshaped_input = ggml_reshape_4d(ctx, input, input->ne[0], 1, input->ne[1], input->ne[2]);
 
-    // Apply im2col with F32 destination type to avoid F16 requirement
-    ggml_tensor* im2col_result = ggml_im2col(ctx, kernel, reshaped_input, stride, 0, padding, 0, dilation, 0, false, GGML_TYPE_F32);
+    const int64_t n_seqs = input->ne[2];
+    const int64_t channels = input->ne[1];
 
-    // Now multiply: im2col_result * kernel (following the exact pattern from ggml_conv_1d_dw)
-    // In ggml_conv_1d_dw: ggml_mul_mat(ctx, im2col, a) where a is the kernel
-    ggml_tensor* mul_result = ggml_mul_mat(ctx, im2col_result, kernel);
-
-    // Reshape the result following ggml_conv_1d_dw: [result->ne[0], result->ne[2], 1]
-    ggml_tensor* output_3d = ggml_reshape_3d(ctx, mul_result, mul_result->ne[0], mul_result->ne[2], 1);
-    return output_3d;
+    if (n_seqs > 1) {
+        ggml_tensor* reshaped_input = ggml_reshape_4d(ctx, input, input->ne[0], 1, input->ne[1] * input->ne[2], 1);
+        
+        // For the kernel in [kernel_size, 1, channel_size, 1] format:
+        // 1. First reshape to 2D: [kernel_size, channel_size]
+        ggml_tensor* kernel_2d = ggml_reshape_2d(ctx, kernel, kernel->ne[0], kernel->ne[2]);
+        
+        // 2. Repeat the kernel for each sequence: [kernel_size, channel_size * n_seqs]
+        ggml_tensor* repeated_kernel = ggml_repeat(ctx, kernel_2d, 
+            ggml_new_tensor_2d(ctx, kernel->type, kernel->ne[0], channels * n_seqs));
+        
+        // 3. Reshape back to 4D for im2col: [kernel_size, 1, channels * n_seqs, 1]
+        ggml_tensor* reshaped_kernel = ggml_reshape_4d(ctx, repeated_kernel, kernel->ne[0], 1, channels * n_seqs, 1);
+        
+        // Apply im2col
+        ggml_tensor* im2col_result = ggml_im2col(ctx, reshaped_kernel, reshaped_input, stride, 0, padding, 0, dilation, 0, false, GGML_TYPE_F32);
+        
+        // Multiply with the kernel
+        ggml_tensor* mul_result = ggml_mul_mat(ctx, im2col_result, reshaped_kernel);
+        
+        // Reshape result back to [output_len, channels, n_seqs]
+        ggml_tensor* output_3d = ggml_reshape_3d(ctx, mul_result, mul_result->ne[0], channels, n_seqs);
+        
+        return output_3d;
+    } else {
+        // Single sequence case - kernel is already in [kernel_size, 1, channel_size, 1]
+        ggml_tensor* reshaped_input = ggml_reshape_4d(ctx, input, input->ne[0], 1, input->ne[1], 1);
+        ggml_tensor* im2col_result = ggml_im2col(ctx, kernel, reshaped_input, stride, 0, padding, 0, dilation, 0, false, GGML_TYPE_F32);
+        ggml_tensor* mul_result = ggml_mul_mat(ctx, im2col_result, kernel);
+        ggml_tensor* output_3d = ggml_reshape_3d(ctx, mul_result, mul_result->ne[0], mul_result->ne[2], 1);
+        return output_3d;
+    }
 }
 
+
 llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context_mamba(params) {
     const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -219,7 +242,7 @@ struct ggml_tensor * llm_build_qwen3next::delta_net(
     GGML_ASSERT(k->ne[2] == n_tokens);
     GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
     GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
-    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == 1);
+    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
 
     GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
     GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
@@ -505,9 +528,10 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
                                                         (conv_kernel_size - 1) * ggml_element_size(conv_output));
     cb(conv_output_no_padding, "conv_output_no_padding", il);
 
-    // Take only the first (n_tokens * n_seqs) values
-    ggml_tensor * conv_output_proper = ggml_view_4d(ctx0, conv_output_no_padding, n_tokens * n_seqs, conv_output_no_padding->ne[1], conv_output_no_padding->ne[2], conv_output_no_padding->ne[3],
-                                                        conv_output_no_padding->nb[1], conv_output_no_padding->nb[2], conv_output_no_padding->nb[3], 0);
+    // Take only the last n_seq_tokens values
+    ggml_tensor * conv_output_proper = ggml_view_4d(ctx0, conv_output_no_padding, n_seq_tokens, conv_output_no_padding->ne[1], 
+        conv_output_no_padding->ne[2], conv_output_no_padding->ne[3], conv_output_no_padding->nb[1], 
+        conv_output_no_padding->nb[2], conv_output_no_padding->nb[3], (conv_output_no_padding->ne[0] - n_seq_tokens) * ggml_element_size(conv_output_no_padding));
     cb(conv_output_proper, "conv_output_proper", il);
 
     conv_output_proper = ggml_permute(ctx0, conv_output_proper, 0, 1, 3, 2);
@@ -556,7 +580,8 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
 
     beta  = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs);
 
-    ggml_tensor * state = ggml_reshape_4d(ctx0, ssm_states_all, head_dim, head_dim * n_heads, 1, 1);
+    ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
+    state = ggml_reshape_4d(ctx0, state, head_dim, head_dim * n_heads, 1, n_seqs);
 
     // if head keys and value keys are different, repeat to force tensors into matching shapes
     if (num_k_heads != num_v_heads) {
@@ -594,8 +619,9 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     ggml_tensor * new_state = ggml_reshape_4d(ctx0, state_1d, head_dim, head_dim, n_heads, n_seqs);
     cb(new_state, "new_state", il);
 
-    // Update the recurrent states - we use the new_state directly since it's already the last state
-    ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_state, ssm_states_all));
+    // Update the recurrent states
+    ggml_build_forward_expand(gf, ggml_view_1d(ctx0, mctx_cur->get_s_l(il), hparams.n_embd_s() * n_seqs,
+        hparams.n_embd_s() * mctx_cur->get_head() * ggml_element_size(mctx_cur->get_s_l(il))));
 
     // Reshape both attn_out_final and z to 2D tensors for normalization
     // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
@@ -619,7 +645,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     cb(cur, "linear_attn_out", il);
 
     // Reshape back to original dimensions
-    cur = ggml_cont(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens));
+    cur = ggml_cont(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs));
     return cur;
 }