Piotr Wilkin преди 3 месеца
родител
ревизия
1579bcb202
променени са 2 файла, в които са добавени 19 реда и са изтрити 32 реда
  1. 18 31
      src/models/llm_build_qwen3next.cpp
  2. 1 1
      src/models/llm_build_qwen3next.h

+ 18 - 31
src/models/llm_build_qwen3next.cpp

@@ -98,7 +98,7 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
         cb(attn_post_norm, "attn_post_norm", il);
 
         // FFN layer (MoE or dense) - without residual connection
-        cur = build_layer_ffn(attn_post_norm, model, il, false);
+        cur = build_layer_ffn(attn_post_norm, model, il);
         cb(cur, "ffn_out", il);
         
         // Residual connection for FFN - add to the tensor BEFORE post_attention_layernorm
@@ -120,7 +120,6 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
     cur = build_lora_mm(model.output, cur);
 
     cb(cur, "result_output", -1);
-    ggml_set_output(cur);
     res->t_logits = cur;
 
     ggml_build_forward_expand(gf, cur);
@@ -511,13 +510,25 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     // Calculate convolution kernel size
     ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d;
     const int64_t conv_kernel_size = conv_kernel->ne[0];
+    const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
     conv_kernel = ggml_permute(ctx0, conv_kernel, 0, 2, 1, 3);
-    conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state, n_seqs);
+    conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
     cb(conv_states, "conv_states_reshaped", il);
 
     ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);    
     cb(conv_input, "conv_input", il);
 
+    // Update convolution state cache
+    // Extract the last (conv_kernel_size - 1) states from conv_input
+    ggml_tensor * last_conv_states =
+        ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], conv_input->nb[2],
+            n_seq_tokens * (conv_input->nb[0]));
+
+    ggml_build_forward_expand(gf,
+        ggml_cpy(ctx0, last_conv_states, ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs,
+            mctx_cur->get_head() * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all))));
+    cb(conv_states_all, "conv_states_updated", il);
+
     // Apply convolution
     ggml_tensor * conv_output = ggml_conv_1d_dw_f32(ctx0, conv_kernel, conv_input, 1, conv_kernel_size - 1, 1);
     cb(conv_output, "conv_output_raw", il);
@@ -539,19 +550,6 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper);
     cb(conv_output_silu, "conv_output_silu", il);
 
-    // Update convolution state cache
-    // Extract the last (conv_kernel_size - 1) states from conv_input
-    ggml_tensor * last_conv_states =
-        ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, qkv_dim, n_seqs, conv_input->nb[1], conv_input->nb[2],
-                     n_seq_tokens * conv_input->nb[0]);
-
-    ggml_build_forward_expand(gf,
-                              ggml_cpy(ctx0, last_conv_states,
-                                       ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * qkv_dim * n_seqs,
-                                                    mctx_cur->get_head() * (conv_kernel_size - 1) * qkv_dim *
-                                                        ggml_element_size(conv_states_all))));
-    cb(conv_states_all, "conv_states_updated", il);
-
     conv_output_proper = ggml_reshape_2d(ctx0, conv_output_silu, n_seq_tokens * n_seqs, qkv_dim);
     cb(conv_output_proper, "conv_output_final", il);
 
@@ -615,12 +613,10 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     ggml_tensor * state_1d = ggml_view_1d(ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size(attn_out));
     cb(state_1d, "state_1d", il);
     
-    ggml_tensor * new_state = ggml_reshape_4d(ctx0, state_1d, head_dim, head_dim, n_heads, n_seqs);
-    cb(new_state, "new_state", il);
-
     // Update the recurrent states
-    ggml_build_forward_expand(gf, ggml_view_1d(ctx0, mctx_cur->get_s_l(il), hparams.n_embd_s() * n_seqs,
-        hparams.n_embd_s() * mctx_cur->get_head() * ggml_element_size(mctx_cur->get_s_l(il))));
+    ggml_build_forward_expand(gf, 
+        ggml_cpy(ctx0, state_1d, ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
+            hparams.n_embd_s() * mctx_cur->get_head() * ggml_element_size(ssm_states_all))));
 
     // Reshape both attn_out_final and z to 2D tensors for normalization
     // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
@@ -648,7 +644,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     return cur;
 }
 
-ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const llama_model & model, const int il, bool do_residual) {
+ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const llama_model & model, const int il) {
 
     // Check if this is an MoE layer
     if (model.layers[il].ffn_gate_inp != nullptr) {
@@ -697,15 +693,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const llam
                         model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
         cb(cur, "ffn_out", il);
     }
-    // Residual connection (only if requested)
-    if (do_residual) {
-        cur = ggml_add(ctx0, cur, cur);
-        cb(cur, "ffn_residual", il);
-    }
-
-    cur = build_cvec(cur, il);
-    cb(cur, "l_out", il);
-
     return cur;
 };
 

+ 1 - 1
src/models/llm_build_qwen3next.h

@@ -36,7 +36,7 @@ private:
                                                     const llama_ubatch & ubatch,
                                                     int                  il);
 
-    ggml_tensor * build_layer_ffn(ggml_tensor * cur, const llama_model & model, const int il, bool do_residual = true);
+    ggml_tensor * build_layer_ffn(ggml_tensor * cur, const llama_model & model, const int il);
 
     ggml_tensor * softplus(ggml_tensor * alpha, ggml_tensor * dt_bias);