Просмотр исходного кода

Added: tri, cumsum. Still a mess.

Piotr Wilkin 3 месяцев назад
Родитель
Сommit
7ec2df64a4

+ 39 - 0
ggml/include/ggml.h

@@ -243,6 +243,8 @@
 
 #define GGML_MROPE_SECTIONS   4
 
+#define GGML_DELTA_NET_CHUNK    64
+
 #define GGML_UNUSED(x) (void)(x)
 #ifdef __CUDACC__
 template<typename... Args>
@@ -472,6 +474,7 @@ extern "C" {
         GGML_OP_COS,
         GGML_OP_SUM,
         GGML_OP_SUM_ROWS,
+        GGML_OP_CUMSUM,
         GGML_OP_MEAN,
         GGML_OP_ARGMAX,
         GGML_OP_COUNT_EQUAL,
@@ -527,6 +530,7 @@ extern "C" {
         GGML_OP_TIMESTEP_EMBEDDING,
         GGML_OP_ARGSORT,
         GGML_OP_LEAKY_RELU,
+        GGML_OP_TRI,
 
         GGML_OP_FLASH_ATTN_EXT,
         GGML_OP_FLASH_ATTN_BACK,
@@ -539,6 +543,7 @@ extern "C" {
         GGML_OP_RWKV_WKV6,
         GGML_OP_GATED_LINEAR_ATTN,
         GGML_OP_RWKV_WKV7,
+        GGML_OP_DELTA_NET,
 
         GGML_OP_UNARY,
 
@@ -612,6 +617,13 @@ extern "C" {
         GGML_TENSOR_FLAG_LOSS   =  8, // ...defines loss for numerical optimization (multiple loss tensors add up)
     };
 
+    enum ggml_tri_type {
+        GGML_TRI_TYPE_UPPER_DIAG        = 0,
+        GGML_TRI_TYPE_UPPER             = 1,
+        GGML_TRI_TYPE_LOWER_DIAG        = 2,
+        GGML_TRI_TYPE_LOWER             = 3
+    };
+
     struct ggml_init_params {
         // memory pool
         size_t mem_size;   // bytes
@@ -975,6 +987,10 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_cumsum(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
     // mean along rows
     GGML_API struct ggml_tensor * ggml_mean(
             struct ggml_context * ctx,
@@ -2119,6 +2135,17 @@ extern "C" {
             int                   shift2,
             int                   shift3);
 
+    // Make matrix into a triangular one (upper, upper + diagonal, lower or lower + diagonal) with constant value
+    GGML_API struct ggml_tensor * ggml_tri(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            float                 constant,
+            enum ggml_tri_type    tritype);
+
+    GGML_API struct ggml_tensor * ggml_tri_keep(
+            struct ggml_context * ctx,
+            struct ggml_tensor * a,
+            enum ggml_tri_type tritype);
 
     // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
     // timesteps: [N,]
@@ -2289,6 +2316,18 @@ extern "C" {
             struct ggml_tensor  * b,
             struct ggml_tensor  * state);
 
+    GGML_API struct ggml_tensor * ggml_delta_net(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * q,
+            struct ggml_tensor  * k,
+            struct ggml_tensor  * v,
+            struct ggml_tensor  * g,
+            struct ggml_tensor  * beta,
+            struct ggml_tensor  * state,
+            bool                  use_qk_l2norm,
+            float                 scale,
+            float                 eps_norm);
+
     // custom operators
 
     typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);

+ 15 - 0
ggml/src/ggml-cpu/ggml-cpu.c

@@ -1731,6 +1731,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_sum_rows(params, tensor);
             } break;
+        case GGML_OP_CUMSUM:
+            {
+                ggml_compute_forward_cumsum(params, tensor);
+            } break;
         case GGML_OP_MEAN:
             {
                 ggml_compute_forward_mean(params, tensor);
@@ -1943,6 +1947,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_leaky_relu(params, tensor);
             } break;
+        case GGML_OP_TRI:
+            {
+                ggml_compute_forward_tri(params, tensor);
+            } break;
         case GGML_OP_FLASH_ATTN_EXT:
             {
                 ggml_compute_forward_flash_attn_ext(params, tensor);
@@ -1998,6 +2006,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_rwkv_wkv7(params, tensor);
             } break;
+        case GGML_OP_DELTA_NET:
+            {
+                ggml_compute_forward_delta_net_f32(params, tensor);
+            } break;
         case GGML_OP_MAP_CUSTOM1:
             {
                 ggml_compute_forward_map_custom1(params, tensor);
@@ -2153,6 +2165,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_SUM_ROWS:
         case GGML_OP_MEAN:
         case GGML_OP_ARGMAX:
+        case GGML_OP_CUMSUM:
+        case GGML_OP_TRI:
             {
                 n_tasks = 1;
             } break;
@@ -2297,6 +2311,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_WIN_PART:
         case GGML_OP_WIN_UNPART:
         case GGML_OP_GET_REL_POS:
+        case GGML_OP_DELTA_NET:
             {
                 n_tasks = 1;
             } break;

+ 684 - 1
ggml/src/ggml-cpu/ops.cpp

@@ -9,6 +9,7 @@
 
 #include <float.h>
 #include <algorithm>
+#include <cmath>
 
 // ggml_compute_forward_dup
 
@@ -2171,6 +2172,57 @@ void ggml_compute_forward_sum(
     }
 }
 
+// ggml_compute_forward_cumsum
+
+static void ggml_compute_forward_cumsum_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT(dst->nb[0] == sizeof(float));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT(ne0 == ne00);
+    GGML_ASSERT(ne1 == ne01);
+    GGML_ASSERT(ne2 == ne02);
+    GGML_ASSERT(ne3 == ne03);
+
+    for (int64_t i3 = 0; i3 < ne03; i3++) {
+        for (int64_t i2 = 0; i2 < ne02; i2++) {
+            for (int64_t i1 = 0; i1 < ne01; i1++) {
+                float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
+                float * dst_row = (float *) ((char *) dst->data  + i1*nb1  + i2*nb2  + i3*nb3);
+                ggml_vec_cumsum_f32(ne00, dst_row, src_row);
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_cumsum(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_cumsum_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 // ggml_compute_forward_sum_rows
 
 static void ggml_compute_forward_sum_rows_f32(
@@ -2917,6 +2969,49 @@ static void ggml_compute_forward_gelu(
     }
 }
 
+// ggml_compute_tri
+
+static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+
+    ggml_tri_type ttype = (ggml_tri_type) dst->op_params[0];
+    float c = *((float *) &(dst->op_params[1]));
+    bool keep_org_val = isnan(c);
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(src0->ne[0] == src0->ne[1]);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    const auto [ir0, ir1] = get_thread_range(params, src0);
+
+    for (int64_t ir = ir0; ir < ir1; ++ir) {
+        const int64_t i03 = ir/(ne02*ne01);
+        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+        float        * dst_ptr  = (float  *)       ((char *)       dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+        float        * src  = (float  *)           ((char *)       src0->data  + i03*nb03  + i02*nb02  + i01*nb01 );
+        ggml_vec_tri_f32(ne0, i01, dst_ptr, src, keep_org_val, c, ttype);
+    }
+
+}
+
+void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_tri_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 // ggml_compute_forward_gelu_erf
 
 static void ggml_compute_forward_gelu_erf_f32(
@@ -10362,8 +10457,596 @@ void ggml_compute_forward_gla(
     }
 }
 
-// ggml_compute_forward_rwkv_wkv7
+// Helper function to compute cumulative sum
+static void ggml_cumsum_f32(const float * x, float * dst, const int64_t n) {
+    float cumsum = 0.0f;
+    for (int64_t i = 0; i < n; i++) {
+        cumsum += x[i];
+        dst[i] = cumsum;
+    }
+}
+
+// Helper function for matrix multiplication
+static void ggml_matmul_f32(const float * a, const float * b, float * dst,
+                              const int64_t m, const int64_t n, const int64_t k) {
+    for (int64_t i = 0; i < m; i++) {
+        for (int64_t j = 0; j < n; j++) {
+            float sum = 0.0f;
+            for (int64_t l = 0; l < k; l++) {
+                sum += a[i * k + l] * b[l * n + j];
+            }
+            dst[i * n + j] = sum;
+        }
+    }
+}
+
+// Helper function to create upper triangular mask
+static void ggml_create_upper_triangular_mask(bool * mask, const int64_t size) {
+    for (int64_t i = 0; i < size; i++) {
+        for (int64_t j = 0; j < size; j++) {
+            mask[i * size + j] = (j >= i); // upper triangular with diagonal
+        }
+    }
+}
+
+// Helper function to compute chunk decay mask
+static void ggml_compute_chunk_decay_mask_f32(const float * g_cumsum, float * decay_mask,
+                                                 const int64_t chunk_size) {
+    for (int64_t i = 0; i < chunk_size; i++) {
+        for (int64_t j = 0; j < chunk_size; j++) {
+            if (i >= j) { // Only compute for lower triangular (including diagonal)
+                float g_diff = g_cumsum[i] - g_cumsum[j];
+                decay_mask[i * chunk_size + j] = expf(-g_diff);
+            } else {
+                decay_mask[i * chunk_size + j] = 0.0f; // Causal mask
+            }
+        }
+    }
+}
+
+// Helper function to compute k_beta @ key.T
+static void ggml_compute_k_beta_key_t_f32(const float * k_beta, const float * key,
+                                             float * k_beta_key_t,
+                                             const int64_t chunk_size, const int64_t k_head_dim) {
+    for (int64_t i = 0; i < chunk_size; i++) {
+        for (int64_t j = 0; j < chunk_size; j++) {
+            float sum = 0.0f;
+            for (int64_t d = 0; d < k_head_dim; d++) {
+                int64_t k_beta_idx = i * k_head_dim + d;
+                int64_t key_idx = j * k_head_dim + d;
+                sum += k_beta[k_beta_idx] * key[key_idx];
+            }
+            k_beta_key_t[i * chunk_size + j] = sum;
+        }
+    }
+}
+
+// Helper function to apply triangular updates
+static void ggml_apply_triangular_updates_f32(float * attn, const int64_t chunk_size) {
+    for (int64_t i = 1; i < chunk_size; i++) {
+        for (int64_t j = 0; j < i; j++) {
+            float sum = 0.0f;
+            for (int64_t k = 0; k < i; k++) {
+                sum += attn[i * chunk_size + k] * attn[k * chunk_size + j];
+            }
+            attn[i * chunk_size + j] += sum;
+        }
+    }
+}
+
+// Helper function to add identity matrix
+static void ggml_add_identity_matrix_f32(float * matrix, const int64_t size) {
+    for (int64_t i = 0; i < size; i++) {
+        matrix[i * size + i] += 1.0f;
+    }
+}
+
+// Helper function to compute value = attn @ v_beta
+static void ggml_compute_value_f32(const float * attn, const float * v_beta,
+                                      float * value,
+                                      const int64_t chunk_size, const int64_t v_head_dim) {
+    for (int64_t i = 0; i < chunk_size; i++) {
+        for (int64_t d = 0; d < v_head_dim; d++) {
+            float sum = 0.0f;
+            for (int64_t j = 0; j < chunk_size; j++) {
+                int64_t v_beta_idx = j * v_head_dim + d;
+                sum += attn[i * chunk_size + j] * v_beta[v_beta_idx];
+            }
+            value[i * v_head_dim + d] = sum;
+        }
+    }
+}
+
+// Helper function to compute k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
+static void ggml_compute_k_cumdecay_f32(const float * attn, const float * k_beta, const double * g_exp,
+                                           float * k_cumdecay,
+                                           const int64_t chunk_size, const int64_t k_head_dim) {
+    for (int64_t i = 0; i < chunk_size; i++) {
+        for (int64_t d = 0; d < k_head_dim; d++) {
+            float sum = 0.0f;
+            for (int64_t j = 0; j < chunk_size; j++) {
+                int64_t k_beta_idx = j * k_head_dim + d;
+                sum += attn[i * chunk_size + j] * k_beta[k_beta_idx] * g_exp[j];
+            }
+            k_cumdecay[i * k_head_dim + d] = sum;
+        }
+    }
+}
+// Helper functions for delta net computation
+
+// Matrix multiplication helper for delta net
+static void ggml_delta_net_matmul_f32(const float * a, const int64_t rows_a, const int64_t cols_a, const int64_t cols_b,
+                           const float * b, float * result) {
+    for (int64_t i = 0; i < rows_a; i++) {
+        for (int64_t j = 0; j < cols_b; j++) {
+            float sum = 0.0f;
+            for (int64_t k = 0; k < cols_a; k++) {
+                int64_t a_idx = i * cols_a + k;
+                int64_t b_idx = k * cols_b + j;
+                sum += a[a_idx] * b[b_idx];
+            }
+            result[i * cols_b + j] = sum;
+        }
+    }
+}
+
+void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml_tensor * dst) {
+    const struct ggml_tensor * src0 = dst->src[0];  // q (already normalized and scaled)
+    const struct ggml_tensor * src1 = dst->src[1];  // k (already normalized)
+    const struct ggml_tensor * src2 = dst->src[2];  // v
+    const struct ggml_tensor * src3 = dst->src[3];  // g (cumsum)
+    const struct ggml_tensor * src4 = dst->src[4];  // state
+    const struct ggml_tensor * src5 = dst->src[5];  // decay_mask
+    const struct ggml_tensor * src6 = dst->src[6];  // v_beta
+    const struct ggml_tensor * src7 = dst->src[7];  // k_beta
+    const struct ggml_tensor * src8 = dst->src[8];  // attn
+
+    const int64_t H_v               = (int64_t) dst->op_params[0];
+    const int64_t S_k               = (int64_t) dst->op_params[1];
+    const int64_t S_v               = (int64_t) dst->op_params[2];
+    const int64_t original_n_tokens = (int64_t) dst->op_params[3];  // Get original sequence length
+    const int64_t H_k               = H_v;
+    const int64_t n_tokens          = original_n_tokens;            // Use the original sequence length
+    const int64_t n_seqs            = src0->ne[3];                  // q tensor has n_seqs in dim 3
+
+    // Add assertions to verify tensor dimensions
+    GGML_ASSERT(src0->ne[3] == n_seqs);  // q tensor
+    GGML_ASSERT(src1->ne[3] == n_seqs);  // k tensor
+    GGML_ASSERT(src2->ne[3] == n_seqs);  // v tensor
+    GGML_ASSERT(src3->ne[3] == n_seqs);  // g tensor
+    GGML_ASSERT(src4->ne[3] == n_seqs);  // beta tensor
+
+    float * dst_data  = (float *) dst->data;
+    // Following GLA pattern: output is first part, state is second part
+    float * output    = dst_data;  // [S_v * H_v, n_tokens, 1, 1] - only real sequence length, not padded
+    float * new_state = dst_data + (S_v * H_v * n_tokens);  // [S_v * H_v, S_v * n_seqs, 1, 1]
+
+    const int ith = params->ith;
+    // const int nth = params->nth;  // nth is unused
+
+    // For chunked implementation, we process all sequences in thread 0 for simplicity
+    // This can be optimized later to parallelize across sequences
+    if (ith != 0) {
+        return;
+    }
+
+    // Clear output and new state section
+    memset(output, 0, ((S_v * H_v * n_tokens) + (S_v * H_v * S_v * n_seqs)) * sizeof(float));
+
+    // Get tensor data pointers
+    float * state_data = (float *) src4->data;
+    float * decay_mask = (float *) src5->data;
+
+    // Allocate temporary buffers for computation
+    const int64_t chunk_size = GGML_DELTA_NET_CHUNK;
+    // The first dimension is the chunk_size, second is head_dim, third is num_heads, fourth is n_seqs
+    // Note: In reference Python implementation, tensors are padded to multiple of chunk_size
+    // but the output only contains the real sequence length, not the padded length
+
+    // Calculate the actual padded sequence length for internal processing
+    const int64_t pad_size              = (chunk_size - n_tokens % chunk_size) % chunk_size;
+    const int64_t total_sequence_length = n_tokens + pad_size;
+    const int64_t n_chunks              = (total_sequence_length + chunk_size - 1) / chunk_size;  // Ceiling division
+
+    // Temporary buffers for each chunk
+    std::vector<float>  attn(chunk_size * chunk_size, 0.0f);
+    std::vector<float>  value(chunk_size * S_v, 0.0f);
+    std::vector<float>  k_cumdecay(chunk_size * S_k, 0.0f);
+    std::vector<double> g_exp(chunk_size, 0.0f);
+    std::vector<float>  g_cumsum(chunk_size, 0.0f);
+    std::vector<float>  last_state(S_v * S_v * H_v, 0.0f);
+
+    // Initialize last_state with input state data
+    // State format in GGML: [S_v, S_v * H_v, 1, 1] where S_v * H_v = S_v * num_heads
+    // The state tensor has format [S_v, S_v * H_v, 1, 1] where second dimension is S_v * num_heads
+    // For delta_net, S_k == S_v (both k and v have the same head dimension)
+    for (int64_t h = 0; h < H_v; h++) {
+        for (int64_t d1 = 0; d1 < S_v; d1++) {
+            for (int64_t d2 = 0; d2 < S_v; d2++) {
+                // GGML state index: [d1, d2 + h*S_v, 0, 0] in flattened form
+                int64_t ggml_state_idx         = d1 * (S_v * H_v) + (d2 + h * S_v);
+                // Our computed state index: [d1, d2 + h*S_v]
+                int64_t computed_state_idx     = d1 * (S_v * H_v) + (d2 + h * S_v);
+                last_state[computed_state_idx] = state_data[ggml_state_idx];
+            }
+        }
+    }
+
+    // Maintain running cumulative sum across all chunks
+    std::vector<float> running_cumsum(n_tokens, 0.0f);
+
+    // Process each chunk
+    for (int64_t chunk_idx = 0; chunk_idx < n_chunks; chunk_idx++) {
+        // Process each head and sequence
+        for (int64_t h = 0; h < H_k; h++) {
+            for (int64_t seq = 0; seq < n_seqs; seq++) {
+                // Extract chunk data for this head and sequence
+                std::vector<float> q_chunk(chunk_size * S_k);
+                std::vector<float> k_chunk(chunk_size * S_k);
+                std::vector<float> v_chunk(chunk_size * S_v);
+                std::vector<float> v_beta_chunk(chunk_size * S_v);
+                std::vector<float> k_beta_chunk(chunk_size * S_k);
+                std::vector<float> g_chunk(chunk_size);
+
+                // Initialize chunks with zeros for padding
+                std::fill(q_chunk.begin(), q_chunk.end(), 0.0f);
+                std::fill(k_chunk.begin(), k_chunk.end(), 0.0f);
+                std::fill(v_chunk.begin(), v_chunk.end(), 0.0f);
+                std::fill(v_beta_chunk.begin(), v_beta_chunk.end(), 0.0f);
+                std::fill(k_beta_chunk.begin(), k_beta_chunk.end(), 0.0f);
+                std::fill(g_chunk.begin(), g_chunk.end(), 0.0f);
+
+                // Determine actual tokens in this chunk
+                int64_t tokens_in_chunk = std::min(chunk_size, n_tokens - chunk_idx * chunk_size);
+
+                // Copy data for this chunk
+                for (int64_t t = 0; t < tokens_in_chunk; t++) {
+                    int64_t actual_pos = chunk_idx * chunk_size + t;  // Position in the original sequence
+
+                    // Only copy if this position is within the original sequence length
+                    if (actual_pos < n_tokens) {
+                        // Calculate indices in GGML format [chunk_size, head_dim, num_heads, n_seqs]
+                        for (int64_t d = 0; d < S_k; d++) {
+                            q_chunk[t * S_k + d]      = ggml_get_f32_nd(src0, actual_pos, d, h, seq);
+                            k_chunk[t * S_k + d]      = ggml_get_f32_nd(src1, actual_pos, d, h, seq);
+                            k_beta_chunk[t * S_k + d] = ggml_get_f32_nd(src7, actual_pos, d, h, seq);
+                        }
+
+                        for (int64_t d = 0; d < S_v; d++) {
+                            v_chunk[t * S_v + d]      = ggml_get_f32_nd(src2, actual_pos, d, h, seq);
+                            v_beta_chunk[t * S_v + d] = ggml_get_f32_nd(src6, actual_pos, d, h, seq);
+                        }
+
+                        if (actual_pos <
+                            n_tokens) {  // Only copy if this position is within the original sequence length
+                            // Use the safe GGML function to access tensor values
+                            g_chunk[t] = ggml_get_f32_nd(src3, actual_pos, 0, h, seq);
+                        } else {
+                            // For padded positions, set to 0 (or a default value)
+                            g_chunk[t] = 0.0f;
+                        }
+                    } else {
+                        // For padded positions beyond original sequence, set to 0
+                        for (int64_t d = 0; d < S_k; d++) {
+                            q_chunk[t * S_k + d]      = 0.0f;
+                            k_chunk[t * S_k + d]      = 0.0f;
+                            k_beta_chunk[t * S_k + d] = 0.0f;
+                        }
+                        for (int64_t d = 0; d < S_v; d++) {
+                            v_chunk[t * S_v + d]      = 0.0f;
+                            v_beta_chunk[t * S_v + d] = 0.0f;
+                        }
+                        g_chunk[t] = 0.0f;
+                    }
+                }
+
+                // In Python, cumsum is applied to each chunk separately after reshaping
+                // So we need to compute cumsum within this chunk only
+
+                // g_chunk already contains the cumsum values from src3 (g_cumsum), so use them directly
+                for (int64_t i = 0; i < tokens_in_chunk; i++) {  // Only for actual tokens, not full chunk size
+                    g_cumsum[i] = g_chunk[i];
+                }
+
+                // For padded positions, set cumsum values to 0
+                for (int64_t i = tokens_in_chunk; i < chunk_size; i++) {
+                    g_cumsum[i] = 0.0f;
+                }
+
+                // Compute g_exp from cumulative sums (like Python: g.cumsum().exp())
+                // Apply numerical stability to prevent underflow for very negative values
+                for (int64_t i = 0; i < tokens_in_chunk; i++) {  // Only for actual tokens, not full chunk size
+                    // Use double precision for exponential to avoid overflow/underflow
+                    // Apply lower bound to prevent extreme underflow - exp(-50) is about 1.9e-22
+                    double g_val        = (double) g_cumsum[i];
+                    double g_exp_double = exp(g_val);
+                    g_exp[i]            = g_exp_double;
+                }
+
+                // For padded positions, set exp values to 0
+                for (int64_t i = tokens_in_chunk; i < chunk_size; i++) {
+                    g_exp[i] = 0.0f;
+                }
+                // Step 1: Compute k_beta @ key.T (this corresponds to the Python: k_beta @ key.transpose(-1, -2))
+                // Only compute for actual tokens in chunk
+                ggml_compute_k_beta_key_t_f32(k_beta_chunk.data(), k_chunk.data(), attn.data(), tokens_in_chunk,
+                                              S_k);  // Use actual tokens, not full chunk size
+
+                // Apply precomputed decay mask from src5 and negate the result (like Python: -(...))
+                // The decay mask is computed in ggml_delta_net in ggml.c and passed as src5
+                // Apply the precomputed decay mask from src5 (decay_mask tensor)
+                // The decay_mask tensor now contains exp(g_cumsum[j] - g_cumsum[i]) values
+                // where g_cumsum[j] - g_cumsum[i] is computed in the main function
+                for (int64_t i = 0; i < tokens_in_chunk; i++) {      // Only for actual tokens
+                    for (int64_t j = 0; j < tokens_in_chunk; j++) {  // Only for actual tokens
+                        // Get decay mask value from precomputed tensor
+                        // src5 decay_mask has shape [chunk_size, chunk_size, H_k, n_seqs] in GGML format
+                        // Format: [i_pos, j_pos, head, seq] - represents exp(g_cumsum[j] - g_cumsum[i])
+                        float decay_val = ggml_get_f32_nd(
+                            src5, i, j, h, seq);  // [i, j, h, seq] to get exp(g_cumsum[j] - g_cumsum[i]) for head h
+                        if (j <= i) {             // Only apply to lower triangular part (i >= j)
+                            // The decay_val already contains exp(g_cumsum[j] - g_cumsum[i]), no need for additional exponential
+                            // Apply the decay mask and negate (like Python: -((k_beta @ key.T) * decay_mask))
+                            attn[i * chunk_size + j] = -attn[i * chunk_size + j] * decay_val;
+                        } else {
+                            attn[i * chunk_size + j] =
+                                0.0f;  // Zero out upper triangular part (like Python: masked_fill(mask, 0))
+                        }
+                    }
+                }
+
+                // Step 2: Apply triangular updates (equivalent to Python's complex triangular update)
+                // Python: for i in range(1, chunk_size):
+                //           row = attn[..., i, :i].clone()  // row = attn[i, 0:i]
+                //           sub = attn[..., :i, :i].clone() // sub = attn[0:i, 0:i]
+                //           attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
+                // This means: new_attn[i, j] = old_attn[i, j] + sum_k(old_attn[i, k] * old_attn[k, j]) for k < i
+                for (int64_t i = 1; i < tokens_in_chunk; i++) {
+                    // Store the original row values to avoid using updated values in computation
+                    std::vector<float> original_row(i);
+                    for (int64_t j = 0; j < i; j++) {
+                        original_row[j] = attn[i * tokens_in_chunk + j];  // Use tokens_in_chunk for indexing
+                    }
+
+                    for (int64_t j = 0; j < i; j++) {
+                        float sum = 0.0f;
+                        for (int64_t k = 0; k < i; k++) {
+                            // This implements: sum over k of (original_row[k] * sub[k, j])
+                            // Where sub[k, j] is attn[k, j] (the original value before updates)
+                            sum += original_row[k] * attn[k * tokens_in_chunk + j];  // Use tokens_in_chunk for indexing
+                        }
+                        // The new value is: original_value + matrix_mult_result
+                        attn[i * tokens_in_chunk + j] = original_row[j] + sum;
+                    }
+                }
+
+                // Step 3: Add identity matrix (equivalent to Python's: attn = attn + torch.eye(...))
+                for (int64_t i = 0; i < tokens_in_chunk; i++) {
+                    attn[i * tokens_in_chunk + i] += 1.0f;
+                }
+
+                // Step 4: Compute value = attn @ v_beta
+                ggml_compute_value_f32(attn.data(), v_beta_chunk.data(), value.data(), tokens_in_chunk,
+                                       S_v);  // Use actual tokens, not full chunk size
+
+                // Step 5: Compute k_cumdecay = attn @ (k_beta * g_exp)
+                ggml_compute_k_cumdecay_f32(attn.data(), k_beta_chunk.data(), g_exp.data(), k_cumdecay.data(),
+                                            tokens_in_chunk, S_k);  // Use actual tokens, not full chunk size
 
+                // Step 6: Compute core attention output for this chunk
+                // First, compute v_new for all tokens in the chunk
+                std::vector<float> v_new_chunk(tokens_in_chunk * S_v);
+                for (int64_t i = 0; i < tokens_in_chunk; i++) {
+                    // v_prime = k_cumdecay @ last_state
+                    // k_cumdecay[i] is [S_k], last_state for head h is [S_k, S_v]
+                    std::vector<float> v_prime(S_v, 0.0f);
+                    for (int64_t d1 = 0; d1 < S_v; d1++) {
+                        for (int64_t d2 = 0; d2 < S_k; d2++) {
+                            // State index: [d2, d1 + h*S_v] in GGML format
+                            int64_t state_idx = d2 * (S_v * H_v) + (d1 + h * S_v);
+                            v_prime[d1] += k_cumdecay[i * S_k + d2] * last_state[state_idx];
+                        }
+                    }
+
+                    // v_new = v_i - v_prime
+                    for (int64_t d = 0; d < S_v; d++) {
+                        v_new_chunk[i * S_v + d] = value[i * S_v + d] - v_prime[d];
+                    }
+                }
+
+                // Now process each token in the chunk to compute output
+                for (int64_t i = 0; i < tokens_in_chunk; i++) {
+                    // q_i @ k_i.T * decay_mask
+                    std::vector<float> q_k_attn(chunk_size);
+                    for (int64_t j = 0; j < chunk_size; j++) {
+                        float sum = 0.0f;
+                        for (int64_t d = 0; d < S_k; d++) {
+                            sum += q_chunk[i * S_k + d] * k_chunk[j * S_k + d];
+                        }
+                        // Apply decay mask - use the precomputed decay mask from src5 tensor
+                        if (j <= i) {                 // Only apply to lower triangular part (i >= j)
+                            float decay_val = ggml_get_f32_nd(
+                                src5, i, j, h, seq);  // [i, j, h, seq] to get exp(g_cumsum[i] - g_cumsum[j]) for head h
+                            q_k_attn[j] = sum * decay_val;
+                        } else {
+                            q_k_attn[j] = 0.0f;  // Zero out upper triangular part (i < j)
+                        }
+                    }
+
+                    // attn_inter = q_i * g_exp @ last_state
+                    // q_chunk[i] is [S_k], g_exp[i] is scalar, last_state for head h is [S_k, S_v]
+                    std::vector<float> attn_inter(S_v, 0.0f);
+                    for (int64_t d1 = 0; d1 < S_v; d1++) {
+                        for (int64_t d2 = 0; d2 < S_k; d2++) {
+                            // State index: [d2, d1 + h*S_v] in GGML format
+                            int64_t state_idx = d2 * (S_v * H_v) + (d1 + h * S_v);
+                            // Use double precision for the computation and then cast to float
+                            double  temp_result =
+                                (double) q_chunk[i * S_k + d2] * g_exp[i] * (double) last_state[state_idx];
+                            attn_inter[d1] += (float) temp_result;
+                        }
+                    }
+
+                    // core_attn_out = attn_inter + attn @ v_new
+                    // We need to use the attention matrix computed for this position (i)
+                    // The attn matrix was computed earlier in the chunk processing
+                    // attn @ v_new where attn is [chunk_size, chunk_size] and v_new is [chunk_size, S_v]
+                    // For token i, we want sum_j(attn[i, j] * v_new[j, :])
+                    std::vector<float> attn_v_new(S_v, 0.0f);
+                    for (int64_t d = 0; d < S_v; d++) {
+                        for (int64_t j = 0; j < tokens_in_chunk; j++) {  // Only process actual tokens
+                            // Use the attention matrix that was computed for position i
+                            // attn[i * chunk_size + j] is the attention from position i to j
+                            // v_new_chunk[j * S_v + d] is the v_new value for token j, dimension d
+                            attn_v_new[d] += attn[i * chunk_size + j] * v_new_chunk[j * S_v + d];
+                        }
+                    }
+
+                    // Store output - only store for the original sequence length (not the padded part)
+                    int64_t global_pos =
+                        chunk_idx * chunk_size + i;  // Convert local chunk position to global sequence position
+                    if (global_pos < n_tokens) {     // Make sure we don't exceed the original sequence length
+                        for (int64_t d = 0; d < S_v; d++) {
+                            // Output tensor is [S_v * H_v * n_tokens] for single sequence (n_seqs=1)
+                            // Indexing: [dim_idx + head_idx*S_v + pos_idx*S_v*H_v]
+                            int64_t ggml_idx = d + h * S_v + global_pos * S_v * H_v;
+                            output[ggml_idx] = attn_inter[d] + attn_v_new[d];
+                        }
+                    }
+                }
+
+                // Step 7: Update last_recurrent_state
+                std::vector<float> new_state_vec(S_v * S_v * H_v);
+
+                // Update running cumulative sum with current chunk's values
+                float prev_cumsum = 0.0f;  // Cumulative sum from all previous chunks
+                if (chunk_idx > 0) {
+                    // Get the cumulative sum of the last token from the previous chunk
+                    int64_t prev_chunk_last_token = std::min(chunk_size, n_tokens - (chunk_idx - 1) * chunk_size) - 1;
+                    if (prev_chunk_last_token >= 0) {
+                        prev_cumsum = running_cumsum[(chunk_idx - 1) * chunk_size + prev_chunk_last_token];
+                    }
+                }
+
+                // Update running_cumsum for tokens in this chunk
+                for (int64_t t = 0; t < tokens_in_chunk; t++) {
+                    int64_t global_pos = chunk_idx * chunk_size + t;
+                    if (global_pos < n_tokens) {
+                        running_cumsum[global_pos] = prev_cumsum + g_cumsum[t];
+                    }
+                }
+
+                // Find the last token position in the current chunk (not the entire sequence)
+                int64_t last_pos_in_chunk =
+                    std::min((chunk_idx + 1) * chunk_size, n_tokens) - 1;  // Last actual token in this chunk
+                if (last_pos_in_chunk >= chunk_idx * chunk_size && last_pos_in_chunk < n_tokens) {
+                    float g_last =
+                        running_cumsum[last_pos_in_chunk];  // Use the last token's cumulative sum in this chunk
+                    // Use double precision for exponential to avoid overflow/underflow
+                    double g_last_exp_double = exp((double) g_last);
+                    float  g_last_exp        = (float) g_last_exp_double;
+
+                    // last_state * g_exp[last]
+                    for (int64_t i = 0; i < S_k; i++) {
+                        for (int64_t j = 0; j < S_v; j++) {
+                            // State index: [i, j + h*S_v] in GGML format
+                            int64_t state_idx                              = i * (S_v * H_v) + (j + h * S_v);
+                            new_state_vec[i * (S_v * H_v) + (j + h * S_v)] = last_state[state_idx] * g_last_exp;
+                        }
+                    }
+
+                    // Add (k_i * (g_last - g_i).exp()).T @ v_new
+                    // This should be: (k_chunk * g_diff_exp).T @ v_new_chunk
+                    // where k_chunk is [chunk_size, S_k], v_new_chunk is [chunk_size, S_v]
+                    // result is [S_k, S_v]
+
+                    // First compute v_new for all positions in the chunk
+                    std::vector<float> v_new_chunk(chunk_size * S_v);
+                    for (int64_t i = 0; i < tokens_in_chunk; i++) {  // Only process actual tokens, not full chunk size
+                        for (int64_t d1 = 0; d1 < S_v; d1++) {
+                            // Recompute v_prime for this position
+                            float v_prime = 0.0f;
+                            for (int64_t d2 = 0; d2 < S_k; d2++) {
+                                // State index: [d2, d1 + h*S_v] in GGML format
+                                int64_t state_idx = d2 * (S_v * H_v) + (d1 + h * S_v);
+                                float   k_val     = k_cumdecay[i * S_k + d2];
+                                float   s_val     = last_state[state_idx];
+                                v_prime += k_val * s_val;
+                            }
+                            v_new_chunk[i * S_v + d1] = value[i * S_v + d1] - v_prime;
+                        }
+                    }
+
+                    // Now compute (k_chunk * g_diff_exp).T @ v_new_chunk
+                    // This is a matrix multiplication: [S_k, chunk_size] @ [chunk_size, S_v] = [S_k, S_v]
+                    // Only process the original sequence length, not the padded chunk size
+                    // In the Python reference, this is: (k_i * g_diff_exp).transpose(-1, -2) @ v_new
+                    // where g_diff_exp = torch.exp(g_last - g) and g_last = g[-1] (last token in chunk)
+                    for (int64_t d1 = 0; d1 < S_k; d1++) {
+                        for (int64_t d2 = 0; d2 < S_v; d2++) {
+                            float sum = 0.0f;
+                            for (int64_t i = 0; i < tokens_in_chunk; i++) {  // Only process actual tokens
+                                // Get g values for the current chunk from the cumsum tensor (src3)
+                                // For state update: g_last (last token in chunk) - g_current (current token)
+                                // g tensor has shape [GGML_DELTA_NET_CHUNK, 1, H_v, n_seqs] in GGML format after cumsum and reshaping
+
+                                // Access g_cumsum for current position in chunk - need to access the original g tensor before cumsum
+                                // The g_cumsum tensor is src3, but we need the original g values for the diff computation
+                                // Actually, we need to access g values that were cumsummed to compute the diff
+
+                                // Get the original g_cumsum values for current and last token in the chunk
+                                // g_cumsum values are stored in src3, which was reshaped from [chunk_size, 1, H_v, n_seqs] to [chunk_size, 1, H_v, n_seqs]
+                                float g_current = g_cumsum[i];      // Use the g_cumsum computed earlier in this chunk
+                                float g_last =
+                                    g_cumsum[tokens_in_chunk - 1];  // Use the last token's cumsum in this chunk
+
+                                float g_diff = g_last - g_current;
+                                float g_diff_exp;
+                                // Use double precision for exponential to avoid overflow/underflow
+                                // For numerical stability, if g_diff is very negative, exp(g_diff) will be very small
+                                if (g_diff < -50.0f) {
+                                    g_diff_exp = 0.0f;  // Set to zero to avoid underflow
+                                } else {
+                                    double g_diff_exp_double = exp((double) g_diff);
+                                    g_diff_exp               = (float) g_diff_exp_double;
+                                }
+                                sum += k_chunk[i * S_k + d1] * g_diff_exp * v_new_chunk[i * S_v + d2];
+                            }
+                            // State index: [d1, d2 + h*S_v] in GGML format
+                            int64_t state_idx = d1 * (S_v * H_v) + (d2 + h * S_v);
+                            new_state_vec[state_idx] += sum;
+                        }
+                    }
+
+                    // Update last_state
+                    for (int64_t i = 0; i < S_k; i++) {
+                        for (int64_t j = 0; j < S_v; j++) {
+                            // State index: [i, j + h*S_v] in GGML format
+                            int64_t state_idx     = i * (S_v * H_v) + (j + h * S_v);
+                            last_state[state_idx] = new_state_vec[state_idx];
+                        }
+                    }
+                }
+            }
+        }
+    }
+    // Copy the final state to the output tensor in the correct GGML layout
+    // GGML expects state layout: [d1, d2 + h*head_dim]
+    for (int64_t h = 0; h < H_v; h++) {
+        for (int64_t d1 = 0; d1 < S_v; d1++) {
+            for (int64_t d2 = 0; d2 < S_v; d2++) {
+                // GGML state index: [d1, d2 + h*head_dim]
+                int64_t ggml_state_idx     = d1 * (S_v * H_v) + (d2 + h * S_v);
+                // Our computed state index: [d1, d2 + h*S_v]
+                int64_t computed_state_idx = d1 * (S_v * H_v) + (d2 + h * S_v);
+                float   val                = last_state[computed_state_idx];
+                new_state[ggml_state_idx]  = val;
+            }
+        }
+    }
+}
+
+// ggml_compute_forward_rwkv_wkv7
 static void ggml_compute_forward_rwkv_wkv7_f32(
         const ggml_compute_params * params,
         ggml_tensor * dst) {

+ 3 - 0
ggml/src/ggml-cpu/ops.h

@@ -34,6 +34,7 @@ void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct
 void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_sum_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_cumsum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_mean(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_argmax(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_count_equal(const struct ggml_compute_params * params, struct ggml_tensor * dst);
@@ -85,6 +86,7 @@ void ggml_compute_forward_arange(const struct ggml_compute_params * params, stru
 void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_flash_attn_back(
         const struct ggml_compute_params * params,
@@ -100,6 +102,7 @@ void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params,
 void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_delta_net_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);

+ 32 - 0
ggml/src/ggml-cpu/vec.h

@@ -1414,6 +1414,38 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
 #endif
 }
 
+// Applies a triangular mask to the input vector 'src' and writes the result to 'dst'.
+// Parameters:
+//   n            - number of elements
+//   r            - current row index
+//   dst          - output array
+//   src          - input array
+//   keep_org_val - if true, keep original value where mask applies; otherwise use constant 'c'
+//   c            - constant value to use when not keeping original value
+//   type         - type of triangular mask (lower, upper, etc.)
+inline static void ggml_vec_tri_f32(const int n, const int r, float * dst, const float * src, bool keep_org_val, float c, enum ggml_tri_type type) {
+    for (int i = 0; i < n; ++i) {
+        bool cmp;
+        switch (type) {
+            case GGML_TRI_TYPE_LOWER: cmp = i < r; break;
+            case GGML_TRI_TYPE_LOWER_DIAG: cmp = i <= r; break;
+            case GGML_TRI_TYPE_UPPER: cmp = i > r; break;
+            case GGML_TRI_TYPE_UPPER_DIAG: cmp = i >= r; break;
+        }
+        dst[i] = cmp ? (keep_org_val ? src[i] : c) : 0.0f;
+    }
+}
+
+inline static void ggml_vec_cumsum_f32(const int n, float * y, const float * x) {
+    for (int i = 0; i < n; ++i) {
+        if (i == 0) {
+            y[i] = x[i];
+        } else {
+            y[i] = y[i - 1] + x[i];
+        }
+    }
+}
+
 inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) {
     ggml_float sum = 0.0;
     for (int i = 0; i < n; ++i) {

+ 168 - 2
ggml/src/ggml.c

@@ -935,6 +935,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "COS",
     "SUM",
     "SUM_ROWS",
+    "CUMSUM",
     "MEAN",
     "ARGMAX",
     "COUNT_EQUAL",
@@ -990,6 +991,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "TIMESTEP_EMBEDDING",
     "ARGSORT",
     "LEAKY_RELU",
+    "TRI",
 
     "FLASH_ATTN_EXT",
     "FLASH_ATTN_BACK",
@@ -1002,6 +1004,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "RWKV_WKV6",
     "GATED_LINEAR_ATTN",
     "RWKV_WKV7",
+    "DELTA_NET",
 
     "UNARY",
 
@@ -1019,7 +1022,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "GLU",
 };
 
-static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
+static_assert(GGML_OP_COUNT == 93, "GGML_OP_COUNT != 93");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1039,6 +1042,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "cos(x)",
     "Σx",
     "Σx_k",
+    "cumsum(x)",
     "Σx/n",
     "argmax(x)",
     "count_equal(x)",
@@ -1094,6 +1098,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "timestep_embedding(timesteps, dim, max_period)",
     "argsort(x)",
     "leaky_relu(x)",
+    "tri(x)",
 
     "flash_attn_ext(x)",
     "flash_attn_back(x)",
@@ -1106,6 +1111,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "rwkv_wkv6(k, v, r, tf, td, s)",
     "gated_linear_attn(k, v, q, gate, s)",
     "rwkv_wkv7(r, w, k, v, a, b, s)",
+    "delta_net(q, k, v, g, beta, state)",
 
     "unary(x)",
 
@@ -1123,7 +1129,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "glu(x)",
 };
 
-static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
+static_assert(GGML_OP_COUNT == 93, "GGML_OP_COUNT != 93");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -2337,6 +2343,20 @@ struct ggml_tensor * ggml_sum_rows(
     return result;
 }
 
+// ggml_cumsum
+
+struct ggml_tensor * ggml_cumsum(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+
+    struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, a->ne);
+
+    result->op     = GGML_OP_CUMSUM;
+    result->src[0] = a;
+
+    return result;
+}
+
 // ggml_mean
 
 struct ggml_tensor * ggml_mean(
@@ -4935,6 +4955,33 @@ struct ggml_tensor * ggml_timestep_embedding(
     return result;
 }
 
+// ggml_tri
+
+struct ggml_tensor * ggml_tri(
+    struct ggml_context * ctx,
+    struct ggml_tensor * a,
+    float constant,
+    enum ggml_tri_type tritype) {
+    
+    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+    ggml_set_op_params_i32(result, 0, tritype);
+    ggml_set_op_params_f32(result, 1, constant);
+
+    result->op = GGML_OP_TRI;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_tri_keep(
+    struct ggml_context * ctx,
+    struct ggml_tensor * a,
+    enum ggml_tri_type tritype) {
+
+    return ggml_tri(ctx, a, nan(""), tritype);
+}
+
 // ggml_argsort
 
 struct ggml_tensor * ggml_argsort(
@@ -5463,6 +5510,125 @@ struct ggml_tensor * ggml_rwkv_wkv7(
     return result;
 }
 
+// ggml_delta_net
+// prepare all the tensor data for the operation so we only
+// do the absolutely necessary steps in the op itself
+struct ggml_tensor * ggml_delta_net(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * q,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * g,
+        struct ggml_tensor  * beta,
+        struct ggml_tensor  * state,
+        bool                  use_qk_l2norm,
+        float                 scale,
+        float                 eps_norm
+    ) {
+    GGML_ASSERT(ggml_is_contiguous(q));
+    GGML_ASSERT(ggml_is_contiguous(k));
+    GGML_ASSERT(ggml_is_contiguous(v));
+    GGML_ASSERT(ggml_is_contiguous(g));
+    GGML_ASSERT(ggml_is_contiguous(beta));
+    GGML_ASSERT(ggml_is_contiguous(state));
+
+    const int64_t S_k = q->ne[0];
+    const int64_t H_k = q->ne[1];
+    const int64_t n_tokens = q->ne[2];
+    const int64_t n_seqs = q->ne[3];
+
+    const int64_t S_v = v->ne[0];
+    const int64_t H_v = v->ne[1];
+
+    GGML_ASSERT(v->ne[2] == n_tokens);
+    GGML_ASSERT(k->ne[2] == n_tokens);
+    GGML_ASSERT(g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
+    GGML_ASSERT(beta->ne[1] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
+    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == 1);
+
+    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+
+    GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
+
+    if (use_qk_l2norm) {
+        q = ggml_l2_norm(ctx, q, eps_norm);
+        k = ggml_l2_norm(ctx, k, eps_norm);
+    }
+    q = ggml_scale(ctx, q, scale);
+
+    int64_t pad_size = (GGML_DELTA_NET_CHUNK - n_tokens % GGML_DELTA_NET_CHUNK) % GGML_DELTA_NET_CHUNK;
+    int64_t num_chunks = (n_tokens + pad_size) / GGML_DELTA_NET_CHUNK;
+
+    // First, permute to chunk format: [n_tokens, S_k, H_k, n_seqs]
+    q = ggml_cont(ctx, ggml_permute(ctx, q, 1, 2, 0, 3));
+    k = ggml_cont(ctx, ggml_permute(ctx, k, 1, 2, 0, 3));
+    v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3));
+    beta = ggml_cont(ctx, ggml_permute(ctx, beta, 1, 2, 0, 3));
+    g = ggml_cont(ctx, ggml_permute(ctx, g, 1, 2, 0, 3));
+
+    // Then, pad the sequence dimension (n_tokens) to chunk_size
+    q = ggml_pad(ctx, q, pad_size, 0, 0, 0); // [CS, S_k, H_k, n_seqs]
+    k = ggml_pad(ctx, k, pad_size, 0, 0, 0); // [CS, S_k, H_k, n_seqs]
+    v = ggml_pad(ctx, v, pad_size, 0, 0, 0); // [CS, S_v, H_v, n_seqs]
+    beta = ggml_pad(ctx, beta, pad_size, 0, 0, 0); // [CS, 1, H_v, n_seqs]
+    g = ggml_pad(ctx, g, pad_size, 0, 0, 0); // [CS, 1, H_v, n_seqs]
+    
+    GGML_ASSERT(q->ne[0] % GGML_DELTA_NET_CHUNK == 0 && q->ne[1] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[0] % GGML_DELTA_NET_CHUNK == 0 && k->ne[1] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
+    GGML_ASSERT(v->ne[0] % GGML_DELTA_NET_CHUNK == 0 && v->ne[1] == S_v && v->ne[2] == H_v && v->ne[3] == n_seqs);
+    GGML_ASSERT(beta->ne[0] % GGML_DELTA_NET_CHUNK == 0 && beta->ne[1] == 1 && beta->ne[2] == H_v && beta->ne[3] == n_seqs);
+    GGML_ASSERT(g->ne[0] % GGML_DELTA_NET_CHUNK == 0 && g->ne[1] == 1 && g->ne[2] == H_v && g->ne[3] == n_seqs);
+
+    struct ggml_tensor * v_beta = ggml_mul(ctx, v, beta);
+    struct ggml_tensor * k_beta = ggml_mul(ctx, k, beta);
+    // struct ggml_tensor * mask = ggml_tri(ctx, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, GGML_DELTA_NET_CHUNK, GGML_DELTA_NET_CHUNK), 1.0f, GGML_TRI_TYPE_UPPER);
+    struct ggml_tensor * g_cumsum = ggml_cumsum(ctx, g);
+        
+    struct ggml_tensor * gcs_i = ggml_cont(ctx, g_cumsum);  // [chunk_size, 1, H_v, n_seqs]
+    struct ggml_tensor * gcs_j = ggml_cont(ctx, ggml_permute(ctx, g_cumsum, 1, 0, 2, 3));  // [1, chunk_size, H_v, n_seqs]
+    
+    // Broadcast both tensors to [chunk_size, chunk_size, H_v, n_seqs]
+    struct ggml_tensor * gcs_i_broadcast = ggml_repeat_4d(ctx, gcs_i, num_chunks * GGML_DELTA_NET_CHUNK, num_chunks * GGML_DELTA_NET_CHUNK, H_v, n_seqs);  // [chunk_size, 1, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
+    struct ggml_tensor * gcs_j_broadcast = ggml_repeat_4d(ctx, gcs_j, num_chunks * GGML_DELTA_NET_CHUNK, num_chunks * GGML_DELTA_NET_CHUNK, H_v, n_seqs);  // [1, chunk_size, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
+    
+    struct ggml_tensor * decay_mask = ggml_sub(ctx, gcs_j_broadcast, gcs_i_broadcast); 
+    
+    // Apply lower triangular mask to ensure attention is causal (only past tokens influence current)
+    decay_mask = ggml_tri_keep(ctx, decay_mask, GGML_TRI_TYPE_LOWER);   
+    // Apply exponential to get the decay mask values
+    decay_mask = ggml_exp(ctx, decay_mask);
+    // Apply lower triangular mask again to ensure only lower triangular values remain
+    decay_mask = ggml_tri_keep(ctx, decay_mask, GGML_TRI_TYPE_LOWER);
+
+    GGML_LOG_INFO("k_beta shape = [%ld, %ld, %ld, %ld], k shape = [%ld, %ld, %ld, %ld]\n", k_beta->ne[0], k_beta->ne[1], k_beta->ne[2], k_beta->ne[3], k->ne[0], k->ne[1], k->ne[2], k->ne[3]);
+    struct ggml_tensor * attn = ggml_neg(ctx, ggml_tri_keep(ctx, ggml_mul(ctx, ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx, k_beta)), ggml_cont(ctx, ggml_transpose(ctx, k))), decay_mask), GGML_TRI_TYPE_LOWER));
+    GGML_LOG_INFO("attn shape = [%ld, %ld, %ld, %ld]\n", attn->ne[0], attn->ne[1], attn->ne[2], attn->ne[3]);
+
+    // We'll be returning the result as a 1D tensor due to the dimensions mismatch of the state and output tensors
+    // Use original n_tokens for output size and padded chunk size for state size
+    const int64_t ne[1] = { (S_v * H_v * n_tokens) + (S_v * S_v * H_v * n_seqs) };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 1, ne);
+
+    ggml_set_op_params_i32(result, 0, H_v);
+    ggml_set_op_params_i32(result, 1, S_k);
+    ggml_set_op_params_i32(result, 2, S_v);
+    ggml_set_op_params_i32(result, 3, n_tokens); // Pass original n_tokens
+
+    result->op     = GGML_OP_DELTA_NET;
+    result->src[0] = q;
+    result->src[1] = k;
+    result->src[2] = v;
+    result->src[3] = g_cumsum;
+    result->src[4] = state;
+    result->src[5] = decay_mask;
+    result->src[6] = v_beta;
+    result->src[7] = k_beta;
+    result->src[8] = attn;
+
+    return result;
+}
+
 // ggml_unary
 
 static struct ggml_tensor * ggml_unary_impl(

+ 54 - 192
src/models/llm_build_qwen3next.cpp

@@ -98,155 +98,6 @@ struct ggml_tensor * llm_build_qwen3next::build_q3n_norm(struct ggml_tensor * in
     return build_norm(input, input_norm, nullptr, LLM_NORM_RMS, layer);
 }
 
-// ggml_delta_net
-struct ggml_tensor * llm_build_qwen3next::ggml_delta_net(struct ggml_tensor * q,
-                                                         struct ggml_tensor * k,
-                                                         struct ggml_tensor * v,
-                                                         struct ggml_tensor * g,
-                                                         struct ggml_tensor * beta,
-                                                         struct ggml_tensor * state,
-                                                         bool                 use_qk_l2norm,
-                                                         float                scale,
-                                                         int                  il) {
-    GGML_ASSERT(ggml_is_contiguous(k));
-    GGML_ASSERT(ggml_is_contiguous(v));
-    GGML_ASSERT(ggml_is_contiguous(q));
-    GGML_ASSERT(ggml_is_contiguous(g));
-    GGML_ASSERT(ggml_is_contiguous(beta));
-    GGML_ASSERT(ggml_is_contiguous(state));
-
-    cb(k, "k_delta_in", il);
-    cb(v, "v_delta_in", il);
-    cb(q, "q_delta_in", il);
-    cb(g, "g_delta_in", il);
-    cb(beta, "beta_delta_in", il);
-    cb(state, "state_delta_in", il);
-
-    const int64_t S_k      = k->ne[0];
-    const int64_t H_k      = k->ne[1];
-    const int64_t n_tokens = k->ne[2];
-    const int64_t n_seqs   = k->ne[3];
-
-    const int64_t S_v = v->ne[0];
-    const int64_t H_v = v->ne[1];
-
-    GGML_ASSERT(v->ne[2] == n_tokens);
-    GGML_ASSERT(q->ne[2] == n_tokens);
-    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[1] == n_tokens && beta->ne[2] == n_seqs);
-    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == n_seqs && state->ne[3] == n_tokens);
-
-    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens);
-    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens);
-
-    GGML_ASSERT(g->ne[0] == S_v && g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
-
-    // Beta sigmoid
-    struct ggml_tensor * beta_sigmoid = ggml_sigmoid(ctx0, beta);
-    cb(beta_sigmoid, "beta_sigmoid", il);
-
-    // Gate calculations are done elsewhere in llama-model.cpp
-
-    struct ggml_tensor * q_broadcast = q;
-    struct ggml_tensor * k_broadcast = k;
-
-    // if head keys and value keys are different, repeat to force tensors into matching shapes
-    if (H_k != H_v) {
-        GGML_ASSERT(H_v % H_k == 0);
-        int64_t repeat_factor = H_v / H_k;
-
-        q_broadcast = ggml_cont_4d(ctx0, q, S_k, n_tokens, H_k, n_seqs);
-        k_broadcast = ggml_cont_4d(ctx0, k, S_k, n_tokens, H_k, n_seqs);
-
-        q_broadcast = ggml_repeat_4d(ctx0, q_broadcast, S_k, n_tokens * repeat_factor, H_k, n_seqs);
-        k_broadcast = ggml_repeat_4d(ctx0, k_broadcast, S_k, n_tokens * repeat_factor, H_k, n_seqs);
-
-        q_broadcast = ggml_reshape_4d(ctx0, q_broadcast, S_k, H_v, n_seqs, n_tokens);
-        k_broadcast = ggml_reshape_4d(ctx0, k_broadcast, S_k, H_v, n_seqs, n_tokens);
-    }
-    struct ggml_tensor * v_reshape       = ggml_cont_4d(ctx0, v, S_v, H_v, n_seqs, n_tokens);
-    struct ggml_tensor * g_reshape       = ggml_cont_4d(ctx0, g, S_v, H_v, n_seqs, n_tokens);
-    struct ggml_tensor * beta_broadcast  = ggml_cont_4d(ctx0, beta_sigmoid, 1, H_v, n_seqs, n_tokens);
-    struct ggml_tensor * state_broadcast = ggml_cont(ctx0, state);
-
-    return ggml_delta_net_op(q_broadcast, k_broadcast, v_reshape, g_reshape, beta_broadcast, state_broadcast,
-                             use_qk_l2norm, scale, il);
-}
-
-struct ggml_tensor * llm_build_qwen3next::ggml_delta_net_op(struct ggml_tensor * q,
-                                                            struct ggml_tensor * k,
-                                                            struct ggml_tensor * v,
-                                                            struct ggml_tensor * g,
-                                                            struct ggml_tensor * beta,
-                                                            struct ggml_tensor * state,
-                                                            bool                 use_qk_l2norm,
-                                                            float                scale,
-                                                            int                  il) {
-    GGML_ASSERT(ggml_is_contiguous(q));
-    GGML_ASSERT(ggml_is_contiguous(k));
-    GGML_ASSERT(ggml_is_contiguous(v));
-    GGML_ASSERT(ggml_is_contiguous(g));
-    GGML_ASSERT(ggml_is_contiguous(beta));
-    GGML_ASSERT(ggml_is_contiguous(state));
-
-    const int64_t S_k      = q->ne[0];
-    const int64_t H_k      = q->ne[1];
-    const int64_t n_seq    = q->ne[2];
-    const int64_t n_tokens = q->ne[3];
-
-    const int64_t S_v = v->ne[0];
-    const int64_t H_v = v->ne[1];
-
-    GGML_ASSERT(H_k == H_v);  // we broadcasted the tensors in the main function to guarantee this
-
-    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_v && k->ne[2] == n_seq && k->ne[3] == n_tokens);
-    GGML_ASSERT(v->ne[1] == H_v && v->ne[2] == n_seq && v->ne[3] == n_tokens);
-    GGML_ASSERT(g->ne[0] == S_v && g->ne[1] == H_v && g->ne[2] == n_seq && g->ne[3] == n_tokens);
-    GGML_ASSERT(beta->ne[0] == 1 && beta->ne[1] == H_v && beta->ne[2] == n_seq && beta->ne[3] == n_tokens);
-    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == n_seq && state->ne[3] == n_tokens);
-
-    struct ggml_tensor * new_state = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, S_v, S_v * H_v, n_seq, n_tokens);
-
-    new_state = ggml_cpy(ctx0, state, new_state);
-    cb(new_state, "new_state", il);
-
-    if (use_qk_l2norm) {
-        q = ggml_l2_norm(ctx0, q, 1e-6f);
-        cb(q, "q_l2_norm", il);
-        k = ggml_l2_norm(ctx0, k, 1e-6f);
-        cb(q, "k_l2_norm", il);
-    }
-    q = ggml_scale(ctx0, q, scale);
-    cb(q, "q_scaled", il);
-
-    struct ggml_tensor * state_decay = ggml_mul(ctx0, state, g);
-    cb(state_decay, "state_decay", il);
-    struct ggml_tensor * kv_mem_presum = ggml_mul(ctx0, state_decay, k);
-
-    // Gotta do some squeezing here...
-    struct ggml_tensor * kv_mem_presum_squeeze = ggml_cont_4d(ctx0, kv_mem_presum, S_v, S_v, H_v, n_seq * n_tokens);
-    struct ggml_tensor * kv_mem = ggml_permute(ctx0, ggml_sum_rows(ctx0, kv_mem_presum_squeeze), 3, 0, 1, 2);
-    cb(kv_mem, "kv_mem", il);
-    struct ggml_tensor * kv_mem_reshape = ggml_reshape_4d(ctx0, kv_mem, S_v, H_v, n_seq, n_tokens);
-    struct ggml_tensor * delta          = ggml_mul(ctx0, ggml_sub(ctx0, v, kv_mem_reshape), beta);
-    cb(delta, "delta", il);
-    struct ggml_tensor * delta_kt = ggml_mul(ctx0, delta, k);
-    cb(delta_kt, "delta_kt", il);
-    struct ggml_tensor * state_plus_k_delta = ggml_add(ctx0, state_decay, delta_kt);
-    cb(state_plus_k_delta, "state_plus_k_delta", il);
-    struct ggml_tensor * state_q = ggml_mul(ctx0, state_plus_k_delta, q);
-    cb(state_q, "state_q", il);
-
-    // And here...
-    state_q                     = ggml_reshape_4d(ctx0, state_q, S_v, S_v, H_v, n_seq * n_tokens);
-    struct ggml_tensor * output = ggml_permute(ctx0, ggml_sum_rows(ctx0, state_q), 2, 0, 1, 3);
-    output                      = ggml_reshape_4d(ctx0, output, S_v, H_v, n_seq, n_tokens);
-    cb(output, "delta_net_output", il);
-
-    struct ggml_tensor * result = ggml_concat(ctx0, output, state_plus_k_delta, 1);
-    cb(result, "delta_net_result", il);
-    return result;
-}
-
 ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_tensor *             cur,
                                                                    ggml_tensor *             inp_pos,
                                                                    llm_graph_input_attn_kv * inp_attn,
@@ -356,18 +207,14 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     cb(a, "a", il);
 
     // Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
-    ggml_tensor * beta  = ggml_reshape_3d(ctx0, ggml_cont(ctx0, b), num_v_heads, n_tokens, n_seqs);
-    ggml_tensor * alpha = ggml_reshape_3d(ctx0, ggml_cont(ctx0, a), num_v_heads, n_tokens, n_seqs);
+    ggml_tensor * beta  = ggml_cont_3d(ctx0, b, num_v_heads, n_tokens, n_seqs);
+    ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_tokens, n_seqs);
 
     GGML_ASSERT(ggml_nelements(beta) + ggml_nelements(alpha) == ggml_nelements(mixed_ba));
 
     ggml_tensor * alpha_softplus = softplus(alpha, model.layers[il].ssm_dt);
     cb(alpha_softplus, "a_softplus", il);
-    ggml_tensor * A_log_exp = ggml_exp(ctx0, model.layers[il].ssm_a);       // A_log.exp()
-    cb(A_log_exp, "a_logexp", il);
-    ggml_tensor * gate_scaled = ggml_mul(ctx0, alpha_softplus, A_log_exp);  // A_log.exp() * softplus
-    cb(gate_scaled, "gate_scaled", il);
-    ggml_tensor * gate = ggml_scale(ctx0, gate_scaled, -1.0f);              // - (A_log.exp() * softplus)
+    ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a);  // -A_log.exp() * softplus
     cb(gate, "gate", il);
 
     // Get convolution states from cache
@@ -505,50 +352,65 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_tokens, n_seqs);
     v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_tokens, n_seqs);
 
-    // Beta tensor
-    beta = ggml_reshape_3d(ctx0, beta, n_heads, n_tokens, n_seqs);
+    beta  = ggml_cont_4d(ctx0, b, 1, num_v_heads, n_tokens, n_seqs);
+    alpha = ggml_cont_4d(ctx0, a, 1, num_v_heads, n_tokens, n_seqs);
+
+    ggml_tensor * state = ggml_reshape_4d(ctx0, ssm_states_all, head_dim, head_dim * n_heads, 1, 1);
+    gate = ggml_reshape_4d(ctx0, gate, 1, n_heads, n_tokens, n_seqs);
 
-    ggml_tensor * state           = ggml_reshape_4d(ctx0, ssm_states_all, head_dim, head_dim * n_heads, 1, 1);
-    ggml_tensor * state_broadcast = ggml_repeat_4d(ctx0, state, head_dim, head_dim * n_heads, n_seqs, n_tokens);
-    ggml_tensor * target_gate     = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_dim, n_heads, n_tokens, n_seqs);
-    ggml_tensor * gate_broadcast  = ggml_reshape_4d(ctx0, gate, 1, n_heads, n_tokens, n_seqs);
-    gate                          = ggml_repeat(ctx0, gate_broadcast, target_gate);
+        // if head keys and value keys are different, repeat to force tensors into matching shapes
+    if (num_k_heads != num_v_heads) {
+        GGML_ASSERT(num_v_heads % num_k_heads == 0);
+        int64_t repeat_factor = num_v_heads / num_k_heads;
+
+        q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, n_tokens, num_k_heads, n_seqs);
+        k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, n_tokens, num_k_heads, n_seqs);
+
+        q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, n_tokens * repeat_factor, num_k_heads, n_seqs);
+        k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, n_tokens * repeat_factor, num_k_heads, n_seqs);
+
+        // Fix dimension order: last two should be [tokens, batches]
+        q_conv = ggml_reshape_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_tokens, n_seqs);
+        k_conv = ggml_reshape_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_tokens, n_seqs);
+    }
 
     // Call the new ggml_delta_net function with the corrected flow
-    ggml_tensor * output = ggml_delta_net(q_conv, k_conv, v_conv, gate, beta, state_broadcast, true, 1.0f, il);
-    cb(q_conv, "delta_output", il);
-
-    // Extract the output part
-    ggml_tensor * attn_out =
-        ggml_view_4d(ctx0, output, head_dim, n_heads, n_tokens, n_seqs, output->nb[0], output->nb[1], output->nb[2], 0);
-    cb(output, "attn_out", il);
-
-    // Extract the new state
-    ggml_tensor * new_state =
-        ggml_view_4d(ctx0, output, head_dim, head_dim * n_heads, n_tokens, n_seqs, output->nb[0], output->nb[1],
-                     output->nb[2], n_tokens * n_seqs * head_dim * n_heads * ggml_element_size(output));
-    cb(output, "new_state", il);
-
-    // Only return the last recurrent state
-    struct ggml_tensor * state_reshaped = ggml_cont_4d(ctx0, new_state, head_dim, head_dim, n_heads, n_tokens * n_seqs);
-    struct ggml_tensor * state_last =
-        ggml_view_4d(ctx0, state_reshaped, head_dim, head_dim, n_heads, 1, state_reshaped->nb[1], state_reshaped->nb[2],
-                     state_reshaped->nb[3], head_dim * head_dim * n_heads * ((n_seqs * n_tokens) - 1));
-    cb(output, "new_state_last", il);
-
-    // Update the recurrent states
-    ggml_build_forward_expand(gf, ggml_cpy(ctx0, state_last, ssm_states_all));
-
-    // Reshape both attn_out and z to 2D tensors for normalization
-    // attn_out: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
-    ggml_tensor * attn_out_2d = ggml_reshape_2d(ctx0, ggml_cont(ctx0, attn_out), head_dim, n_heads * n_tokens * n_seqs);
+    const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(num_k_heads)) : hparams.f_attention_scale;
+    ggml_tensor * attn_out = ggml_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, true, kq_scale, hparams.f_norm_rms_eps);
+    cb(attn_out, "attn_out", il);
+
+    // The tensors were concatenated 1d, so we need to extract them 1d as well
+    const int64_t output_flat_size = head_dim * n_heads * n_tokens * n_seqs;
+    ggml_tensor * attn_out_1d =
+        ggml_view_1d(ctx0, attn_out, output_flat_size, 0);
+    cb(attn_out_1d, "attn_out_1d", il);
+    
+    ggml_tensor * attn_out_final = ggml_cont_4d(ctx0, attn_out_1d, head_dim, n_heads, n_tokens, n_seqs);
+    cb(attn_out_final, "attn_out_final", il);
+   
+    // Extract the state part (second part of the concatenated tensor)
+    // State starts after n_tokens elements along dimension 1
+    const int64_t state_flat_size = head_dim * head_dim * n_heads * n_seqs;
+    
+    ggml_tensor * state_1d = ggml_view_1d(ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size(attn_out));
+    cb(state_1d, "state_1d", il);
+    
+    ggml_tensor * new_state = ggml_reshape_4d(ctx0, state_1d, head_dim, head_dim, n_heads, n_seqs);
+    cb(new_state, "new_state", il);
+
+    // Update the recurrent states - we use the new_state directly since it's already the last state
+    ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_state, ssm_states_all));
+
+    // Reshape both attn_out_final and z to 2D tensors for normalization
+    // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
+    ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, ggml_cont(ctx0, attn_out_final), head_dim, n_heads * n_tokens * n_seqs);
 
     // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
     ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z_reshaped, head_dim, n_heads * n_tokens * n_seqs);
 
     // Apply gated normalization: self.norm(core_attn_out, z)
     // This is Qwen3NextRMSNormGated which applies: RMSNorm(x) * silu(gate)
-    ggml_tensor * attn_out_norm = build_norm(attn_out_2d, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
+    ggml_tensor * attn_out_norm = build_norm(attn_out_2d_final, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
     cb(attn_out_norm, "attn_out_norm", il);
 
     // Apply silu gate: attn_out_norm * silu(z_2d)
@@ -562,7 +424,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
 
     // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
     ggml_tensor * final_output = ggml_reshape_3d(ctx0, gated_output_4d, n_heads * head_dim, n_tokens, n_seqs);
-    cb(output, "final_output", il);
+    cb(final_output, "final_output", il);
 
     // Output projection
     cur = build_lora_mm(model.layers[il].ssm_out, final_output);

+ 1 - 10
src/models/llm_build_qwen3next.h

@@ -11,16 +11,6 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
 
 private:
     // ggml_delta_net
-    struct ggml_tensor * ggml_delta_net(struct ggml_tensor * k,
-                                        struct ggml_tensor * v,
-                                        struct ggml_tensor * q,
-                                        struct ggml_tensor * g,
-                                        struct ggml_tensor * beta,
-                                        struct ggml_tensor * state,
-                                        bool                 use_qk_l2norm,
-                                        float                scale,
-                                        int                  il);
-
     ggml_tensor * ggml_delta_net_op(struct ggml_tensor * q,
                                    struct ggml_tensor * k,
                                    struct ggml_tensor * v,
@@ -29,6 +19,7 @@ private:
                                    struct ggml_tensor * state,
                                    bool                 use_qk_l2norm,
                                    float                scale,
+                                   float                eps_norm,
                                    int                  il);
 
     ggml_tensor * build_qwen3next_attention_layer(ggml_tensor *             cur,