Piotr Wilkin 3 месяцев назад
Родитель
Сommit
eb0a15fc9b
1 измененных файлов с 34 добавлено и 35 удалено
  1. 34 35
      src/models/llm_build_qwen3next.cpp

+ 34 - 35
src/models/llm_build_qwen3next.cpp

@@ -374,7 +374,6 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     const int64_t num_v_heads = hparams.ssm_dt_rank;
 
     const int64_t n_seq_tokens = ubatch.n_seq_tokens;
-    const int64_t n_tokens     = ubatch.n_tokens;
 
     GGML_ASSERT(n_seqs != 0);
     GGML_ASSERT(ubatch.equal_seqs());
@@ -388,11 +387,11 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     cb(mixed_ba, "linear_attn_mixed_ba", il);
 
     int64_t       qkvz_new_dim        = 2 * head_k_dim + 2 * head_v_dim * num_v_heads / num_k_heads;
-    ggml_tensor * mixed_qkvz_reshaped = ggml_cont_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_tokens, n_seqs);
+    ggml_tensor * mixed_qkvz_reshaped = ggml_cont_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs);
 
     // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads]
     int64_t       ba_new_dim        = 2 * num_v_heads / num_k_heads;
-    ggml_tensor * mixed_ba_reshaped = ggml_cont_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_tokens, n_seqs);
+    ggml_tensor * mixed_ba_reshaped = ggml_cont_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs);
 
     // Split mixed_ba into b and a (beta and alpha parameters)
     int64_t split_sizes_ba[2] = {
@@ -400,18 +399,18 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
         num_v_heads / num_k_heads   // alpha size
     };
 
-    ggml_tensor * b = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_tokens, n_seqs,
+    ggml_tensor * b = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_seq_tokens, n_seqs,
                         mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], 0);
     cb(b, "b", il);
 
-    ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_tokens, n_seqs,
+    ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_seq_tokens, n_seqs,
                         mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], 
                         split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped));
     cb(a, "a", il);
 
     // Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
-    ggml_tensor * beta  = ggml_cont_3d(ctx0, b, num_v_heads, n_tokens, n_seqs);
-    ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_tokens, n_seqs);
+    ggml_tensor * beta  = ggml_cont_3d(ctx0, b, num_v_heads, n_seq_tokens, n_seqs);
+    ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs);
 
     GGML_ASSERT(ggml_nelements(beta) + ggml_nelements(alpha) == ggml_nelements(mixed_ba));
 
@@ -436,29 +435,29 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
         head_v_dim * num_v_heads / num_k_heads   // z size
     };
 
-    ggml_tensor * query = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_tokens, n_seqs, 
+    ggml_tensor * query = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs, 
                                             mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0));
     cb(query, "q", il);
 
-    ggml_tensor * key = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_tokens, n_seqs,
+    ggml_tensor * key = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs,
                         mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
                         split_sizes_qkvz[0] * sizeof(float)));
     cb(key, "k", il);
 
-    ggml_tensor * value = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_tokens, n_seqs,
+    ggml_tensor * value = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs,
                         mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
                         (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float)));
     cb(value, "v", il);
 
-    ggml_tensor * z = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_tokens, n_seqs,
+    ggml_tensor * z = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs,
                         mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
                         (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float)));
     cb(z, "z", il);
 
     // Reshape value and z to merge head dimensions: [batch, seq_len, num_k_heads, head_v_dim*num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads, head_v_dim]
     ggml_tensor * value_reshaped =
-        ggml_reshape_4d(ctx0, ggml_cont(ctx0, value), head_v_dim, num_v_heads, n_tokens, n_seqs);
-    ggml_tensor * z_reshaped = ggml_reshape_4d(ctx0, ggml_cont(ctx0, z), head_v_dim, num_v_heads, n_tokens, n_seqs);
+        ggml_reshape_4d(ctx0, ggml_cont(ctx0, value), head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
+    ggml_tensor * z_reshaped = ggml_reshape_4d(ctx0, ggml_cont(ctx0, z), head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
 
     GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value_reshaped) +
                     ggml_nelements(z_reshaped) ==
@@ -466,15 +465,15 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
 
     // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions
     // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
-    ggml_tensor * query_flat = ggml_reshape_3d(ctx0, query, head_k_dim * num_k_heads, n_tokens, n_seqs);
+    ggml_tensor * query_flat = ggml_reshape_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
     cb(query_flat, "query_flat", il);
 
     // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
-    ggml_tensor * key_flat = ggml_reshape_3d(ctx0, key, head_k_dim * num_k_heads, n_tokens, n_seqs);
+    ggml_tensor * key_flat = ggml_reshape_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
     cb(key_flat, "key_flat", il);
 
     // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs]
-    ggml_tensor * value_flat = ggml_reshape_3d(ctx0, value_reshaped, head_v_dim * num_v_heads, n_tokens, n_seqs);
+    ggml_tensor * value_flat = ggml_reshape_3d(ctx0, value_reshaped, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
     cb(value_flat, "value_flat", il);
 
     // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs]
@@ -512,7 +511,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     cb(conv_output_proper, "conv_output_proper", il);
 
     conv_output_proper = ggml_permute(ctx0, conv_output_proper, 0, 1, 3, 2);
-    conv_output_proper = ggml_cont_4d(ctx0, conv_output_proper, qkv_dim, 1, n_tokens, n_seqs);
+    conv_output_proper = ggml_cont_4d(ctx0, conv_output_proper, qkv_dim, 1, n_seq_tokens, n_seqs);
 
     ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper);
     cb(conv_output_silu, "conv_output_silu", il);
@@ -530,32 +529,32 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
                                                         ggml_element_size(conv_states_all))));
     cb(conv_states_all, "conv_states_updated", il);
 
-    conv_output_proper = ggml_reshape_2d(ctx0, conv_output_silu, n_tokens * n_seqs, qkv_dim);
+    conv_output_proper = ggml_reshape_2d(ctx0, conv_output_silu, n_seq_tokens * n_seqs, qkv_dim);
     cb(conv_output_proper, "conv_output_final", il);
 
     ggml_tensor * conv_transposed = ggml_transpose(ctx0, conv_output_proper);
     cb(conv_transposed, "conv_transposed", il);
 
-    ggml_tensor * conv_qkv_mix = ggml_cont_2d(ctx0, conv_transposed, qkv_dim, n_tokens * n_seqs);
+    ggml_tensor * conv_qkv_mix = ggml_cont_2d(ctx0, conv_transposed, qkv_dim, n_seq_tokens * n_seqs);
     cb(conv_qkv_mix, "conv_qkv_mix", il);
 
     // Extract the convolved Q, K, V from conv_output
-    ggml_tensor * q_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_tokens * n_seqs,
+    ggml_tensor * q_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs,
         conv_qkv_mix->nb[1], 0);
     cb(q_conv, "q_conv", il);
-    ggml_tensor * k_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_tokens * n_seqs, 
+    ggml_tensor * k_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, 
         conv_qkv_mix->nb[1], head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
     cb(k_conv, "k_conv", il);
-    ggml_tensor * v_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_tokens * n_seqs, 
+    ggml_tensor * v_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, 
         conv_qkv_mix->nb[1], 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
     cb(v_conv, "v_conv", il);
     
     // Unsqueeze them
-    q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_tokens, n_seqs);
-    k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_tokens, n_seqs);
-    v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_tokens, n_seqs);
+    q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
+    k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
+    v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
 
-    beta  = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_tokens, n_seqs);
+    beta  = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs);
 
     ggml_tensor * state = ggml_reshape_4d(ctx0, ssm_states_all, head_dim, head_dim * n_heads, 1, 1);
 
@@ -564,8 +563,8 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
         GGML_ASSERT(num_v_heads % num_k_heads == 0);
         int64_t repeat_factor = num_v_heads / num_k_heads;
 
-        q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_k_heads * repeat_factor, n_tokens, n_seqs);
-        k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_k_heads * repeat_factor, n_tokens, n_seqs);
+        q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
+        k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
     }
 
     cb(q_conv, "q_conv_predelta", il);
@@ -577,12 +576,12 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     cb(attn_out, "attn_out", il);
 
     // The tensors were concatenated 1d, so we need to extract them 1d as well
-    const int64_t output_flat_size = head_dim * n_heads * n_tokens * n_seqs;
+    const int64_t output_flat_size = head_dim * n_heads * n_seq_tokens * n_seqs;
     ggml_tensor * attn_out_1d =
         ggml_view_1d(ctx0, attn_out, output_flat_size, 0);
     cb(attn_out_1d, "attn_out_1d", il);
     
-    ggml_tensor * attn_out_final = ggml_cont(ctx0, ggml_permute(ctx0, ggml_cont_4d(ctx0, attn_out_1d, head_dim, n_tokens, n_heads, n_seqs), 0, 2, 1, 3));
+    ggml_tensor * attn_out_final = ggml_cont(ctx0, ggml_permute(ctx0, ggml_cont_4d(ctx0, attn_out_1d, head_dim, n_seq_tokens, n_heads, n_seqs), 0, 2, 1, 3));
     cb(attn_out_final, "attn_out_final", il);
    
     // Extract the state part (second part of the concatenated tensor)
@@ -600,19 +599,19 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
 
     // Reshape both attn_out_final and z to 2D tensors for normalization
     // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
-    ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, ggml_cont(ctx0, attn_out_final), head_dim, n_heads * n_tokens * n_seqs);
+    ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, ggml_cont(ctx0, attn_out_final), head_dim, n_heads * n_seq_tokens * n_seqs);
 
     // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
-    ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z_reshaped, head_dim, n_heads * n_tokens * n_seqs);
+    ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z_reshaped, head_dim, n_heads * n_seq_tokens * n_seqs);
 
     // Apply gated normalization: self.norm(core_attn_out, z)
     ggml_tensor * attn_out_norm = build_q3n_gated_norm(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
     
     // Reshape back to original dimensions: [n_heads * n_tokens * n_seqs, head_dim] -> [head_dim, n_heads, n_tokens, n_seqs]
-    ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, attn_out_norm, head_dim, n_heads, n_tokens, n_seqs);
+    ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, attn_out_norm, head_dim, n_heads, n_seq_tokens, n_seqs);
 
     // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
-    ggml_tensor * final_output = ggml_reshape_3d(ctx0, gated_output_4d, n_heads * head_dim, n_tokens, n_seqs);
+    ggml_tensor * final_output = ggml_reshape_3d(ctx0, gated_output_4d, n_heads * head_dim, n_seq_tokens, n_seqs);
     cb(final_output, "final_output", il);
 
     // Output projection
@@ -620,7 +619,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     cb(cur, "linear_attn_out", il);
 
     // Reshape back to original dimensions
-    cur = ggml_cont(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens));
+    cur = ggml_cont(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens));
     return cur;
 }