Explorar el Código

All's well that ends in a well

Piotr Wilkin hace 3 meses
padre
commit
5306640300
Se han modificado 3 ficheros con 174 adiciones y 8 borrados
  1. 156 4
      src/models/llm_build_qwen3next.cpp
  2. 13 0
      src/models/llm_build_qwen3next.h
  3. 5 4
      tools/main/main.cpp

+ 156 - 4
src/models/llm_build_qwen3next.cpp

@@ -382,6 +382,144 @@ struct ggml_tensor * llm_build_qwen3next::delta_net(
     return result;
 }
 
+// delta_net_recurrent
+// Recurrent version of delta_net for sequence_length = 1
+struct ggml_tensor * llm_build_qwen3next::delta_net_recurrent(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * q,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * g,
+        struct ggml_tensor  * beta,
+        struct ggml_tensor  * state,
+        bool                  use_qk_l2norm,
+        float                 eps_norm,
+        const int             il
+    ) {
+    GGML_ASSERT(ggml_is_contiguous(q));
+    GGML_ASSERT(ggml_is_contiguous(k));
+    GGML_ASSERT(ggml_is_contiguous(v));
+    GGML_ASSERT(ggml_is_contiguous(g));
+    GGML_ASSERT(ggml_is_contiguous(beta));
+    GGML_ASSERT(ggml_is_contiguous(state));
+
+    const int64_t S_k = q->ne[0];
+    const int64_t H_k = q->ne[1];
+    const int64_t n_tokens = q->ne[2];
+    const int64_t n_seqs = q->ne[3];
+
+    const int64_t S_v = v->ne[0];
+    const int64_t H_v = v->ne[1];
+
+    GGML_ASSERT(n_tokens == 1); // Recurrent version only supports sequence_length = 1
+    GGML_ASSERT(v->ne[2] == n_tokens);
+    GGML_ASSERT(k->ne[2] == n_tokens);
+    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
+    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
+    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
+
+    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && q->ne[3] == n_seqs);
+
+    GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
+
+    cb(q, "q_prenorm", il);
+    cb(k, "k_prenorm", il);
+
+    if (use_qk_l2norm) {
+        q = ggml_l2_norm(ctx, q, eps_norm);
+        k = ggml_l2_norm(ctx, k, eps_norm);
+    }
+
+    cb(k, "k_postnorm", il);
+    cb(q, "q_prescale", il);
+
+    float scale = 1.0f / sqrtf(S_v);
+    q = ggml_scale(ctx, q, scale);
+
+    cb(beta, "beta_raw", il);
+    beta = ggml_sigmoid(ctx, beta);
+
+    cb(q, "q_postscale", il);
+    cb(beta, "beta_sigmoid", il);
+
+    // Reshape tensors for recurrent computation
+    // From [S_k, H_k, n_tokens, n_seqs] to [S_k, n_tokens, H_k, n_seqs]
+    q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3));
+    cb(q, "q_reshape", il);
+    k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3));
+    cb(k, "k_reshape", il);
+    v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3));
+    cb(v, "v_reshape", il);
+    
+    beta = ggml_cont(ctx, ggml_permute(ctx, beta, 1, 2, 0, 3));
+    cb(beta, "beta_reshape", il);
+
+    g = ggml_cont(ctx, ggml_permute(ctx, g, 2, 0, 3, 1));
+    cb(g, "g_permute", il);
+
+    ggml_tensor * q_t = ggml_cont_4d(ctx, q, 1, S_k, H_k, n_seqs);
+    ggml_tensor * k_t = ggml_cont_4d(ctx, k, 1, S_k, H_k, n_seqs);
+    ggml_tensor * v_t = ggml_cont_4d(ctx, v, 1, S_v, H_k, n_seqs);
+    ggml_tensor * g_t = ggml_cont_4d(ctx, g, 1, 1, H_k, n_seqs);
+    ggml_tensor * beta_t = ggml_cont_4d(ctx, beta, 1, 1, H_k, n_seqs);
+    state = ggml_cont_4d(ctx, state, S_v, S_v, H_k, n_seqs);
+
+    // Apply exponential to gate: exp(g)
+    ggml_tensor * g_exp = ggml_exp(ctx, g_t);
+    cb(g_exp, "g_exp", il);
+
+    // Apply gate to state: state = state * exp(g)
+    ggml_tensor * gated_state = ggml_mul(ctx, state, g_exp);
+    cb(gated_state, "gated_state", il);
+
+    // Compute kv_memory from state and key
+    // kv_mem = (state * k.unsqueeze(-1)).sum(dim=-2)
+    
+    // Reshape gated_state from [S_v, S_v*H_v, 1, n_seqs] to [S_v, S_v, H_v, n_seqs]
+    // to make it compatible with k_expanded for element-wise multiplication
+    ggml_tensor * gated_state_reshaped = ggml_reshape_4d(ctx, gated_state, S_v, S_v, H_v, n_seqs);
+    cb(gated_state_reshaped, "gated_state_reshaped", il);
+    
+    ggml_tensor * state_k_product = ggml_mul(ctx, gated_state_reshaped, k_t);
+    cb(state_k_product, "state_k_product", il);
+
+    ggml_tensor * kv_memory = ggml_sum_rows(ctx, ggml_cont(ctx, ggml_transpose(ctx, state_k_product)));
+    cb(kv_memory, "kv_memory", il);
+
+    // Compute delta = (v - kv_memory) * beta
+    ggml_tensor * v_diff = ggml_sub(ctx, v_t, kv_memory);
+    ggml_tensor * delta = ggml_mul(ctx, v_diff, beta_t);
+    cb(delta, "delta", il);
+
+    // Update state = state + k * delta
+    // In the reference: last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
+    ggml_tensor * delta_t = ggml_transpose(ctx, delta);
+
+    // Will need to broadcast here since GGML doesn't support auto-double-broadcasting on mul
+    ggml_tensor * delta_t_broadcast = ggml_repeat_4d(ctx, delta_t, S_v, S_v, H_v, n_seqs);
+    ggml_tensor * k_t_broadcast  = ggml_repeat_4d(ctx, k_t, S_v, S_v, H_v, n_seqs);
+    ggml_tensor * k_delta_product = ggml_mul(ctx, k_t_broadcast, delta_t_broadcast);
+    cb(k_delta_product, "k_delta", il);
+
+    ggml_tensor * updated_state = ggml_add(ctx, gated_state_reshaped, k_delta_product);
+    cb(updated_state, "updated_state", il);
+    
+    ggml_tensor * state_q_product = ggml_mul(ctx, updated_state, q_t);
+    cb(state_q_product, "state_q_product", il);
+    ggml_tensor * output = ggml_sum_rows(ctx, ggml_cont(ctx, ggml_transpose(ctx, state_q_product)));
+    cb(output, "output", il);
+
+    // Concatenate output and updated_state into a single tensor
+    // First, flatten both tensors to 1D
+    ggml_tensor * output_1d = ggml_cont_1d(ctx, output, ggml_nelements(output));
+    ggml_tensor * updated_state_1d = ggml_cont_1d(ctx, updated_state, ggml_nelements(updated_state));
+    
+    // Concatenate them: [output, updated_state]
+    ggml_tensor * result = ggml_concat(ctx, output_1d, updated_state_1d, 0);
+    return result;
+}
+
 
 ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_input_rs * inp,
                                                                      ggml_tensor *        cur,
@@ -402,6 +540,8 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
 
     const int64_t n_seq_tokens = ubatch.n_seq_tokens;
 
+    const auto kv_head = mctx_cur->get_head();
+
     GGML_ASSERT(n_seqs != 0);
     GGML_ASSERT(ubatch.equal_seqs());
     GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
@@ -494,6 +634,8 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
     ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
 
+    bool is_generation = mctx_cur->get_rs_z() < 0;
+
     // Build the convolution states tensor
     ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
     cb(conv_states, "conv_states", il);
@@ -528,7 +670,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     cb(last_conv_states, "last_conv_states", il);
 
     ggml_tensor * state_update_target = ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs,
-        mctx_cur->get_head() * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
+        kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
     cb(state_update_target, "state_update_target", il);
 
     ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
@@ -584,6 +726,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
 
     ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
     state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs);
+    cb(state, "state_predelta", il);
 
     // if head keys and value keys are different, repeat to force tensors into matching shapes
     if (num_k_heads != num_v_heads) {
@@ -598,8 +741,15 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     cb(k_conv, "k_conv_predelta", il);
     cb(v_conv, "v_conv_predelta", il);
 
-    // Call the new delta_net function with the corrected flow
-    ggml_tensor * attn_out = delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, true, hparams.f_norm_rms_eps, il);
+    // Choose between delta_net and delta_net_recurrent based on generation mode
+    ggml_tensor * attn_out;
+    if (is_generation) {
+        // Use delta_net_recurrent for single token generation
+        attn_out = delta_net_recurrent(ctx0, q_conv, k_conv, v_conv, gate, beta, state, true, hparams.f_norm_rms_eps, il);
+    } else {
+        // Use regular delta_net for prompt processing
+        attn_out = delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, true, hparams.f_norm_rms_eps, il);
+    }
     cb(attn_out, "attn_out", il);
 
     // The tensors were concatenated 1d, so we need to extract them 1d as well
@@ -621,7 +771,9 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     // Update the recurrent states
     ggml_build_forward_expand(gf, 
         ggml_cpy(ctx0, state_1d, ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
-            hparams.n_embd_s() * mctx_cur->get_head() * ggml_element_size(ssm_states_all))));
+            kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
+
+    GGML_ASSERT(ggml_nelements(attn_out_1d) + ggml_nelements(state_1d) == ggml_nelements(attn_out));
 
     // Reshape both attn_out_final and z to 2D tensors for normalization
     // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]

+ 13 - 0
src/models/llm_build_qwen3next.h

@@ -23,6 +23,19 @@ private:
         float                 eps_norm,
         const int             il);
 
+    // delta_net_recurrent
+    struct ggml_tensor * delta_net_recurrent(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * q,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * g,
+        struct ggml_tensor  * beta,
+        struct ggml_tensor  * state,
+        bool                  use_qk_l2norm,
+        float                 eps_norm,
+        const int             il);
+
     ggml_tensor * build_qwen3next_attention_layer(ggml_tensor *             cur,
                                                   ggml_tensor *             inp_pos,
                                                   llm_graph_input_attn_kv * inp_attn,

+ 5 - 4
tools/main/main.cpp

@@ -242,7 +242,8 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
     if (!ggml_is_quantized(t->type)) {
         uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data();
         ggml_print_tensor(data, t->type, t->ne, t->nb, 3);
-        if (std::string(t->name).substr(0, std::string("post_moe-").size()) == "post_moe-") {
+        if (std::string(t->name).substr(0, std::string("post_moe-").size()) == "post_moe-" || 
+            std::string(t->name).substr(0, std::string("state_1d-").size()) == "state_1d-") {
             if (cb_data->tensors.count(t->name) == 0) {
                 cb_data->tensors[t->name] = 1;
             } else {
@@ -311,9 +312,9 @@ int main(int argc, char ** argv) {
     std::vector<common_chat_msg> chat_msgs;
 
     // load the model and apply lora adapter, if any
-    callback_data cb_data;
-    params.cb_eval = ggml_debug;
-    params.cb_eval_user_data = &cb_data;
+    // callback_data cb_data;
+    // params.cb_eval = ggml_debug;
+    // params.cb_eval_user_data = &cb_data;
     LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__);
     common_init_result llama_init = common_init_from_params(params);