فهرست منبع

Optimization: Qwen3 next autoregressive pass (#17996)

* It's Qwen3 Next, the lean mean token generation machine!

* Apply patches from thread

* Remove recurrent version, only keep chunked and autoregressive

* Remove unnecessary conts and asserts

* Remove more extra conts and asserts

* Cleanup masking
Piotr Wilkin (ilintar) 1 ماه پیش
والد
کامیت
a5251ca11d
2فایلهای تغییر یافته به همراه86 افزوده شده و 271 حذف شده
  1. 5 5
      src/models/models.h
  2. 81 266
      src/models/qwen3next.cpp

+ 5 - 5
src/models/models.h

@@ -441,13 +441,14 @@ private:
                 ggml_tensor * cur,
                 ggml_tensor * causal_mask,
                 ggml_tensor * identity,
+                ggml_tensor * diag_mask,
                         int   il);
 
     ggml_tensor * build_layer_ffn(
                 ggml_tensor * cur,
                         int   il);
 
-    ggml_tensor * build_delta_net_recurrent(
+    ggml_tensor * build_delta_net_chunking(
                 ggml_tensor * q,
                 ggml_tensor * k,
                 ggml_tensor * v,
@@ -456,18 +457,17 @@ private:
                 ggml_tensor * state,
                 ggml_tensor * causal_mask,
                 ggml_tensor * identity,
+                ggml_tensor * diag_mask,
                         int   il);
 
-    ggml_tensor * build_delta_net_chunking(
+    ggml_tensor * build_delta_net_autoregressive(
                 ggml_tensor * q,
                 ggml_tensor * k,
                 ggml_tensor * v,
                 ggml_tensor * g,
                 ggml_tensor * beta,
                 ggml_tensor * state,
-                ggml_tensor * causal_mask,
-                ggml_tensor * identity,
-                        int   il);
+                int           il);
 
     ggml_tensor * build_norm_gated(
                 ggml_tensor * input,

+ 81 - 266
src/models/qwen3next.cpp

@@ -17,13 +17,15 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
     ggml_tensor * inp_out_ids = build_inp_out_ids();
 
     ggml_tensor * causal_mask =
-        ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens, ubatch.n_seq_tokens), 1.0f),
+        ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f),
                     GGML_TRI_TYPE_LOWER);
 
-    ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens), 1.0f));
+    ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f));
+    ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity);
 
     ggml_build_forward_expand(gf, causal_mask);
     ggml_build_forward_expand(gf, identity);
+    ggml_build_forward_expand(gf, diag_mask);
 
     for (int il = 0; il < n_layer; ++il) {
         ggml_tensor * inpSA = inpL;
@@ -34,7 +36,7 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
         // Determine layer type and build appropriate attention mechanism
         if (hparams.is_recurrent(il)) {
             // Linear attention layer (gated delta net)
-            cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, il);
+            cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il);
         } else {
             // Full attention layer
             cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il);
@@ -93,14 +95,8 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
         ggml_tensor * state,
         ggml_tensor * causal_mask,
         ggml_tensor * identity,
+        ggml_tensor * diag_mask,
         int           il) {
-    GGML_ASSERT(ggml_is_contiguous(q));
-    GGML_ASSERT(ggml_is_contiguous(k));
-    GGML_ASSERT(ggml_is_contiguous(v));
-    GGML_ASSERT(ggml_is_contiguous(g));
-    GGML_ASSERT(ggml_is_contiguous(beta));
-    GGML_ASSERT(ggml_is_contiguous(state));
-
     const int64_t S_k      = q->ne[0];
     const int64_t H_k      = q->ne[1];
     const int64_t n_tokens = q->ne[2];
@@ -120,15 +116,10 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
 
     GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
 
-    // TODO: can this ever be false?
-    const bool use_qk_l2norm = true;
-
-    if (use_qk_l2norm) {
-        const float eps_norm = hparams.f_norm_rms_eps;
+    const float eps_norm = hparams.f_norm_rms_eps;
 
-        q = ggml_l2_norm(ctx0, q, eps_norm);
-        k = ggml_l2_norm(ctx0, k, eps_norm);
-    }
+    q = ggml_l2_norm(ctx0, q, eps_norm);
+    k = ggml_l2_norm(ctx0, k, eps_norm);
 
     const float scale = 1.0f / sqrtf(S_v);
 
@@ -136,8 +127,6 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
 
     beta = ggml_sigmoid(ctx0, beta);
 
-    ggml_tensor * causal_diag_mask = ggml_add(ctx0, causal_mask, identity);
-
     cb(q, "q_in", il);
     cb(k, "k_in", il);
     cb(v, "v_in", il);
@@ -188,36 +177,21 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
     cb(v_beta, "v_beta", il);
     cb(k_beta, "k_beta", il);
 
-    ggml_tensor * chunked_mask =
-        ggml_view_4d(ctx0, causal_mask, chunk_size,
-                chunk_size,         causal_mask->ne[2], causal_mask->ne[3],
-                causal_mask->nb[1], causal_mask->nb[2], causal_mask->nb[3], 0);
+    q      = ggml_reshape_4d(ctx0, q,      S_k, chunk_size, n_chunks, H_k * n_seqs);
+    k      = ggml_reshape_4d(ctx0, k,      S_k, chunk_size, n_chunks, H_k * n_seqs);
+    k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs);
+    v      = ggml_reshape_4d(ctx0, v,      S_v, chunk_size, n_chunks, H_v * n_seqs);
+    v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs);
 
-    ggml_tensor * chunked_diag_mask =
-        ggml_view_4d(ctx0, causal_diag_mask, chunk_size,
-                chunk_size,              causal_diag_mask->ne[2], causal_diag_mask->ne[3],
-                causal_diag_mask->nb[1], causal_diag_mask->nb[2], causal_diag_mask->nb[3], 0);
-
-    ggml_tensor * chunked_identity =
-        ggml_view_4d(ctx0, identity, chunk_size,
-            chunk_size,      identity->ne[2], identity->ne[3],
-            identity->nb[1], identity->nb[2], identity->nb[3], 0);
-
-    q      = ggml_cont_4d(ctx0, q,      S_k, chunk_size, n_chunks, H_k * n_seqs);
-    k      = ggml_cont_4d(ctx0, k,      S_k, chunk_size, n_chunks, H_k * n_seqs);
-    k_beta = ggml_cont_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs);
-    v      = ggml_cont_4d(ctx0, v,      S_v, chunk_size, n_chunks, H_v * n_seqs);
-    v_beta = ggml_cont_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs);
-
-    g    = ggml_cont_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs);
-    beta = ggml_cont_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
+    g    = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs);
+    beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
 
     ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
 
     cb(g_cumsum, "g_cumsum", il);
 
-    ggml_tensor * gcs_i = ggml_cont_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs);
-    ggml_tensor * gcs_j = ggml_cont_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
+    ggml_tensor * gcs_i = ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs);
+    ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
 
     ggml_tensor * gcs_j_broadcast =
         ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
@@ -226,23 +200,23 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
 
     cb(decay_mask, "decay_mask", il);
 
-    decay_mask = ggml_mul(ctx0, decay_mask, chunked_diag_mask);
+    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
     decay_mask = ggml_exp(ctx0, decay_mask);
-    decay_mask = ggml_mul(ctx0, decay_mask, chunked_diag_mask);
+    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
 
     ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
 
     ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
-    ggml_tensor * attn    = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, chunked_mask));
+    ggml_tensor * attn    = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
 
     cb(attn, "attn_pre_solve", il);
 
-    ggml_tensor * attn_lower = ggml_mul(ctx0, attn, chunked_mask);
-    ggml_tensor * lhs        = ggml_sub(ctx0, ggml_repeat(ctx0, chunked_identity, attn_lower), attn_lower);
+    ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
+    ggml_tensor * lhs        = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
 
     ggml_tensor * lin_solve  = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
-    attn                     = ggml_mul(ctx0, lin_solve, chunked_mask);
-    attn                     = ggml_add(ctx0, attn, chunked_identity);
+    attn                     = ggml_mul(ctx0, lin_solve, causal_mask);
+    attn                     = ggml_add(ctx0, attn, identity);
 
     cb(attn, "attn_solved", il);
 
@@ -291,7 +265,7 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
         // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
         attn = ggml_mul_mat(ctx0, k_chunk, q_chunk);
         attn = ggml_mul(ctx0, attn, decay_mask_chunk);
-        attn = ggml_mul(ctx0, attn, ggml_add(ctx0, chunked_identity, chunked_mask));
+        attn = ggml_mul(ctx0, attn, diag_mask);
 
         ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
 
@@ -361,23 +335,14 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
     return ggml_concat(ctx0, flat_output, flat_state, 0);
 }
 
-ggml_tensor * llm_build_qwen3next::build_delta_net_recurrent(
+ggml_tensor * llm_build_qwen3next::build_delta_net_autoregressive(
         ggml_tensor * q,
         ggml_tensor * k,
         ggml_tensor * v,
         ggml_tensor * g,
         ggml_tensor * beta,
         ggml_tensor * state,
-        ggml_tensor * causal_mask,
-        ggml_tensor * identity,
         int           il) {
-    GGML_ASSERT(ggml_is_contiguous(q));
-    GGML_ASSERT(ggml_is_contiguous(k));
-    GGML_ASSERT(ggml_is_contiguous(v));
-    GGML_ASSERT(ggml_is_contiguous(g));
-    GGML_ASSERT(ggml_is_contiguous(beta));
-    GGML_ASSERT(ggml_is_contiguous(state));
-
     const int64_t S_k      = q->ne[0];
     const int64_t H_k      = q->ne[1];
     const int64_t n_tokens = q->ne[2];
@@ -386,6 +351,7 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_recurrent(
     const int64_t S_v = v->ne[0];
     const int64_t H_v = v->ne[1];
 
+    GGML_ASSERT(n_tokens == 1);  // This function is optimized for single token processing
     GGML_ASSERT(v->ne[2] == n_tokens);
     GGML_ASSERT(k->ne[2] == n_tokens);
     GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
@@ -397,215 +363,65 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_recurrent(
 
     GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
 
-    // TODO: can this ever be false?
-    const bool use_qk_l2norm = true;
-
-    if (use_qk_l2norm) {
-        const float eps_norm = hparams.f_norm_rms_eps;
+    const float eps_norm = hparams.f_norm_rms_eps;
 
-        q = ggml_l2_norm(ctx0, q, eps_norm);
-        k = ggml_l2_norm(ctx0, k, eps_norm);
-    }
+    q = ggml_l2_norm(ctx0, q, eps_norm);
+    k = ggml_l2_norm(ctx0, k, eps_norm);
 
     const float scale = 1.0f / sqrtf(S_v);
 
-    q = ggml_scale(ctx0, q, scale);
-
+    q    = ggml_scale(ctx0, q, scale);
     beta = ggml_sigmoid(ctx0, beta);
 
-    ggml_tensor * causal_diag_mask = ggml_add(ctx0, causal_mask, identity);
-
     cb(q, "q_in", il);
     cb(k, "k_in", il);
     cb(v, "v_in", il);
     cb(beta, "beta_in", il);
     cb(g, "g_in", il);
 
-    q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
-    k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
-    v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
-    g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
-
-    beta  = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
     state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
 
-    cb(q, "q_perm", il);
-    cb(k, "k_perm", il);
-    cb(v, "v_perm", il);
-    cb(beta, "beta_perm", il);
-    cb(g, "g_perm", il);
-    cb(state, "state_in", il);
-
-    GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
-    GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
-    GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
-    GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
-
-    ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
-    ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
-
-    ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
-
-    cb(k_beta, "k_beta", il);
-    cb(v_beta, "v_beta", il);
-    cb(g_cumsum, "g_cumsum", il);
-
-    ggml_tensor * gcs_i = ggml_cont_4d(ctx0, g_cumsum, n_tokens, 1, H_v, n_seqs);  // [chunk_size, 1, n_tokens, n_seqs]
-    ggml_tensor * gcs_j = ggml_cont_4d(ctx0, g_cumsum, 1, n_tokens, H_v, n_seqs);  // [1, chunk_size, n_tokens, n_seqs]
-
-    // Broadcast both tensors to [chunk_size, chunk_size, H_v, n_seqs]
-    // ggml_tensor * gcs_i_broadcast =
-    //     ggml_repeat_4d(ctx0, gcs_i, GGML_DELTA_NET_CHUNK, GGML_DELTA_NET_CHUNK, num_chunks * H_v,
-    //                     n_seqs);  // [chunk_size, 1, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
-    // Don't need this, this one will get auto-broadcast
-    ggml_tensor * gcs_j_broadcast =
-        ggml_repeat_4d(ctx0, gcs_j, n_tokens, n_tokens, H_v, n_seqs);  // [1, chunk_size, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
-
-    ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
-
-    // Apply lower triangular mask to ensure attention is causal (only past tokens influence current)
-    decay_mask = ggml_mul(ctx0, decay_mask, causal_diag_mask);
-    // Apply exponential to get the decay mask values
-    decay_mask = ggml_exp(ctx0, decay_mask);
-    // Apply lower triangular mask again to ensure only lower triangular values remain
-    decay_mask = ggml_mul(ctx0, decay_mask, causal_diag_mask);
-
-    cb(decay_mask, "decay_mask", il);
-
-    // attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
-    ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
-
-    cb(kmulkbeta, "kmulkbeta", il);
-
-    ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
-    ggml_tensor * attn    = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
-
-    cb(attn, "attn_pre_rec", il);
-
-    // for i in range(1, chunk_size):
-    //          row = attn[..., i, :i].clone()
-    //          sub = attn[..., :i, :i].clone()
-    //          attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
-    // attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
-    //
-    // We reduce this to a linear triangular solve: AX = B, where B = attn, A = I - tril(A)
-    ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
-    ggml_tensor * lhs        = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
-
-    ggml_tensor * lin_solve  = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
-    attn                     = ggml_mul(ctx0, lin_solve, causal_mask);
-    attn                     = ggml_add(ctx0, attn, identity);
-
-    // value = attn @ v_beta
-    v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
-
-    cb(v, "value_beta", il);
-
-    // k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
-    ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
-    ggml_tensor * gexp       = ggml_exp(ctx0, g_cumsum_t);
-
-    cb(gexp, "g_cum_exp", il);
-
-    ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
-
-    cb(kbeta_gexp, "kbeta_gexp", il);
-
-    ggml_tensor * k_cumdecay =
-        ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)))));
-
-    cb(k_cumdecay, "k_cumdecay", il);
-
-    // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
-    attn = ggml_mul_mat(ctx0, k, q);
-    attn = ggml_mul(ctx0, attn, decay_mask);
-    attn = ggml_mul(ctx0, attn, ggml_add(ctx0, identity, causal_mask));
-
-    cb(attn, "attn_decay_key", il);
-
-    ggml_tensor * state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
-
-    // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
-    ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay);
-
-    cb(v_prime, "v_prime", il);
-
-    // v_new = v_i - v_prime
-    ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v, v_prime), v_prime);
-
-    ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
-
-    cb(v_new, "v_new", il);
-
-    // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
-    ggml_tensor * q_g_exp    = ggml_mul(ctx0, q, gexp);
-    ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
-
-    cb(attn_inter, "attn_inter", il);
-
-    // core_attn_out[:, :, i] = attn_inter + attn @ v_new
-    ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn);
-
-    cb(v_attn, "v_attn", il);
-
-    ggml_tensor * core_attn_out = ggml_add(ctx0, attn_inter, v_attn);
-
-    cb(core_attn_out, "core_attn_out", il);
-
-    // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
-    // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
-    // key_gdiff = key * g_diff.unsqueeze(-1)
-    // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
-    // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
-
-    ggml_tensor * g_cum_last =
-        ggml_cont(ctx0, ggml_view_4d(ctx0, g_cumsum_t, g_cumsum_t->ne[0], 1, g_cumsum_t->ne[2], g_cumsum_t->ne[3],
-                                    g_cumsum_t->nb[1], g_cumsum_t->nb[2], g_cumsum_t->nb[3],
-                                    g_cumsum_t->nb[0] * (g_cumsum_t->ne[1] - 1)));
-
-    cb(g_cum_last, "g_cum_last", il);
-
-    ggml_tensor * gexp_last =
-        ggml_reshape_4d(ctx0, ggml_exp(ctx0, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]);
-
-    cb(gexp_last, "gexp_last", il);
-
-    ggml_tensor * g_cum_last_3d =
-        ggml_reshape_3d(ctx0, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]);
-
-    cb(g_cum_last_3d, "g_cum_last_3d", il);
-
-    ggml_tensor * g_cumsum_3d = ggml_reshape_3d(ctx0, g_cumsum, g_cumsum->ne[0], g_cumsum->ne[2], g_cumsum->ne[3]);
-
-    cb(g_cumsum_3d, "g_cumsum_3d", il);
-
-    ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum_3d, g_cum_last_3d));
-
-    cb(g_diff, "g_diff", il);
-
-    ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
-
-    cb(g_diff_exp, "g_diff_exp", il);
-
-    ggml_tensor * key_gdiff = ggml_mul(ctx0, k,
-                                    ggml_reshape_4d(ctx0, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1],
-                                                    g_diff_exp->ne[2] * g_diff_exp->ne[3]));
-
-    cb(key_gdiff, "key_gdiff", il);
-
-    ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)));
-
-    cb(kgdmulvnew, "kgdmulvnew", il);
-
-    state = ggml_add(ctx0, ggml_mul(ctx0, state, gexp_last), kgdmulvnew);
-
+    ggml_tensor * g_t    = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs);
+    ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
+
+    // Apply exponential to g_t
+    g_t = ggml_exp(ctx0, g_t);
+
+    // Apply the gated delta rule for the single timestep
+    // last_recurrent_state = last_recurrent_state * g_t
+    state = ggml_mul(ctx0, state, g_t);
+
+    // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
+    ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs);
+    ggml_tensor * kv_mem         = ggml_mul(ctx0, state, k_t_unsqueezed);
+    // we need to sum over dim=-2, so we transpose, sum, then transpose again
+    kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem))));
+
+    // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v)
+    ggml_tensor * v_t    = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
+    // delta = (v_t - kv_mem) * beta_t
+    ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem);  // both should be [S_v, 1, H_v, n_seqs]
+    ggml_tensor * delta  = ggml_mul(ctx0, v_diff, beta_t);
+
+    // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta
+    ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta);
+    state                   = ggml_add(ctx0, state, k_t_delta);
+
+    // Compute the attention output
+    // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
+    ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs);  // unsqueeze q_t
+    ggml_tensor * state_q        = ggml_mul(ctx0, state, q_t_unsqueezed);
+    // again, since it's over dim = -2, transpose, sum, transpose back
+    ggml_tensor * core_attn_out =
+        ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q))));
+
+    // core_attn_out should be [S_v, 1, H_v, n_seqs] after this
+    cb(core_attn_out, "output_tokens", il);
     cb(state, "new_state", il);
 
-    // flatten output
-    ggml_tensor * flat_output =
-        ggml_cont_1d(ctx0, ggml_permute(ctx0, core_attn_out, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs);
-
-    ggml_tensor * flat_state = ggml_cont_1d(ctx0, state, S_v * S_v * H_v * n_seqs);
+    // flatten output, no need to permute since n_tokens is 1 so [S_v, 1, H_v, n_seqs] and [S_v, H_v, 1, n_seqs] are equivalent memory-layout wise
+    ggml_tensor * flat_output = ggml_reshape_1d(ctx0, core_attn_out, S_v * H_v * n_tokens * n_seqs);
+    ggml_tensor * flat_state  = ggml_reshape_1d(ctx0, state, S_v * S_v * H_v * n_seqs);
 
     return ggml_concat(ctx0, flat_output, flat_state, 0);
 }
@@ -712,6 +528,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
         ggml_tensor *        cur,
         ggml_tensor *        causal_mask,
         ggml_tensor *        identity,
+        ggml_tensor *        diag_mask,
         int                  il) {
     const auto * mctx_cur = inp->mctx;
 
@@ -737,11 +554,11 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     cb(mixed_ba, "linear_attn_mixed_ba", il);
 
     int64_t       qkvz_new_dim        = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads);
-    ggml_tensor * mixed_qkvz_reshaped = ggml_cont_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs);
+    ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs);
 
     // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads]
     int64_t       ba_new_dim        = 2 * num_v_heads / num_k_heads;
-    ggml_tensor * mixed_ba_reshaped = ggml_cont_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs);
+    ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs);
 
     // Split mixed_ba into b and a (beta and alpha parameters)
     int64_t split_sizes_ba[2] = {
@@ -762,8 +579,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     ggml_tensor * beta  = ggml_cont_3d(ctx0, b, num_v_heads, n_seq_tokens, n_seqs);
     ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs);
 
-    GGML_ASSERT(ggml_nelements(beta) + ggml_nelements(alpha) == ggml_nelements(mixed_ba));
-
     ggml_tensor * alpha_biased   = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
     ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
     cb(alpha_softplus, "a_softplus", il);
@@ -799,9 +614,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
                                    (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float));
     cb(z, "z", il);
 
-    GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value) + ggml_nelements(z) ==
-                ggml_nelements(mixed_qkvz));
-
     // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions
     // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
     ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
@@ -925,10 +737,13 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     cb(k_conv, "k_conv_predelta", il);
     cb(v_conv, "v_conv_predelta", il);
 
-    // Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens
-    ggml_tensor * attn_out = n_seq_tokens > CHUNK_SIZE ?
-        build_delta_net_chunking (q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, il) :
-        build_delta_net_recurrent(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, il);
+    // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
+    ggml_tensor * attn_out;
+    if (n_seq_tokens == 1) {
+        attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
+    } else {
+        attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
+    }
     cb(attn_out, "attn_out", il);
 
     // The tensors were concatenated 1d, so we need to extract them 1d as well