Piotr Wilkin 3 месяцев назад
Родитель
Сommit
890fa2c1e3
4 измененных файлов с 115 добавлено и 89 удалено
  1. 0 2
      ggml/include/ggml-delta.h
  2. 6 52
      ggml/src/ggml-delta.c
  3. 1 1
      src/llama-context.cpp
  4. 108 34
      src/llama-model.cpp

+ 0 - 2
ggml/include/ggml-delta.h

@@ -23,8 +23,6 @@ GGML_API struct ggml_tensor * ggml_delta_net(struct ggml_context * ctx,
                                              struct ggml_tensor *  v,
                                              struct ggml_tensor *  q,
                                              struct ggml_tensor *  g,
-                                             struct ggml_tensor *  conv_weight,
-                                             struct ggml_tensor *  conv_bias,
                                              struct ggml_tensor *  beta,
                                              struct ggml_tensor *  state,
                                              bool                  use_qk_l2norm,

+ 6 - 52
ggml/src/ggml-delta.c

@@ -17,8 +17,6 @@ struct ggml_tensor * ggml_delta_net(
         struct ggml_tensor  * v,
         struct ggml_tensor  * q,
         struct ggml_tensor  * g,
-        struct ggml_tensor  * conv_weight,
-        struct ggml_tensor  * conv_bias,
         struct ggml_tensor  * beta,
         struct ggml_tensor  * state,
         bool                  use_qk_l2norm,
@@ -55,67 +53,23 @@ struct ggml_tensor * ggml_delta_net(
        
     GGML_ASSERT(g->ne[0] == S_v && g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
        
-    // Merge q, k, v into qkv
-    struct ggml_tensor * mixed_qkv = ggml_concat(ctx, q, k, 1);
-    report_tensor_size("mixed_qkv_qk", mixed_qkv);
-    mixed_qkv = ggml_concat(ctx, mixed_qkv, v, 1);
-    report_tensor_size("mixed_qkv_qkv", mixed_qkv);
-
-    uint32_t dim = (S_v * H_v) + 2 * (H_k * S_k);
-
-    mixed_qkv = ggml_reshape_3d(ctx, mixed_qkv, n_seqs, dim, n_tokens);
-    report_tensor_size("mixed_qkv_reshaped", mixed_qkv);
-    struct ggml_tensor * mixed_qkv_padded = ggml_pad(ctx, mixed_qkv, conv_weight->ne[0] - 1, 0, 0, 0);
-    report_tensor_size("mixed_qkv_padded", mixed_qkv_padded);
-
-    // Apply convolution
-    struct ggml_tensor * conv_out = ggml_ssm_conv(ctx, mixed_qkv_padded, conv_weight);
-    report_tensor_size("conv_out", conv_out);
-
-    if (conv_bias) {
-        conv_out = ggml_add(ctx, conv_out, conv_bias);
-        report_tensor_size("conv_out_bias", conv_out);
-    }
-
-    conv_out = ggml_silu(ctx, conv_out);
-    report_tensor_size("conv_out_silu", conv_out);
-
-    conv_out = ggml_reshape_4d(ctx, conv_out, dim, n_seqs, n_tokens, 1);
-    report_tensor_size("conv_out_reshaped", conv_out);
-
-    conv_out = ggml_permute(ctx, conv_out, 0, 2, 1, 3);
-    report_tensor_size("conv_out_transposed", conv_out);
-
     // Beta sigmoid
     struct ggml_tensor * beta_sigmoid = ggml_sigmoid(ctx, beta);
     report_tensor_size("beta_sigmoid", beta_sigmoid);
 
     // Gate calculations are done elsewhere in llama-model.cpp
 
-    // Re-split the qkv tensors
-    struct ggml_tensor * q_conv = ggml_view_4d(ctx, conv_out, S_k, H_k, conv_out->ne[1], conv_out->ne[2], 
-                                               H_k * sizeof(float), conv_out->nb[1], conv_out->nb[2], 0);
-    report_tensor_size("q_conv_view", q_conv);
-
-    struct ggml_tensor * k_conv = ggml_view_4d(ctx, conv_out, S_k, H_k, conv_out->ne[1], conv_out->ne[2],
-                                               H_k * sizeof(float), conv_out->nb[1], conv_out->nb[2], S_k * H_k * sizeof(q->type));
-    report_tensor_size("k_conv_view", k_conv);
-
-    struct ggml_tensor * v_conv = ggml_view_4d(ctx, conv_out, S_v, H_v, conv_out->ne[1], conv_out->ne[2], H_v * sizeof(float),
-                                               conv_out->nb[1], conv_out->nb[2], (2 * S_k * H_k) * sizeof(q->type));
-    report_tensor_size("v_conv_view", v_conv);
-
-    struct ggml_tensor * q_broadcast = q_conv;
-    struct ggml_tensor * k_broadcast = k_conv;
+    struct ggml_tensor * q_broadcast = q;
+    struct ggml_tensor * k_broadcast = k;
     
     // if head keys and value keys are different, repeat to force tensors into matching shapes
     if (H_k != H_v) {
         GGML_ASSERT(H_v % H_k == 0);
         int64_t repeat_factor = H_v / H_k;
         
-        q_broadcast = ggml_cont_4d(ctx, q_conv, S_k, n_tokens, H_k, n_seqs);
+        q_broadcast = ggml_cont_4d(ctx, q, S_k, n_tokens, H_k, n_seqs);
         report_tensor_size("q_broadcast_reshape1", q_broadcast);
-        k_broadcast = ggml_cont_4d(ctx, k_conv, S_k, n_tokens, H_k, n_seqs);
+        k_broadcast = ggml_cont_4d(ctx, k, S_k, n_tokens, H_k, n_seqs);
         report_tensor_size("k_broadcast_reshape1", k_broadcast);
         
         q_broadcast = ggml_repeat_4d(ctx, q_broadcast, S_k, n_tokens * repeat_factor, H_k, n_seqs);
@@ -129,7 +83,7 @@ struct ggml_tensor * ggml_delta_net(
         report_tensor_size("k_broadcast_reshape2", k_broadcast);
     }
 
-    struct ggml_tensor * v_reshape = ggml_cont_4d(ctx, v_conv, S_v, H_v, n_seqs, n_tokens);
+    struct ggml_tensor * v_reshape = ggml_cont_4d(ctx, v, S_v, H_v, n_seqs, n_tokens);
     report_tensor_size("v_reshape", v_reshape);
     struct ggml_tensor * g_reshape = ggml_cont_4d(ctx, g, S_v, H_v, n_seqs, n_tokens);
     report_tensor_size("g_reshape", g_reshape);
@@ -211,7 +165,7 @@ struct ggml_tensor * ggml_delta_net_op(
     struct ggml_tensor * kv_mem_presum_squeeze = ggml_reshape_4d(ctx, kv_mem_presum, S_v, S_v, H_v, n_seq * n_tokens);
     report_tensor_size("kv_mem_presum_sequeeze", kv_mem_presum_squeeze);
 
-    struct ggml_tensor * kv_mem = ggml_permute(ctx, ggml_sum_rows(ctx, ggml_permute(ctx, kv_mem_presum_squeeze, 1, 2, 0, 3)), 2, 0, 1, 3);
+    struct ggml_tensor * kv_mem = ggml_permute(ctx, ggml_sum_rows(ctx, ggml_cont(ctx, ggml_permute(ctx, kv_mem_presum_squeeze, 1, 2, 0, 3))), 2, 0, 1, 3);
     report_tensor_size("kv_mem", kv_mem);
 
     struct ggml_tensor * kv_mem_reshape = ggml_reshape_4d(ctx, kv_mem, S_v, S_v, n_seq, n_tokens);

+ 1 - 1
src/llama-context.cpp

@@ -1362,7 +1362,7 @@ void llama_context::output_reorder() {
 //
 
 uint32_t llama_context::graph_max_nodes() const {
-    return std::max<uint32_t>(1024u, 128u*model.n_tensors());
+    return std::max<uint32_t>(1024u, 8u*model.n_tensors());
 }
 
 llm_graph_result * llama_context::get_gf_res_reserve() const {

+ 108 - 34
src/llama-model.cpp

@@ -19136,11 +19136,12 @@ private:
             head_v_dim * num_v_heads / num_k_heads   // z size
         };
 
-        ggml_tensor * query = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_tokens,
-                                           n_seqs, split_sizes_qkvz[0] * sizeof(float), mixed_qkvz_reshaped->nb[1],
-                                           mixed_qkvz_reshaped->nb[2], 0));
+        ggml_tensor * query = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads,
+                                                           n_tokens, n_seqs, split_sizes_qkvz[0] * sizeof(float),
+                                                           mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], 0));
 
-        ggml_tensor * key = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_tokens, n_seqs,
+        ggml_tensor * key =
+            ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_tokens, n_seqs,
                                          split_sizes_qkvz[1] * sizeof(float), mixed_qkvz_reshaped->nb[1],
                                          mixed_qkvz_reshaped->nb[2], split_sizes_qkvz[0] * sizeof(float)));
 
@@ -19155,8 +19156,9 @@ private:
                          (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float));
 
         // Reshape value and z to merge head dimensions: [batch, seq_len, num_k_heads, head_v_dim*num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads, head_v_dim]
-        ggml_tensor * value_reshaped = ggml_reshape_4d(ctx0, ggml_cont(ctx0, value), head_v_dim, num_v_heads, n_tokens, n_seqs);
-        ggml_tensor * z_reshaped     = ggml_reshape_4d(ctx0, ggml_cont(ctx0, z), head_v_dim, num_v_heads, n_tokens, n_seqs);
+        ggml_tensor * value_reshaped =
+            ggml_reshape_4d(ctx0, ggml_cont(ctx0, value), head_v_dim, num_v_heads, n_tokens, n_seqs);
+        ggml_tensor * z_reshaped = ggml_reshape_4d(ctx0, ggml_cont(ctx0, z), head_v_dim, num_v_heads, n_tokens, n_seqs);
 
         GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value_reshaped) +
                         ggml_nelements(z_reshaped) ==
@@ -19183,38 +19185,106 @@ private:
         GGML_ASSERT(ggml_nelements(beta) + ggml_nelements(alpha) == ggml_nelements(mixed_ba));
 
         ggml_tensor * alpha_softplus = softplus(alpha, model.layers[il].ssm_dt);
-        ggml_tensor * A_log_exp      = ggml_exp(ctx0, model.layers[il].ssm_a);        // A_log.exp()
-        ggml_tensor * gate_scaled    = ggml_mul(ctx0, alpha_softplus, A_log_exp);     // A_log.exp() * softplus
-        ggml_tensor * gate           = ggml_scale(ctx0, gate_scaled, -1.0f);          // - (A_log.exp() * softplus)
+        ggml_tensor * A_log_exp      = ggml_exp(ctx0, model.layers[il].ssm_a);     // A_log.exp()
+        ggml_tensor * gate_scaled    = ggml_mul(ctx0, alpha_softplus, A_log_exp);  // A_log.exp() * softplus
+        ggml_tensor * gate           = ggml_scale(ctx0, gate_scaled, -1.0f);       // - (A_log.exp() * softplus)
 
-        // Get convolution weights and bias
-        ggml_tensor * conv_weight = model.layers[il].ssm_conv1d;
-        ggml_tensor * conv_bias   = nullptr;  // Add if your model has conv bias
+        // Get convolution states from cache
+        ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
+        ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
+
+        // Build the convolution states tensor
+        ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
+
+        // Calculate convolution kernel size
+        const int64_t conv_kernel_size = model.layers[il].ssm_conv1d->ne[0];
+
+        // Calculate input dimensions for Qwen3Next
+        const int64_t input_dim = (head_k_dim * num_k_heads * 2) + (head_v_dim * num_v_heads);
+
+        // Reshape conv_states to [conv_kernel_size - 1, input_dim, n_seqs]
+        conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, input_dim, n_seqs);
+        cb(conv_states, "conv_states_reshaped", il);
+
+        // Combine query, key, value for convolution input
+        ggml_tensor * qkv_mixed = ggml_concat(ctx0, query, key, 1);
+        qkv_mixed               = ggml_concat(ctx0, qkv_mixed, value_reshaped, 1);
 
-        // Get recurrent states (conv_states not needed as it's handled internally by ggml_delta_net)
-        ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
+        // Reshape to [input_dim, n_seq_tokens, n_seqs] for concatenation
+        qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, input_dim, n_seq_tokens, n_seqs);
+        cb(qkv_mixed, "qkv_mixed_for_conv", il);
+
+        // Concatenate cached conv states with current input
+        // conv_states: [conv_kernel_size - 1, input_dim, n_seqs]
+        // qkv_mixed: [input_dim, n_seq_tokens, n_seqs]
+        // After transpose: [n_seq_tokens, input_dim, n_seqs]
+        ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, ggml_transpose(ctx0, qkv_mixed), 0);
+        cb(conv_input, "conv_input", il);
+
+        // Apply convolution
+        ggml_tensor * conv_output = ggml_ssm_conv(ctx0, conv_input, model.layers[il].ssm_conv1d);
+        cb(conv_output, "conv_output_raw", il);
+
+        if (model.layers[il].ssm_conv1d_b) {
+            conv_output = ggml_add(ctx0, conv_output, model.layers[il].ssm_conv1d_b);
+            cb(conv_output, "conv_output_bias", il);
+        }
+        conv_output = ggml_silu(ctx0, conv_output);
+        cb(conv_output, "conv_output_silu", il);
+
+        // Update convolution state cache
+        // Extract the last (conv_kernel_size - 1) states from conv_input
+        ggml_tensor * last_conv_states =
+            ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, input_dim, n_seqs, conv_input->nb[1],
+                         conv_input->nb[2], n_seq_tokens * conv_input->nb[0]);
+
+        ggml_build_forward_expand(
+            gf, ggml_cpy(ctx0, last_conv_states,
+                         ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * input_dim * n_seqs,
+                                      mctx_cur->get_head() * (conv_kernel_size - 1) * input_dim *
+                                          ggml_element_size(conv_states_all))));
+        cb(conv_states_all, "conv_states_updated", il);
+
+        // Reshape conv_output back to proper dimensions
+        conv_output = ggml_reshape_4d(ctx0, conv_output, input_dim, n_seqs, n_seq_tokens, 1);
+        cb(conv_output, "conv_output_reshaped", il);
+        conv_output = ggml_permute(ctx0, conv_output, 0, 2, 1, 3);
+        cb(conv_output, "conv_output_final", il);
+
+        // Extract the convolved Q, K, V from conv_output
+        ggml_tensor * q_conv = ggml_cont(ctx0, ggml_view_4d(ctx0, conv_output, head_k_dim, num_k_heads, n_tokens, n_seqs,
+                                            head_k_dim, conv_output->nb[1], conv_output->nb[2], 0));
+        cb(q_conv, "q_conv", il);
+        ggml_tensor * k_conv =
+            ggml_cont(ctx0, ggml_view_4d(ctx0, conv_output, head_k_dim, num_k_heads, n_tokens, n_seqs, head_k_dim,
+                         conv_output->nb[1], conv_output->nb[2], head_k_dim * num_k_heads * ggml_element_size(conv_output)));
+        cb(q_conv, "k_conv", il);
+        ggml_tensor * v_conv =
+            ggml_cont(ctx0, ggml_view_4d(ctx0, conv_output, head_v_dim, num_v_heads, n_tokens, n_seqs, head_v_dim,
+                         conv_output->nb[1], conv_output->nb[2], 2 * head_k_dim * num_k_heads * ggml_element_size(conv_output)));
+        cb(q_conv, "v_conv", il);
+
+        ggml_build_forward_expand(gf, ssm_states_all);
 
         // Beta tensor
         beta = ggml_reshape_3d(ctx0, beta, n_heads, n_tokens, n_seqs);
 
-        ggml_tensor * state = ggml_reshape_4d(ctx0, ssm_states_all, head_dim, head_dim * n_heads, 1, 1);
+        ggml_tensor * state           = ggml_reshape_4d(ctx0, ssm_states_all, head_dim, head_dim * n_heads, 1, 1);
         ggml_tensor * state_broadcast = ggml_repeat_4d(ctx0, state, head_dim, head_dim * n_heads, n_seqs, n_tokens);
-        ggml_tensor * target_gate    = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_dim, n_heads, n_tokens, n_seqs);
-        ggml_tensor * gate_broadcast = ggml_reshape_4d(ctx0, gate, 1, n_heads, n_tokens, n_seqs);
-        gate                         = ggml_repeat(ctx0, gate_broadcast, target_gate);
+        ggml_tensor * target_gate     = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_dim, n_heads, n_tokens, n_seqs);
+        ggml_tensor * gate_broadcast  = ggml_reshape_4d(ctx0, gate, 1, n_heads, n_tokens, n_seqs);
+        gate                          = ggml_repeat(ctx0, gate_broadcast, target_gate);
 
         // Call the new ggml_delta_net function with the corrected flow
         ggml_tensor * output = ggml_delta_net(ctx0,
-                                              key,             // k tensor
-                                              value_reshaped,  // v tensor
-                                              query,           // q tensor
-                                              gate,            // g tensor
-                                              conv_weight,     // conv_weight tensor
-                                              conv_bias,       // conv_bias tensor (can be nullptr)
-                                              beta,            // beta tensor
-                                              state_broadcast, // state tensor
-                                              true,            // use_qk_l2norm
-                                              1.0f             // scale (adjust based on your model)
+                                              k_conv,           // k tensor (already convolved)
+                                              v_conv,           // v tensor (already convolved)
+                                              q_conv,           // q tensor (already convolved)
+                                              gate,             // g tensor
+                                              beta,             // beta tensor
+                                              state_broadcast,  // state tensor
+                                              true,             // use_qk_l2norm
+                                              1.0f              // scale
         );
         cb(output, "delta_net_output", il);
 
@@ -19223,20 +19293,24 @@ private:
                                               output->nb[1], output->nb[2], 0);
 
         // Extract the new state
-        ggml_tensor * new_state = ggml_view_4d(ctx0, output, head_dim, head_dim * n_heads, n_tokens, n_seqs, 
-            output->nb[0], output->nb[1], output->nb[2], n_tokens * n_seqs * head_dim * n_heads * ggml_element_size(output));
+        ggml_tensor * new_state =
+            ggml_view_4d(ctx0, output, head_dim, head_dim * n_heads, n_tokens, n_seqs, output->nb[0], output->nb[1],
+                         output->nb[2], n_tokens * n_seqs * head_dim * n_heads * ggml_element_size(output));
 
         // Only return the last recurrent state
-        struct ggml_tensor * state_reshaped = ggml_cont_4d(ctx0, new_state, head_dim, head_dim, n_heads, n_tokens * n_seqs);
-        struct ggml_tensor * state_last = ggml_view_4d(ctx0, state_reshaped, head_dim, head_dim, n_heads, 1, 
-            state_reshaped->nb[1], state_reshaped->nb[2], state_reshaped->nb[3], head_dim * head_dim * n_heads * ((n_seqs * n_tokens) - 1));
+        struct ggml_tensor * state_reshaped =
+            ggml_cont_4d(ctx0, new_state, head_dim, head_dim, n_heads, n_tokens * n_seqs);
+        struct ggml_tensor * state_last = ggml_view_4d(
+            ctx0, state_reshaped, head_dim, head_dim, n_heads, 1, state_reshaped->nb[1], state_reshaped->nb[2],
+            state_reshaped->nb[3], head_dim * head_dim * n_heads * ((n_seqs * n_tokens) - 1));
 
         // Update the recurrent states
         ggml_build_forward_expand(gf, ggml_cpy(ctx0, state_last, ssm_states_all));
 
         // Reshape both attn_out and z to 2D tensors for normalization
         // attn_out: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
-        ggml_tensor * attn_out_2d = ggml_reshape_2d(ctx0, ggml_cont(ctx0, attn_out), head_dim, n_heads * n_tokens * n_seqs);
+        ggml_tensor * attn_out_2d =
+            ggml_reshape_2d(ctx0, ggml_cont(ctx0, attn_out), head_dim, n_heads * n_tokens * n_seqs);
 
         // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
         ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z_reshaped, head_dim, n_heads * n_tokens * n_seqs);