Piotr Wilkin 3 месяцев назад
Родитель
Сommit
17240eafc0
1 измененных файлов с 33 добавлено и 39 удалено
  1. 33 39
      src/models/llm_build_qwen3next.cpp

+ 33 - 39
src/models/llm_build_qwen3next.cpp

@@ -386,14 +386,13 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     const auto * mctx_cur = inp->mctx;
 
     const int64_t d_inner  = hparams.ssm_d_inner;
-    const int64_t n_heads  = hparams.ssm_dt_rank;
-    const int64_t head_dim = d_inner / n_heads;
+    
     const int64_t n_seqs   = ubatch.n_seqs;
 
     const int64_t head_k_dim  = hparams.ssm_d_state;
-    const int64_t head_v_dim  = hparams.ssm_d_state;
     const int64_t num_k_heads = hparams.ssm_n_group;
     const int64_t num_v_heads = hparams.ssm_dt_rank;
+    const int64_t head_v_dim  = d_inner / num_v_heads;
 
     const int64_t n_seq_tokens = ubatch.n_seq_tokens;
 
@@ -408,7 +407,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur);
     cb(mixed_ba, "linear_attn_mixed_ba", il);
 
-    int64_t       qkvz_new_dim        = 2 * head_k_dim + 2 * head_v_dim * num_v_heads / num_k_heads;
+    int64_t       qkvz_new_dim        = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads);
     ggml_tensor * mixed_qkvz_reshaped = ggml_cont_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs);
 
     // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads]
@@ -441,15 +440,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a);  // -A_log.exp() * softplus
     cb(gate, "gate", il);
 
-    // Get convolution states from cache
-    ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
-    ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
-
-    // Build the convolution states tensor
-    ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
-    cb(conv_states, "conv_states", il);
-
-        // Split mixed_qkvz into query, key, value, z
+    // Split mixed_qkvz into query, key, value, z
     int64_t split_sizes_qkvz[4] = {
         head_k_dim,                              // query size
         head_k_dim,                              // key size
@@ -457,47 +448,50 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
         head_v_dim * num_v_heads / num_k_heads   // z size
     };
 
-    ggml_tensor * query = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs, 
-                                            mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0));
+    ggml_tensor * query = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs, 
+                                            mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0);
     cb(query, "q", il);
 
-    ggml_tensor * key = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs,
+    ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs,
                         mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
-                        split_sizes_qkvz[0] * sizeof(float)));
+                        split_sizes_qkvz[0] * sizeof(float));
     cb(key, "k", il);
 
-    ggml_tensor * value = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs,
+    ggml_tensor * value = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs,
                         mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
-                        (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float)));
+                        (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float));
     cb(value, "v", il);
 
-    ggml_tensor * z = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs,
+    ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs,
                         mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
-                        (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float)));
+                        (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float));
     cb(z, "z", il);
 
-    // Reshape value and z to merge head dimensions: [batch, seq_len, num_k_heads, head_v_dim*num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads, head_v_dim]
-    ggml_tensor * value_reshaped =
-        ggml_reshape_4d(ctx0, ggml_cont(ctx0, value), head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
-    ggml_tensor * z_reshaped = ggml_reshape_4d(ctx0, ggml_cont(ctx0, z), head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
-
-    GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value_reshaped) +
-                    ggml_nelements(z_reshaped) ==
+    GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value) +
+                    ggml_nelements(z) ==
                 ggml_nelements(mixed_qkvz));
 
     // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions
     // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
-    ggml_tensor * query_flat = ggml_reshape_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
+    ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
     cb(query_flat, "query_flat", il);
 
     // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
-    ggml_tensor * key_flat = ggml_reshape_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
+    ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
     cb(key_flat, "key_flat", il);
 
     // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs]
-    ggml_tensor * value_flat = ggml_reshape_3d(ctx0, value_reshaped, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
+    ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
     cb(value_flat, "value_flat", il);
 
+    // Get convolution states from cache
+    ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
+    ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
+
+    // Build the convolution states tensor
+    ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
+    cb(conv_states, "conv_states", il);
+
     // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs]
     ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
     qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
@@ -578,7 +572,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     beta  = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs);
 
     ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
-    state = ggml_reshape_4d(ctx0, state, head_dim, head_dim * n_heads, 1, n_seqs);
+    state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs);
 
     // if head keys and value keys are different, repeat to force tensors into matching shapes
     if (num_k_heads != num_v_heads) {
@@ -598,17 +592,17 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     cb(attn_out, "attn_out", il);
 
     // The tensors were concatenated 1d, so we need to extract them 1d as well
-    const int64_t output_flat_size = head_dim * n_heads * n_seq_tokens * n_seqs;
+    const int64_t output_flat_size = head_v_dim * num_v_heads * n_seq_tokens * n_seqs;
     ggml_tensor * attn_out_1d =
         ggml_view_1d(ctx0, attn_out, output_flat_size, 0);
     cb(attn_out_1d, "attn_out_1d", il);
     
-    ggml_tensor * attn_out_final = ggml_cont(ctx0, ggml_permute(ctx0, ggml_cont_4d(ctx0, attn_out_1d, head_dim, n_seq_tokens, n_heads, n_seqs), 0, 2, 1, 3));
+    ggml_tensor * attn_out_final = ggml_cont(ctx0, ggml_permute(ctx0, ggml_cont_4d(ctx0, attn_out_1d, head_v_dim, n_seq_tokens, num_v_heads, n_seqs), 0, 2, 1, 3));
     cb(attn_out_final, "attn_out_final", il);
    
     // Extract the state part (second part of the concatenated tensor)
     // State starts after n_tokens elements along dimension 1
-    const int64_t state_flat_size = head_dim * head_dim * n_heads * n_seqs;
+    const int64_t state_flat_size = head_v_dim * head_v_dim * num_v_heads * n_seqs;
     
     ggml_tensor * state_1d = ggml_view_1d(ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size(attn_out));
     cb(state_1d, "state_1d", il);
@@ -620,19 +614,19 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
 
     // Reshape both attn_out_final and z to 2D tensors for normalization
     // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
-    ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, ggml_cont(ctx0, attn_out_final), head_dim, n_heads * n_seq_tokens * n_seqs);
+    ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, ggml_cont(ctx0, attn_out_final), head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
 
     // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
-    ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z_reshaped, head_dim, n_heads * n_seq_tokens * n_seqs);
+    ggml_tensor * z_2d = ggml_cont_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
 
     // Apply gated normalization: self.norm(core_attn_out, z)
     ggml_tensor * attn_out_norm = build_q3n_gated_norm(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
     
     // Reshape back to original dimensions: [n_heads * n_tokens * n_seqs, head_dim] -> [head_dim, n_heads, n_tokens, n_seqs]
-    ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, attn_out_norm, head_dim, n_heads, n_seq_tokens, n_seqs);
+    ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, attn_out_norm, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
 
     // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
-    ggml_tensor * final_output = ggml_reshape_3d(ctx0, gated_output_4d, n_heads * head_dim, n_seq_tokens, n_seqs);
+    ggml_tensor * final_output = ggml_reshape_3d(ctx0, gated_output_4d, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
     cb(final_output, "final_output", il);
 
     // Output projection