Просмотр исходного кода

Correct convolution state dimension calculations

Piotr Wilkin 3 месяцев назад
Родитель
Сommit
27fa5f335d
2 измененных файлов с 13 добавлено и 17 удалено
  1. 1 1
      convert_hf_to_gguf.py
  2. 12 16
      src/llama-model.cpp

+ 1 - 1
convert_hf_to_gguf.py

@@ -3759,7 +3759,7 @@ class Qwen3NextModel(Qwen3MoeModel):
         self.gguf_writer.add_ssm_state_size(self.find_hparam(["linear_key_head_dim"]))
         self.gguf_writer.add_ssm_group_count(self.find_hparam(["linear_num_key_heads"]))
         self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["linear_num_value_heads"]))
-        self.gguf_writer.add_ssm_inner_size(self.find_hparam(["hidden_size"]) * (self.find_hparam(["linear_num_value_heads"]) // self.find_hparam(["linear_num_key_heads"])))
+        self.gguf_writer.add_ssm_inner_size(self.find_hparam(['linear_value_head_dim']) * self.find_hparam(['linear_num_value_heads']))
 
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
         if name.startswith("mtp"):

+ 12 - 16
src/llama-model.cpp

@@ -19240,7 +19240,6 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
         ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur);
         cb(mixed_ba, "linear_attn_mixed_ba", il);
 
-        // Reshape mixed_qkvz: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*head_k_dim + 2*head_v_dim*num_v_heads/num_k_heads]
         int64_t       qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * num_v_heads / num_k_heads;
         ggml_tensor * mixed_qkvz_reshaped =
             ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_tokens, n_seqs);
@@ -19327,23 +19326,20 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
         // Build the convolution states tensor
         ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
         cb(conv_states, "conv_states", il);
+        
+        // Combine query, key, value for convolution input
+        ggml_tensor * qkv_mixed = ggml_concat(ctx0, query, key, 1);
+        qkv_mixed               = ggml_concat(ctx0, qkv_mixed, value_reshaped, 1);
+
+        int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
 
         // Calculate convolution kernel size
         const int64_t conv_kernel_size = model.layers[il].ssm_conv1d->ne[0];
-
-        // Calculate input dimensions for Qwen3Next
-        const int64_t input_dim = (head_k_dim * num_k_heads * 2) + (head_v_dim * num_v_heads);
-
-        // Reshape conv_states to [conv_kernel_size - 1, input_dim, n_seqs]
-        conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, input_dim, n_seqs);
+        conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state, n_seqs);
         cb(conv_states, "conv_states_reshaped", il);
 
-        // Combine query, key, value for convolution input
-        ggml_tensor * qkv_mixed = ggml_concat(ctx0, query, key, 1);
-        qkv_mixed               = ggml_concat(ctx0, qkv_mixed, value_reshaped, 1);
-
         // Reshape to [input_dim, n_seq_tokens, n_seqs] for concatenation
-        qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, input_dim, n_seq_tokens, n_seqs);
+        qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_dim, n_seq_tokens, n_seqs);
         cb(qkv_mixed, "qkv_mixed_for_conv", il);
 
         // Concatenate cached conv states with current input
@@ -19367,18 +19363,18 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
         // Update convolution state cache
         // Extract the last (conv_kernel_size - 1) states from conv_input
         ggml_tensor * last_conv_states =
-            ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, input_dim, n_seqs, conv_input->nb[1],
+            ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, qkv_dim, n_seqs, conv_input->nb[1],
                          conv_input->nb[2], n_seq_tokens * conv_input->nb[0]);
 
         ggml_build_forward_expand(
             gf, ggml_cpy(ctx0, last_conv_states,
-                         ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * input_dim * n_seqs,
-                                      mctx_cur->get_head() * (conv_kernel_size - 1) * input_dim *
+                         ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * qkv_dim * n_seqs,
+                                      mctx_cur->get_head() * (conv_kernel_size - 1) * qkv_dim *
                                           ggml_element_size(conv_states_all))));
         cb(conv_states_all, "conv_states_updated", il);
 
         // Reshape conv_output back to proper dimensions
-        conv_output = ggml_reshape_4d(ctx0, conv_output, input_dim, n_seqs, n_seq_tokens, 1);
+        conv_output = ggml_reshape_4d(ctx0, conv_output, qkv_dim, n_seqs, n_seq_tokens, 1);
         cb(conv_output, "conv_output_reshaped", il);
         conv_output = ggml_permute(ctx0, conv_output, 0, 2, 1, 3);
         cb(conv_output, "conv_output_final", il);