Piotr Wilkin 4 месяцев назад
Родитель
Сommit
178230ee21
2 измененных файлов с 235 добавлено и 120 удалено
  1. 104 48
      ggml/src/ggml.c
  2. 131 72
      src/llama-model.cpp

+ 104 - 48
ggml/src/ggml.c

@@ -3435,7 +3435,7 @@ struct ggml_tensor * ggml_reshape_4d(
         int64_t               ne2,
         int64_t               ne3) {
     GGML_ASSERT(ggml_is_contiguous(a));
-    GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3);
+     GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3);
 
     const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a, 0);
@@ -5441,17 +5441,25 @@ struct ggml_tensor * ggml_delta_net(
     GGML_ASSERT(ggml_is_contiguous(beta));
     GGML_ASSERT(ggml_is_contiguous(state));
     
-    const int64_t S = k->ne[0];
-    const int64_t H = k->ne[1];
+    const int64_t S_k = k->ne[0];
+    const int64_t H_k = k->ne[1];
     const int64_t n_tokens = k->ne[2];
     const int64_t n_seqs = state->ne[1];
     
-    // Validate dimensions
-    GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
-    GGML_ASSERT(q->ne[0] == S && q->ne[1] == H && q->ne[2] == n_tokens);
-    GGML_ASSERT(g->ne[0] == S && g->ne[1] == H && g->ne[2] == n_tokens);
-    GGML_ASSERT(beta->ne[0] == H && beta->ne[1] == n_tokens && beta->ne[2] == n_seqs);
-    GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
+    const int64_t S_v = v->ne[0];
+    const int64_t H_v = v->ne[1];
+    
+    // Validate dimensions - allow different head dimensions for q/k vs v
+    GGML_ASSERT(v->ne[2] == n_tokens);
+    GGML_ASSERT(q->ne[2] == n_tokens);
+    GGML_ASSERT(g->ne[2] == n_tokens);
+    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[1] == n_tokens && (beta->ne[2] == n_seqs || beta->ne[2] == 1));
+    GGML_ASSERT(ggml_nelements(state) == S_v * H_v * n_seqs);
+    
+    // Check that q and k have the same dimensions
+    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens);
+    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens);
+    GGML_ASSERT(g->ne[0] == S_v && g->ne[1] == H_v && g->ne[2] == n_tokens);
     
     // Apply L2 normalization to query and key if requested
     struct ggml_tensor * q_norm = q;
@@ -5466,53 +5474,101 @@ struct ggml_tensor * ggml_delta_net(
     
     // Apply sigmoid to beta for gating
     struct ggml_tensor * beta_sigmoid = ggml_sigmoid(ctx, beta);
-    
-    // Apply causal 1D convolution preprocessing to mixed QKV
-    // Concatenate q, k, v along the feature dimension
-    int64_t concat_ne[4] = { q->ne[0], q->ne[1], q->ne[2], q->ne[3] * 3 };
-    struct ggml_tensor * mixed_qkv = ggml_concat(ctx, q_norm, k_norm, 3);
-    mixed_qkv = ggml_concat(ctx, mixed_qkv, v, 3);
-    
-    // Transpose for convolution: [S, H, n_tokens, n_seqs*3] -> [S, n_tokens, H, n_seqs*3]
-    mixed_qkv = ggml_permute(ctx, mixed_qkv, 0, 2, 1, 3);
-    
-    // Apply causal 1D convolution
-    struct ggml_tensor * conv_out = ggml_conv_1d(
-        ctx,
-        conv_weight,
-        mixed_qkv,
-        1,  // stride
-        conv_weight->ne[2] - 1,  // padding (kernel_size - 1)
-        1   // dilation
-    );
-    
+    struct ggml_tensor * mixed_qkv = ggml_concat(ctx, q_norm, k_norm, 1);
+    mixed_qkv = ggml_concat(ctx, mixed_qkv, v, 1);
+
+    u_int32_t dim = (S_v * H_v) + 2 * (H_k * S_k);
+
+    mixed_qkv = ggml_reshape_3d(ctx, mixed_qkv, 1, dim, n_tokens);
+    struct ggml_tensor * mixed_qkv_padded = ggml_pad(ctx, mixed_qkv, 3, 0, 0, 0);
+
+    // Apply SSM convolution
+    struct ggml_tensor * conv_out = ggml_ssm_conv(ctx, mixed_qkv_padded, conv_weight);
+
     // Apply bias if provided
     if (conv_bias) {
         conv_out = ggml_add(ctx, conv_out, conv_bias);
     }
-    
+
     // Apply SiLU activation
     conv_out = ggml_silu(ctx, conv_out);
-    
-    // Transpose back: [S, n_tokens, H, n_seqs*3] -> [S, H, n_tokens, n_seqs*3]
+
+    // Reshape back to 4D: [dim, n_tokens, 1] -> [dim, n_tokens, 1, 1]
+    conv_out = ggml_reshape_4d(ctx, conv_out, dim, n_tokens, 1, 1);
+
+    // Transpose to get the right layout: [dim, n_tokens, 1] -> [dim, 1, n_tokens, 1]
     conv_out = ggml_permute(ctx, conv_out, 0, 2, 1, 3);
+
+    // q projection view
+    struct ggml_tensor * q_conv = ggml_view_4d(ctx, conv_out,
+                                               S_k,                  // ne0
+                                               H_k,                  // ne1
+                                               conv_out->ne[1],      // ne2 = sequence length (1)
+                                               conv_out->ne[2],      // ne3 = batch (1)
+                                               H_k * sizeof(float),  // nb1 = stride along H_k
+                                               conv_out->nb[1],      // nb2 = stride along sequence dim
+                                               conv_out->nb[2],      // nb3 = stride along batch dim
+                                               0                     // offset in bytes
+    );
+
+    // k projection view
+    struct ggml_tensor * k_conv = ggml_view_4d(ctx, conv_out,
+                                               S_k,                       // ne0
+                                               H_k,                       // ne1
+                                               conv_out->ne[1],           // ne2
+                                               conv_out->ne[2],           // ne3
+                                               H_k * sizeof(float),       // nb1
+                                               conv_out->nb[1],           // nb2
+                                               conv_out->nb[2],           // nb3
+                                               S_k * H_k * sizeof(q->type)  // offset = skip q_out
+    );
+
+    // v projection view
+    struct ggml_tensor * v_conv = ggml_view_4d(ctx, conv_out,
+                                               S_v,                             // ne0
+                                               H_v,                             // ne1
+                                               conv_out->ne[1],                 // ne2
+                                               conv_out->ne[2],                 // ne3
+                                               H_v * sizeof(float),             // nb1
+                                               conv_out->nb[1],                 // nb2
+                                               conv_out->nb[2],                 // nb3
+                                               (2 * S_k * H_k) * sizeof(q->type)// offset = skip q_out + k_out
+    );
+
+    // Transpose each component back to original layout: [S_v, 1, token_split_size, 1] -> [S_v, token_split_size, 1, 1]
+    q_conv = ggml_permute(ctx, q_conv, 0, 2, 1, 3);
+    k_conv = ggml_permute(ctx, k_conv, 0, 2, 1, 3);
+    v_conv = ggml_permute(ctx, v_conv, 0, 2, 1, 3);
+
+    q_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, q_conv), S_k * H_k, 1, n_tokens);
+    k_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, k_conv), S_k * H_k, 1, n_tokens);
+    v_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, v_conv), S_v * H_v, 1, n_tokens);
     
-    // Split the convolved output back into q, k, v components
-    // Split along the last dimension (3 * original size)
-    int64_t split_size = q->ne[3];
-    struct ggml_tensor * q_conv = ggml_view_4d(ctx, conv_out, q->ne[0], q->ne[1], q->ne[2], split_size,
-                                               conv_out->nb[0], conv_out->nb[1], conv_out->nb[2], 0);
-    
-    struct ggml_tensor * k_conv = ggml_view_4d(ctx, conv_out, k->ne[0], k->ne[1], k->ne[2], split_size,
-                                               conv_out->nb[0], conv_out->nb[1], conv_out->nb[2],
-                                               split_size * ggml_type_size(q->type));
+    // NOW we repeat query and key to match value head dimensions if needed (after convolution)
+    struct ggml_tensor * q_broadcast = q_conv;
+    struct ggml_tensor * k_broadcast = k_conv;
     
-    struct ggml_tensor * v_conv = ggml_view_4d(ctx, conv_out, v->ne[0], v->ne[1], v->ne[2], split_size,
-                                               conv_out->nb[0], conv_out->nb[1], conv_out->nb[2],
-                                               2 * split_size * ggml_type_size(q->type));
+    if (H_k != H_v) {
+        // Calculate the repeat factor: H_v / H_k
+        GGML_ASSERT(H_v % H_k == 0);
+        int64_t repeat_factor = H_v / H_k;
+        
+        // Repeat query and key along the head dimension
+        // First reshape to separate the repeat dimension: [S_k, H_k, n_tokens, 1] -> [S_k, 1, H_k, n_tokens]
+        q_broadcast = ggml_reshape_4d(ctx, q_conv, S_k, 1, H_k, n_tokens);
+        k_broadcast = ggml_reshape_4d(ctx, k_conv, S_k, 1, H_k, n_tokens);
+        
+        // Repeat along the new dimension: [S_k, repeat_factor, H_k, n_tokens]
+        q_broadcast = ggml_repeat_4d(ctx, q_broadcast, S_k, repeat_factor, H_k, n_tokens);
+        k_broadcast = ggml_repeat_4d(ctx, k_broadcast, S_k, repeat_factor, H_k, n_tokens);
+        
+        // Reshape back to original dimensions but with H_v heads: [S_k, H_v, n_tokens, 1]
+        q_broadcast = ggml_reshape_4d(ctx, q_broadcast, S_k, H_v, n_tokens, 1);
+        k_broadcast = ggml_reshape_4d(ctx, k_broadcast, S_k, H_v, n_tokens, 1);
+    }
     
     // concat output and new_state
-    const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
+    const int64_t ne[4] = { S_v * H_v, n_tokens + H_v * n_seqs, 1, 1 };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
     
     // Set operation parameters for the delta rule computation
@@ -5520,15 +5576,15 @@ struct ggml_tensor * ggml_delta_net(
         chunk_size,
         use_qk_l2norm ? 1 : 0,
         0, 0,  // reserved
-        0, 0, 0, 0  // scale and other params
+        0, 0, 0  // scale and other params
     };
     memcpy(params + 4, &scale, sizeof(float));
     ggml_set_op_params(result, params, sizeof(params));
     
     // Use custom operation for the gated delta rule computation
     result->op = GGML_OP_DELTA_NET;
-    result->src[0] = q_conv;
-    result->src[1] = k_conv;
+    result->src[0] = q_broadcast;
+    result->src[1] = k_broadcast;
     result->src[2] = v_conv;
     result->src[3] = g;
     result->src[4] = beta_sigmoid;

+ 131 - 72
src/llama-model.cpp

@@ -19049,9 +19049,9 @@ private:
         cb(Kcur, "Kcur", il);
         cb(Vcur, "Vcur", il);
 
-        Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens);
-        Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens);
-        Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens);
+        Qcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Qcur), n_embd_head, hparams.n_head(il), n_tokens);
+        Kcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Kcur), n_embd_head, hparams.n_head_kv(il), n_tokens);
+        Vcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Vcur), n_embd_head, hparams.n_head_kv(il), n_tokens);
 
         // Apply Q/K normalization
         Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
@@ -19079,8 +19079,8 @@ private:
                 Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
         
         // Apply gating
-        gate = ggml_reshape_2d(ctx0, gate, n_embd_q, n_tokens);
-        cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate));
+        gate = ggml_reshape_2d(ctx0, ggml_cont(ctx0, gate), n_embd_q, n_tokens);
+        cur = ggml_cont(ctx0, ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)));
         cb(cur, "attn_gated", il);
         
         return cur;
@@ -19096,59 +19096,102 @@ private:
         const auto   kv_head  = mctx_cur->get_head();
 
         const int64_t d_inner  = hparams.ssm_d_inner;
-        const int64_t d_state  = hparams.ssm_d_state;
         const int64_t n_heads  = hparams.ssm_dt_rank;
         const int64_t head_dim = d_inner / n_heads;
         const int64_t n_seqs   = ubatch.n_seqs;
 
+        const int64_t head_k_dim  = hparams.ssm_d_state;
+        const int64_t head_v_dim  = hparams.ssm_d_state;
+        const int64_t num_k_heads = hparams.ssm_n_group;
+        const int64_t num_v_heads = hparams.ssm_dt_rank;
+
         const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+        const int64_t n_tokens     = ubatch.n_tokens;
 
         GGML_ASSERT(n_seqs != 0);
         GGML_ASSERT(ubatch.equal_seqs());
         GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
 
-        // Input projection for QKV and beta/alpha
-        ggml_tensor * qkvz_ba = build_lora_mm(model.layers[il].ssm_in, cur);
-        cb(qkvz_ba, "linear_attn_in_proj", il);
-
-        // Split into QKV and beta/alpha components
-        const int64_t qkv_size = d_inner * 2 + d_state * 2;
-
-        ggml_tensor * qkv =
-            ggml_view_3d(ctx0, qkvz_ba, qkv_size, n_tokens, 1, qkv_size * sizeof(float), qkvz_ba->nb[1], 0);
-        ggml_tensor * ba = ggml_view_2d(ctx0, qkvz_ba, n_embd, n_tokens, 
-                               qkvz_ba->nb[1], qkv_size * sizeof(float));
-
-        // Reshape QKV for processing
-        qkv = ggml_reshape_3d(ctx0, qkv, head_dim, n_heads * 2 + d_state * 2 / head_dim, n_tokens);
-
-        // Split into individual components
-        ggml_tensor * query =
-            ggml_view_3d(ctx0, qkv, head_dim, n_heads, n_tokens, head_dim * sizeof(float), qkv->nb[1], 0);
-        ggml_tensor * key   = ggml_view_3d(ctx0, qkv, head_dim, n_heads, n_tokens, head_dim * sizeof(float), qkv->nb[1],
-                                           n_heads * head_dim * sizeof(float));
-        ggml_tensor * value = ggml_view_3d(ctx0, qkv, head_dim, n_heads, n_tokens, head_dim * sizeof(float), qkv->nb[1],
-                                           n_heads * head_dim * 2 * sizeof(float));
-
-        // Process beta and alpha parameters (corrected dimensions)
-        ggml_tensor * beta_alpha = build_lora_mm(model.layers[il].ssm_beta_alpha, ba);
-        ggml_tensor * beta =
-            ggml_view_3d(ctx0, beta_alpha, n_heads, n_tokens, n_seqs, n_heads * sizeof(float), beta_alpha->nb[1], 0);
-        ggml_tensor * alpha = ggml_view_3d(ctx0, beta_alpha, n_heads, n_tokens, n_seqs, n_heads * sizeof(float),
-                                           beta_alpha->nb[1], n_heads * sizeof(float));
-
-        // Apply sigmoid to beta (exactly like reference: beta = b.sigmoid())
-        beta = ggml_sigmoid(ctx0, beta);
-
-        ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);        // a + dt_bias
-        ggml_tensor * alpha_exp = ggml_exp(ctx0, alpha_biased);             // exp(a + dt_bias)
-        ggml_tensor * one_tensor = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);  // Create scalar tensor
-        one_tensor = ggml_exp(ctx0, one_tensor); // e^0 = 1
-        ggml_tensor * one_plus_exp = ggml_add1(ctx0, alpha_exp, one_tensor);     // 1 + exp(a + dt_bias)
-        ggml_tensor * alpha_softplus = ggml_log(ctx0, one_plus_exp);       // log(1 + exp(...))
-        ggml_tensor * A_log_exp = ggml_exp(ctx0, model.layers[il].ssm_a);                    // A_log.exp()
-        ggml_tensor * gate_scaled = ggml_mul(ctx0, alpha_softplus, A_log_exp); // A_log.exp() * softplus
-        ggml_tensor * gate = ggml_neg(ctx0, gate_scaled);                   // - (A_log.exp() * softplus)
+        // Input projections
+        ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, cur);
+        cb(mixed_qkvz, "linear_attn_mixed_qkvz", il);
+
+        ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur);
+        cb(mixed_ba, "linear_attn_mixed_ba", il);
+
+        // Reshape mixed_qkvz: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*head_k_dim + 2*head_v_dim*num_v_heads/num_k_heads]
+        int64_t       qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * num_v_heads / num_k_heads;
+        ggml_tensor * mixed_qkvz_reshaped =
+            ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_tokens, n_seqs);
+
+        // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads]
+        int64_t       ba_new_dim        = 2 * num_v_heads / num_k_heads;
+        ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_tokens, n_seqs);
+
+        // Split mixed_qkvz into query, key, value, z
+        int64_t split_sizes_qkvz[4] = {
+            head_k_dim,                              // query size
+            head_k_dim,                              // key size
+            head_v_dim * num_v_heads / num_k_heads,  // value size
+            head_v_dim * num_v_heads / num_k_heads   // z size
+        };
+
+        ggml_tensor * query = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_tokens,
+                                           n_seqs, split_sizes_qkvz[0] * sizeof(float), mixed_qkvz_reshaped->nb[1],
+                                           mixed_qkvz_reshaped->nb[2], 0));
+
+        ggml_tensor * key = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_tokens, n_seqs,
+                                         split_sizes_qkvz[1] * sizeof(float), mixed_qkvz_reshaped->nb[1],
+                                         mixed_qkvz_reshaped->nb[2], split_sizes_qkvz[0] * sizeof(float)));
+
+        ggml_tensor * value =
+            ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_tokens, n_seqs,
+                         split_sizes_qkvz[2] * sizeof(float), mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2],
+                         (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float));
+
+        ggml_tensor * z =
+            ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_tokens, n_seqs,
+                         split_sizes_qkvz[3] * sizeof(float), mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2],
+                         (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float));
+
+        // Reshape value and z to merge head dimensions: [batch, seq_len, num_k_heads, head_v_dim*num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads, head_v_dim]
+        ggml_tensor * value_reshaped = ggml_reshape_4d(ctx0, ggml_cont(ctx0, value), head_v_dim, num_v_heads, n_tokens, n_seqs);
+        ggml_tensor * z_reshaped     = ggml_reshape_4d(ctx0, ggml_cont(ctx0, z), head_v_dim, num_v_heads, n_tokens, n_seqs);
+
+        GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value_reshaped) +
+                        ggml_nelements(z_reshaped) ==
+                    ggml_nelements(mixed_qkvz));
+
+        // Split mixed_ba into b and a (beta and alpha parameters)
+        int64_t split_sizes_ba[2] = {
+            num_v_heads / num_k_heads,  // beta size
+            num_v_heads / num_k_heads   // alpha size
+        };
+
+        ggml_tensor * b =
+            ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_tokens, n_seqs,
+                         split_sizes_ba[0] * sizeof(float), mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], 0);
+
+        ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_tokens, n_seqs,
+                                       split_sizes_ba[1] * sizeof(float), mixed_ba_reshaped->nb[1],
+                                       mixed_ba_reshaped->nb[2], split_sizes_ba[0] * sizeof(float));
+
+        // Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
+        ggml_tensor * beta  = ggml_reshape_3d(ctx0, ggml_cont(ctx0, b), num_v_heads, n_tokens, n_seqs);
+        ggml_tensor * alpha = ggml_reshape_3d(ctx0, ggml_cont(ctx0, a), num_v_heads, n_tokens, n_seqs);
+
+        GGML_ASSERT(ggml_nelements(beta) + ggml_nelements(alpha) == ggml_nelements(mixed_ba));
+
+        // Softplus would be nice...
+        ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);  // a + dt_bias
+        ggml_tensor * alpha_exp    = ggml_exp(ctx0, alpha_biased);                    // exp(a + dt_bias)
+        ggml_tensor * one_tensor   = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);      // Create scalar tensor
+        ggml_exp(ctx0, one_tensor);                                                   // make it a 1
+        ggml_tensor * one_plus_exp   = ggml_add1(ctx0, alpha_exp, one_tensor);        // 1 + exp(a + dt_bias)
+        ggml_tensor * alpha_softplus = ggml_log(ctx0, one_plus_exp);                  // log(1 + exp(...))
+        ggml_tensor * A_log_exp      = ggml_exp(ctx0, model.layers[il].ssm_a);        // A_log.exp()
+        ggml_tensor * gate_scaled    = ggml_mul(ctx0, alpha_softplus, A_log_exp);     // A_log.exp() * softplus
+        ggml_tensor * gate           = ggml_neg(ctx0, gate_scaled);                   // - (A_log.exp() * softplus)
 
         // Get convolution weights and bias
         ggml_tensor * conv_weight = model.layers[il].ssm_conv1d;
@@ -19157,12 +19200,6 @@ private:
         // Get recurrent states (conv_states not needed as it's handled internally by ggml_delta_net)
         ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
 
-        // Reshape tensors to match ggml_delta_net expectations
-        // [S, H, n_tokens, n_seqs] format
-        query = ggml_reshape_4d(ctx0, query, head_dim, n_heads, n_tokens, n_seqs);
-        key   = ggml_reshape_4d(ctx0, key, head_dim, n_heads, n_tokens, n_seqs);
-        value = ggml_reshape_4d(ctx0, value, head_dim, n_heads, n_tokens, n_seqs);
-
         // Beta tensor
         beta = ggml_reshape_3d(ctx0, beta, n_heads, n_tokens, n_seqs);
 
@@ -19170,22 +19207,25 @@ private:
         ggml_tensor * state = ggml_view_4d(ctx0, ssm_states_all, head_dim, head_dim, n_heads, n_seqs,
                                            ssm_states_all->nb[0], ssm_states_all->nb[1], ssm_states_all->nb[2],
                                            kv_head * head_dim * head_dim * n_heads * ggml_element_size(ssm_states_all));
-        state = ggml_cont(ctx0, state);
-        gate = ggml_repeat(ctx0, gate, ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 1, n_heads, n_tokens, n_seqs));
+        state               = ggml_cont(ctx0, state);
+
+        ggml_tensor * target_gate    = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_dim, n_heads, n_tokens, n_seqs);
+        ggml_tensor * gate_broadcast = ggml_reshape_4d(ctx0, gate, 1, n_heads, n_tokens, n_seqs);
+        gate                         = ggml_repeat(ctx0, gate_broadcast, target_gate);
 
-        // Call the new ggml_delta_net function
+        // Call the new ggml_delta_net function with the corrected flow
         ggml_tensor * output = ggml_delta_net(ctx0,
-                                              key,          // k tensor
-                                              value,        // v tensor
-                                              query,        // q tensor
-                                              gate,         // g tensor
-                                              conv_weight,  // conv_weight tensor
-                                              conv_bias,    // conv_bias tensor (can be nullptr)
-                                              beta,         // beta tensor
-                                              state,        // state tensor
-                                              64,           // chunk_size (adjust as needed)
-                                              true,         // use_qk_l2norm
-                                              1.0f          // scale (adjust based on your model)
+                                              key,             // k tensor
+                                              value_reshaped,  // v tensor
+                                              query,           // q tensor
+                                              gate,            // g tensor
+                                              conv_weight,     // conv_weight tensor
+                                              conv_bias,       // conv_bias tensor (can be nullptr)
+                                              beta,            // beta tensor
+                                              state,           // state tensor
+                                              64,              // chunk_size (adjust as needed)
+                                              true,            // use_qk_l2norm
+                                              1.0f             // scale (adjust based on your model)
         );
         cb(output, "delta_net_output", il);
 
@@ -19205,18 +19245,37 @@ private:
                              ctx0, ssm_states_all, head_dim * head_dim * n_heads * n_seqs,
                              kv_head * n_seqs * head_dim * head_dim * n_heads * ggml_element_size(ssm_states_all))));
 
-        // Apply normalization and gating
-        attn_out = build_norm(attn_out, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
+        // Reshape both attn_out and z to 2D tensors for normalization
+        // attn_out: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
+        ggml_tensor * attn_out_2d = ggml_reshape_2d(ctx0, ggml_cont(ctx0, attn_out), head_dim, n_heads * n_tokens * n_seqs);
+
+        // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
+        ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z_reshaped, head_dim, n_heads * n_tokens * n_seqs);
+
+        // Apply gated normalization: self.norm(core_attn_out, z)
+        // This is Qwen3NextRMSNormGated which applies: RMSNorm(x) * silu(gate)
+        ggml_tensor * attn_out_norm = build_norm(attn_out_2d, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
+
+        // Apply silu gate: attn_out_norm * silu(z_2d)
+        ggml_tensor * z_silu       = ggml_silu(ctx0, z_2d);
+        ggml_tensor * gated_output = ggml_mul(ctx0, attn_out_norm, z_silu);
+
+        // Reshape back to original dimensions: [n_heads * n_tokens * n_seqs, head_dim] -> [head_dim, n_heads, n_tokens, n_seqs]
+        ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, gated_output, head_dim, n_heads, n_tokens, n_seqs);
+
+        // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
+        ggml_tensor * final_output = ggml_reshape_3d(ctx0, gated_output_4d, n_heads * head_dim, n_tokens, n_seqs);
 
         // Output projection
-        cur = build_lora_mm(model.layers[il].wo, attn_out);
+        cur = build_lora_mm(model.layers[il].ssm_out, final_output);
         cb(cur, "linear_attn_out", il);
 
         // Reshape back to original dimensions
-        cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
+        cur = ggml_cont(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens));
 
         return cur;
     }
+
     ggml_tensor * build_layer_ffn(ggml_tensor * cur, const llama_model & model, const int il) {
         // Check if this is an MoE layer
         if (model.layers[il].ffn_gate_inp != nullptr) {