Przeglądaj źródła

Parity on delta!

Piotr Wilkin 3 miesięcy temu
rodzic
commit
666fc0583d
2 zmienionych plików z 201 dodań i 527 usunięć
  1. 194 522
      ggml/src/ggml-cpu/ops.cpp
  2. 7 5
      src/models/llm_build_qwen3next.cpp

+ 194 - 522
ggml/src/ggml-cpu/ops.cpp

@@ -10458,7 +10458,7 @@ void ggml_compute_forward_gla(
 }
 }
 
 
 // Helper function to compute cumulative sum
 // Helper function to compute cumulative sum
-static void ggml_cumsum_f32(const float * x, float * dst, const int64_t n) {
+static void delta_cumsum_f32(const float * x, float * dst, const int64_t n) {
     float cumsum = 0.0f;
     float cumsum = 0.0f;
     for (int64_t i = 0; i < n; i++) {
     for (int64_t i = 0; i < n; i++) {
         cumsum += x[i];
         cumsum += x[i];
@@ -10467,7 +10467,7 @@ static void ggml_cumsum_f32(const float * x, float * dst, const int64_t n) {
 }
 }
 
 
 // Helper function for matrix multiplication
 // Helper function for matrix multiplication
-static void ggml_matmul_f32(const float * a, const float * b, float * dst,
+static void delta_matmul_f32(const float * a, const float * b, float * dst,
                               const int64_t m, const int64_t n, const int64_t k) {
                               const int64_t m, const int64_t n, const int64_t k) {
     for (int64_t i = 0; i < m; i++) {
     for (int64_t i = 0; i < m; i++) {
         for (int64_t j = 0; j < n; j++) {
         for (int64_t j = 0; j < n; j++) {
@@ -10481,7 +10481,7 @@ static void ggml_matmul_f32(const float * a, const float * b, float * dst,
 }
 }
 
 
 // Helper function to create upper triangular mask
 // Helper function to create upper triangular mask
-static void ggml_create_upper_triangular_mask(bool * mask, const int64_t size) {
+static void delta_create_upper_triangular_mask(bool * mask, const int64_t size) {
     for (int64_t i = 0; i < size; i++) {
     for (int64_t i = 0; i < size; i++) {
         for (int64_t j = 0; j < size; j++) {
         for (int64_t j = 0; j < size; j++) {
             mask[i * size + j] = (j >= i); // upper triangular with diagonal
             mask[i * size + j] = (j >= i); // upper triangular with diagonal
@@ -10505,7 +10505,7 @@ static void ggml_compute_chunk_decay_mask_f32(const float * g_cumsum, float * de
 }
 }
 
 
 // Helper function to compute k_beta @ key.T
 // Helper function to compute k_beta @ key.T
-static void ggml_compute_k_beta_key_t_f32(const float * k_beta, const float * key,
+static void delta_compute_k_beta_key_t_f32(const float * k_beta, const float * key,
                                              float * k_beta_key_t,
                                              float * k_beta_key_t,
                                              const int64_t chunk_size, const int64_t k_head_dim) {
                                              const int64_t chunk_size, const int64_t k_head_dim) {
     for (int64_t i = 0; i < chunk_size; i++) {
     for (int64_t i = 0; i < chunk_size; i++) {
@@ -10522,7 +10522,7 @@ static void ggml_compute_k_beta_key_t_f32(const float * k_beta, const float * ke
 }
 }
 
 
 // Helper function to apply triangular updates
 // Helper function to apply triangular updates
-static void ggml_apply_triangular_updates_f32(float * attn, const int64_t chunk_size) {
+static void delta_apply_triangular_updates_f32(float * attn, const int64_t chunk_size) {
     for (int64_t i = 1; i < chunk_size; i++) {
     for (int64_t i = 1; i < chunk_size; i++) {
         for (int64_t j = 0; j < i; j++) {
         for (int64_t j = 0; j < i; j++) {
             float sum = 0.0f;
             float sum = 0.0f;
@@ -10535,14 +10535,14 @@ static void ggml_apply_triangular_updates_f32(float * attn, const int64_t chunk_
 }
 }
 
 
 // Helper function to add identity matrix
 // Helper function to add identity matrix
-static void ggml_add_identity_matrix_f32(float * matrix, const int64_t size) {
+static void delta_add_identity_matrix_f32(float * matrix, const int64_t size) {
     for (int64_t i = 0; i < size; i++) {
     for (int64_t i = 0; i < size; i++) {
         matrix[i * size + i] += 1.0f;
         matrix[i * size + i] += 1.0f;
     }
     }
 }
 }
 
 
 // Helper function to compute value = attn @ v_beta
 // Helper function to compute value = attn @ v_beta
-static void ggml_compute_value_f32(const float * attn, const float * v_beta,
+static void delta_compute_value_f32(const float * attn, const float * v_beta,
                                       float * value,
                                       float * value,
                                       const int64_t chunk_size, const int64_t v_head_dim) {
                                       const int64_t chunk_size, const int64_t v_head_dim) {
     for (int64_t i = 0; i < chunk_size; i++) {
     for (int64_t i = 0; i < chunk_size; i++) {
@@ -10558,21 +10558,19 @@ static void ggml_compute_value_f32(const float * attn, const float * v_beta,
 }
 }
 
 
 // Helper function to compute k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
 // Helper function to compute k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
-static void ggml_compute_k_cumdecay_f32(const float * attn, const float * k_beta, const double * g_exp,
-                                           float * k_cumdecay,
-                                           const int64_t chunk_size, const int64_t k_head_dim) {
+static void delta_compute_k_cumdecay_f32(const float * attn, const float * k_beta, const float * g,
+                                           float * k_cumdecay, const int64_t chunk_size, const int64_t k_head_dim) {
     for (int64_t i = 0; i < chunk_size; i++) {
     for (int64_t i = 0; i < chunk_size; i++) {
         for (int64_t d = 0; d < k_head_dim; d++) {
         for (int64_t d = 0; d < k_head_dim; d++) {
             float sum = 0.0f;
             float sum = 0.0f;
             for (int64_t j = 0; j < chunk_size; j++) {
             for (int64_t j = 0; j < chunk_size; j++) {
                 int64_t k_beta_idx = j * k_head_dim + d;
                 int64_t k_beta_idx = j * k_head_dim + d;
-                sum += attn[i * chunk_size + j] * k_beta[k_beta_idx] * g_exp[j];
+                sum += attn[i * chunk_size + j] * k_beta[k_beta_idx] * expf(g[j]);
             }
             }
             k_cumdecay[i * k_head_dim + d] = sum;
             k_cumdecay[i * k_head_dim + d] = sum;
         }
         }
     }
     }
 }
 }
-// Helper functions for delta net computation
 
 
 // Matrix multiplication helper for delta net
 // Matrix multiplication helper for delta net
 static void ggml_delta_net_matmul_f32(const float * a, const int64_t rows_a, const int64_t cols_a, const int64_t cols_b,
 static void ggml_delta_net_matmul_f32(const float * a, const int64_t rows_a, const int64_t cols_a, const int64_t cols_b,
@@ -10590,47 +10588,76 @@ static void ggml_delta_net_matmul_f32(const float * a, const int64_t rows_a, con
     }
     }
 }
 }
 
 
-// Helper function to apply triangular updates to attention matrix
-static void delta_net_compute_diagonal_updates(
-    float * attn,
-    const int64_t chunk_size,
-    const int64_t n_heads,
-    const int64_t n_seqs) {
-    
-    // Apply triangular updates like in the Python reference:
-    // for i in range(1, chunk_size):
-    //     row = attn[..., i, :i].clone()
-    //     sub = attn[..., :i, :i].clone()
-    //     attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
-    // attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
-    
-    for (int64_t head = 0; head < n_heads; head++) {
-        for (int64_t seq = 0; seq < n_seqs; seq++) {
-            // Get pointer to this head's attention matrix
-            float * attn_head = attn + (head * chunk_size * chunk_size * n_seqs) + (seq * chunk_size * chunk_size);
-            
-            // Create temporary storage for the original values to avoid in-place modification
-            float * attn_copy = (float *) malloc(chunk_size * chunk_size * sizeof(float));
-            memcpy(attn_copy, attn_head, chunk_size * chunk_size * sizeof(float));
-            
-            // Apply triangular updates using the original values
-            for (int64_t i = 1; i < chunk_size; i++) {
-                for (int64_t j = 0; j < i; j++) {
-                    float sum = 0.0f;
-                    // Compute (row.unsqueeze(-1) * sub).sum(-2) using original values
-                    for (int64_t k = 0; k < i; k++) {
-                        sum += attn_copy[i * chunk_size + k] * attn_copy[k * chunk_size + j];
-                    }
-                    attn_head[i * chunk_size + j] = attn_copy[i * chunk_size + j] + sum;
-                }
+// Helper function to compute q_i @ k_i.transpose(-1, -2) * decay_mask and apply mask
+static void delta_compute_q_k_attn_f32(const float * q, const float * k, const float * decay_mask,
+                                       float * attn, const bool * mask,
+                                       const int64_t chunk_size, const int64_t head_dim) {
+    // Compute q @ k.transpose(-1, -2)
+    for (int64_t i = 0; i < chunk_size; i++) {
+        for (int64_t j = 0; j < chunk_size; j++) {
+            float sum = 0.0f;
+            for (int64_t d = 0; d < head_dim; d++) {
+                int64_t q_idx = i * head_dim + d;
+                int64_t k_idx = j * head_dim + d;
+                sum += q[q_idx] * k[k_idx];
+            }
+            // Apply decay mask and causal mask
+            int64_t attn_idx = i * chunk_size + j;
+            attn[attn_idx] = (mask[attn_idx] ? 0.0f : sum * decay_mask[attn_idx]);
+        }
+    }
+}
+
+// Helper function for matrix multiplication with state tensors
+static void delta_matmul_state_f32(const float * a, const float * state, float * dst,
+                                   const int64_t rows_a, const int64_t cols_a, const int64_t cols_state) {
+    for (int64_t i = 0; i < rows_a; i++) {
+        for (int64_t j = 0; j < cols_state; j++) {
+            float sum = 0.0f;
+            for (int64_t k = 0; k < cols_a; k++) {
+                int64_t a_idx = i * cols_a + k;
+                int64_t state_idx = k * cols_state + j;
+                sum += a[a_idx] * state[state_idx];
             }
             }
+            dst[i * cols_state + j] = sum;
+        }
+    }
+}
+
+// Helper function for element-wise tensor subtraction
+static void delta_tensor_subtract_f32(const float * a, const float * b, float * dst, const int64_t size) {
+    for (int64_t i = 0; i < size; i++) {
+        dst[i] = a[i] - b[i];
+    }
+}
+
+// Helper function for element-wise tensor addition
+static void delta_tensor_add_f32(const float * a, const float * b, float * dst, const int64_t size) {
+    for (int64_t i = 0; i < size; i++) {
+        dst[i] = a[i] + b[i];
+    }
+}
+
+// Helper function to update recurrent state
+static void delta_update_recurrent_state_f32(const float * last_state, const float * g_last,
+                                             const float * k_i, const float * g_diff_exp, const float * v_new, float * new_state,
+                                             const int64_t chunk_size, const int64_t k_head_dim, const int64_t v_head_dim) {
+    for (int64_t i = 0; i < k_head_dim; i++) {
+        for (int64_t j = 0; j < v_head_dim; j++) {
+            int64_t state_idx = i * v_head_dim + j;
             
             
-            // Add identity matrix
-            for (int64_t i = 0; i < chunk_size; i++) {
-                attn_head[i * chunk_size + i] += 1.0f;
+            // last_recurrent_state * g_last
+            float term1 = last_state[state_idx] * (*g_last);
+            
+            // (k_i * g_diff_exp).transpose(-1, -2) @ v_new
+            float term2 = 0.0f;
+            for (int64_t k = 0; k < chunk_size; k++) {
+                int64_t k_idx = k * k_head_dim + i;
+                int64_t v_idx = k * v_head_dim + j;
+                term2 += k_i[k_idx] * g_diff_exp[k] * v_new[v_idx];
             }
             }
             
             
-            free(attn_copy);
+            new_state[state_idx] = term1 + term2;
         }
         }
     }
     }
 }
 }
@@ -10650,10 +10677,10 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
     const int64_t S_k               = (int64_t) dst->op_params[1];
     const int64_t S_k               = (int64_t) dst->op_params[1];
     const int64_t S_v               = (int64_t) dst->op_params[2];
     const int64_t S_v               = (int64_t) dst->op_params[2];
     const int64_t original_n_tokens = (int64_t) dst->op_params[3];  // Get original sequence length
     const int64_t original_n_tokens = (int64_t) dst->op_params[3];  // Get original sequence length
-    const int64_t H_k               = H_v;
     const int64_t n_tokens          = original_n_tokens;            // Use the original sequence length
     const int64_t n_tokens          = original_n_tokens;            // Use the original sequence length
     const int64_t n_seqs            = src0->ne[3];                  // q tensor has n_seqs in dim 3
     const int64_t n_seqs            = src0->ne[3];                  // q tensor has n_seqs in dim 3
 
 
+
     // Add assertions to verify tensor dimensions
     // Add assertions to verify tensor dimensions
     GGML_ASSERT(src0->ne[3] == n_seqs);  // q tensor
     GGML_ASSERT(src0->ne[3] == n_seqs);  // q tensor
     GGML_ASSERT(src1->ne[3] == n_seqs);  // k tensor
     GGML_ASSERT(src1->ne[3] == n_seqs);  // k tensor
@@ -10669,14 +10696,13 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
     const int ith = params->ith;
     const int ith = params->ith;
     // const int nth = params->nth;  // nth is unused
     // const int nth = params->nth;  // nth is unused
 
 
-    // For chunked implementation, we process all sequences in thread 0 for simplicity
-    // This can be optimized later to parallelize across sequences
+    // TODO: parallelize across heads and sequences
     if (ith != 0) {
     if (ith != 0) {
         return;
         return;
     }
     }
 
 
     // Clear output and new state section
     // Clear output and new state section
-    memset(output, 0, ((S_v * H_v * n_tokens) + (S_v * H_v * S_v * n_seqs)) * sizeof(float));
+    memset(output, 0, ((S_v * H_v * n_tokens * n_seqs) + (S_v * S_v * H_v * n_seqs)) * sizeof(float));
 
 
     // Calculate chunk size
     // Calculate chunk size
     const int64_t chunk_size = GGML_DELTA_NET_CHUNK;
     const int64_t chunk_size = GGML_DELTA_NET_CHUNK;
@@ -10684,518 +10710,164 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
     const int64_t num_chunks = (n_tokens + pad_size) / chunk_size;
     const int64_t num_chunks = (n_tokens + pad_size) / chunk_size;
 
 
     // Apply triangular updates to the precomputed attention matrix
     // Apply triangular updates to the precomputed attention matrix
-    // This is the missing piece that was causing the attention matrix to have all zeros
     float * attn_data = (float *) src8->data;
     float * attn_data = (float *) src8->data;
-    delta_net_compute_diagonal_updates(attn_data, chunk_size, H_v, n_seqs);
-    
-    // Debug: Check attention matrix after triangular updates
-    float attn_after_updates_sum = 0.0f;
-    float attn_after_updates_max = 0.0f;
-    for (int64_t i = 0; i < chunk_size * chunk_size * H_v * n_seqs; i++) {
-        attn_after_updates_sum += attn_data[i];
-        attn_after_updates_max = fmaxf(attn_after_updates_max, fabsf(attn_data[i]));
-    }
-    
-    for (int64_t seq_idx = 0; seq_idx < n_seqs; seq_idx++) {
-        for (int64_t head_idx = 0; head_idx < H_v; head_idx++) {            
-            const int64_t v_beta_offset = (head_idx * src6->nb[2] + seq_idx * src6->nb[3]) / sizeof(float);
-            const int64_t k_beta_offset = (head_idx * src7->nb[2] + seq_idx * src7->nb[3]) / sizeof(float);
-            const int64_t attn_offset = (head_idx * src8->nb[2] + seq_idx * src8->nb[3]) / sizeof(float);
-            const int64_t g_offset = (head_idx * src3->nb[2] + seq_idx * src3->nb[3]) / sizeof(float);
-            
-            // Fixed memory access patterns with bounds checking
-            float * attn_precomputed = (float *) src8->data + attn_offset;
-            float * v_beta_ptr = (float *) src6->data + v_beta_offset;
-            float * k_beta_ptr = (float *) src7->data + k_beta_offset;
-            float * g_vals = (float *) src3->data + g_offset;
-            
-            // Add bounds checking to prevent out-of-bounds access
-            const int64_t attn_total_elements = src8->ne[0] * src8->ne[1] * src8->ne[2] * src8->ne[3];
-            const int64_t v_beta_total_elements = src6->ne[0] * src6->ne[1] * src6->ne[2] * src6->ne[3];
-            const int64_t k_beta_total_elements = src7->ne[0] * src7->ne[1] * src7->ne[2] * src7->ne[3];
-            const int64_t g_total_elements = src3->ne[0] * src3->ne[1] * src3->ne[2] * src3->ne[3];
-                        
-            // Compute value = attn @ v_beta with deterministic tensor access
-            // printf("C++ DEBUG: Computing value = attn @ v_beta with deterministic tensor access\n");
-            float * value = (float *) malloc(chunk_size * S_v * sizeof(float));
-            
-            // Calculate tensor strides for deterministic access
-            const int64_t attn_stride0 = src8->nb[0] / sizeof(float);  // chunk_size dimension
-            const int64_t v_beta_stride0 = src6->nb[0] / sizeof(float);  // S_v dimension
-            const int64_t v_beta_stride1 = src6->nb[1] / sizeof(float);  // chunk_size dimension
-            
-            // printf("C++ DEBUG: Tensor strides for deterministic access:\n");
-            // printf("  attn_stride0=%ld, v_beta_stride0=%ld, v_beta_stride1=%ld\n",
-            //        attn_stride0, v_beta_stride0, v_beta_stride1);
-            
-            for (int64_t i = 0; i < chunk_size; i++) {
-                for (int64_t d = 0; d < S_v; d++) {
-                    float sum = 0.0f;
-                    for (int64_t j = 0; j < chunk_size; j++) {
-                        // Deterministic tensor access using stride calculations
-                        // attn[i][j] access: i * attn_stride0 + j
-                        int64_t attn_idx = i * attn_stride0 + j;
-                        if (attn_idx >= chunk_size * chunk_size) {
-                            // printf("ERROR: attn access out of bounds: attn_idx=%ld, max=%ld\n",
-                            //        attn_idx, chunk_size * chunk_size);
-                            continue;
-                        }
-                        
-                        // v_beta[j][d] access: j * v_beta_stride1 + d
-                        int64_t v_beta_idx = j * v_beta_stride1 + d;
-                        if (v_beta_idx >= chunk_size * S_v) {
-                            // printf("ERROR: v_beta access out of bounds: v_beta_idx=%ld, max=%ld\n",
-                            //        v_beta_idx, chunk_size * S_v);
-                            continue;
-                        }
-                        
-                        float attn_val = attn_precomputed[attn_idx];
-                        float v_beta_val = v_beta_ptr[v_beta_idx];
-                        
-                        if (isnan(attn_val) || isnan(v_beta_val)) {
-                            // printf("ERROR: NaN detected in matrix multiplication: attn=%f, v_beta=%f\n", attn_val, v_beta_val);
-                            continue;
-                        }
-                        
-                        // Debug: Print first few multiplications for validation
-                        if (seq_idx == 0 && head_idx == 0 && i < 2 && d < 2 && j < 2) {
-                            // printf("C++ DEBUG value[%ld][%ld]: attn[%ld][%ld]=%f * v_beta[%ld][%ld]=%f = %f\n",
-                            //        i, d, i, j, attn_val, j, d, v_beta_val, attn_val * v_beta_val);
-                        }
-                        sum += attn_val * v_beta_val;
-                    }
-                    value[i * S_v + d] = sum;
-                    // Debug: Print first few results for validation
-                    if (seq_idx == 0 && head_idx == 0 && i < 2 && d < 2) {
-                        // printf("C++ DEBUG value[%ld][%ld] = sum = %f\n", i, d, value[i * S_v + d]);
-                    }
-                }
-            }
-            
-            float value_sum = 0.0f;
-            for (int64_t i = 0; i < chunk_size * S_v; i++) {
-                value_sum += value[i];
-            }
-            // printf("C++ PRE-CHUNK value_sum = %f (head %ld, seq %ld)\n", value_sum, head_idx, seq_idx);
-            
-            // Compute k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) with deterministic tensor access
-            // printf("C++ DEBUG: Computing k_cumdecay = attn @ (k_beta * g.exp()) with deterministic tensor access\n");
-            float * k_cumdecay = (float *) malloc(chunk_size * S_k * sizeof(float));
-            
-            // Calculate tensor strides for deterministic access
-            const int64_t k_beta_stride0 = src7->nb[0] / sizeof(float);  // S_k dimension
-            const int64_t k_beta_stride1 = src7->nb[1] / sizeof(float);  // chunk_size dimension
-            const int64_t g_stride0 = src3->nb[0] / sizeof(float);  // chunk_size dimension
-            
-            // printf("C++ DEBUG: k_cumdecay tensor strides: k_beta_stride0=%ld, k_beta_stride1=%ld, g_stride0=%ld\n",
-            //        k_beta_stride0, k_beta_stride1, g_stride0);
-            
-            for (int64_t i = 0; i < chunk_size; i++) {
-                for (int64_t d = 0; d < S_k; d++) {
-                    float sum = 0.0f;
-                    for (int64_t j = 0; j < chunk_size; j++) {
-                        // Deterministic tensor access using stride calculations
-                        // attn[i][j] access: i * attn_stride0 + j
-                        int64_t attn_idx = i * attn_stride0 + j;
-                        if (attn_idx >= chunk_size * chunk_size) {
-                            // printf("ERROR: attn access out of bounds in k_cumdecay: attn_idx=%ld, max=%ld\n",
-                            //        attn_idx, chunk_size * chunk_size);
-                            continue;
-                        }
-                        
-                        // k_beta[j][d] access: j * k_beta_stride1 + d
-                        int64_t k_beta_idx = j * k_beta_stride1 + d;
-                        if (k_beta_idx >= chunk_size * S_k) {
-                            // printf("ERROR: k_beta access out of bounds: k_beta_idx=%ld, max=%ld\n",
-                            //        k_beta_idx, chunk_size * S_k);
-                            continue;
-                        }
-                        
-                        // g tensor layout: [chunk_size, n_heads, n_seqs, 1]
-                        // Deterministic access: g[j + head_idx * chunk_size + seq_idx * chunk_size * n_heads]
-                        int64_t g_idx = j + head_idx * chunk_size + seq_idx * chunk_size * H_v;
-                        if (g_idx >= chunk_size * H_v * n_seqs) {
-                            // printf("ERROR: g tensor out of bounds: g_idx=%ld, max=%ld\n",
-                            //        g_idx, chunk_size * H_v * n_seqs);
-                            continue;
-                        }
-                        
-                        float attn_val = attn_precomputed[attn_idx];
-                        float k_beta_val = k_beta_ptr[k_beta_idx];
-                        float g_val = g_vals[g_idx];
-                        float g_exp = expf(g_val);
-                        
-                        if (isnan(attn_val) || isnan(k_beta_val) || isnan(g_val) || isnan(g_exp)) {
-                            // printf("ERROR: NaN detected in k_cumdecay multiplication: attn=%f, k_beta=%f, g_val=%f, g_exp=%f\n",
-                            //        attn_val, k_beta_val, g_val, g_exp);
-                            continue;
-                        }
-                        
-                        // Debug: Print first few multiplications for validation
-                        if (seq_idx == 0 && head_idx == 0 && i < 2 && d < 2 && j < 2) {
-                            // printf("C++ DEBUG k_cumdecay[%ld][%ld]: attn[%ld][%ld]=%f * k_beta[%ld][%ld]=%f * g_exp[%ld]=%f = %f\n",
-                            //        i, d, i, j, attn_val, j, d, k_beta_val, j, g_exp,
-                            //        attn_val * k_beta_val * g_exp);
-                        }
-                        sum += attn_val * k_beta_val * g_exp;
-                    }
-                    k_cumdecay[i * S_k + d] = sum;
-                    // Debug: Print first few results for validation
-                    if (seq_idx == 0 && head_idx == 0 && i < 2 && d < 2) {
-                        // printf("C++ DEBUG k_cumdecay[%ld][%ld] = sum = %f\n", i, d, k_cumdecay[i * S_k + d]);
-                    }
-                }
-            }
-            
-            float k_cumdecay_sum = 0.0f;
-            for (int64_t i = 0; i < chunk_size * S_k; i++) {
-                k_cumdecay_sum += k_cumdecay[i];
-            }
-            
-            free(value);
-            free(k_cumdecay);
-        }
-    }
-
-    // Initialize last_recurrent_state
-    // last_recurrent_state = torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value)
-    // if initial_state is None else initial_state.to(value)
-    float * initial_state_ptr = (float *) src4->data;
-    
-    // If initial_state is provided, copy it to new_state, otherwise initialize new_state to zeros
-    // Output state layout: [S_v * H_v, S_v * n_seqs, 1, 1]
-    // This means: [n_heads * v_head_dim, v_head_dim * n_seqs, 1, 1]
-    // So total size is: S_v * H_v * S_v * n_seqs
-    if (initial_state_ptr != NULL) {
-        // Copy initial state to new state
-        memcpy(new_state, initial_state_ptr, S_v * H_v * S_v * n_seqs * sizeof(float));
-    } else {
-        // Initialize new state to zeros
-        memset(new_state, 0, S_v * H_v * S_v * n_seqs * sizeof(float));
-    }
-    
-    // Process each chunk for the main computation
-    // Following the Python reference implementation exactly
-    for (int64_t chunk_idx = 0; chunk_idx < num_chunks; chunk_idx++) {
-        for (int64_t seq_idx = 0; seq_idx < n_seqs; seq_idx++) {
-            // Process each head in this chunk
-            for (int64_t head_idx = 0; head_idx < H_v; head_idx++) {
-                // Get the recurrent state for this sequence and head
-                // Output state layout: [S_v * H_v, S_v * n_seqs, 1, 1]
-                // For each head and sequence, we need to find the correct state slice
-                // The state is organized as: [head_idx * S_v * S_v * n_seqs + seq_idx * S_v * S_v]
-                float * last_recurrent_state = new_state + (head_idx * S_v * S_v * n_seqs) + (seq_idx * S_v * S_v);
-                // printf("\n=== C++ Processing chunk %ld, seq %ld, head %ld ===\n", chunk_idx, seq_idx, head_idx);
-                // Get pointers to current chunk data for this head
-                // GGML tensor layout: [S_k/S_v, chunk_size, H_v, n_seqs]
-                // Python layout: [batch_size, sequence_length, num_heads, k_head_dim]
-                // After transpose: [batch_size, num_heads, sequence_length, k_head_dim]
-                
-                // For GGML: ne[0]=S_k/S_v, ne[1]=chunk_size, ne[2]=H_v, ne[3]=n_seqs
-                // nb[0]=sizeof(float)*S_k/S_v, nb[1]=sizeof(float)*S_k/S_v*chunk_size, etc.
-                
-                const int64_t q_offset = (seq_idx * src0->nb[3] + head_idx * src0->nb[2]) / sizeof(float);
-                const int64_t k_offset = (seq_idx * src1->nb[3] + head_idx * src1->nb[2]) / sizeof(float);
-                const int64_t v_offset = (seq_idx * src2->nb[3] + head_idx * src2->nb[2]) / sizeof(float);
-                const int64_t g_offset = (seq_idx * src3->nb[3] + head_idx * src3->nb[2]) / sizeof(float);
-                const int64_t v_beta_offset = (seq_idx * src6->nb[3] + head_idx * src6->nb[2]) / sizeof(float);
-                const int64_t k_beta_offset = (seq_idx * src7->nb[3] + head_idx * src7->nb[2]) / sizeof(float);
-                const int64_t attn_offset = (seq_idx * src8->nb[3] + head_idx * src8->nb[2]) / sizeof(float);
-                const int64_t decay_mask_offset = (seq_idx * src5->nb[3] + head_idx * src5->nb[2]) / sizeof(float);
-                
-                // Calculate strides for each tensor
-                const int64_t q_stride0 = src0->nb[0] / sizeof(float);  // S_k
-                const int64_t q_stride1 = src0->nb[1] / sizeof(float);  // chunk_size
-                const int64_t k_stride0 = src1->nb[0] / sizeof(float);  // S_k
-                const int64_t k_stride1 = src1->nb[1] / sizeof(float);  // chunk_size
-                const int64_t v_stride0 = src2->nb[0] / sizeof(float);  // S_v
-                const int64_t v_stride1 = src2->nb[1] / sizeof(float);  // chunk_size
-                const int64_t g_stride0 = src3->nb[0] / sizeof(float);  // chunk_size
-                const int64_t v_beta_stride0 = src6->nb[0] / sizeof(float);  // S_v
-                const int64_t v_beta_stride1 = src6->nb[1] / sizeof(float);  // chunk_size
-                const int64_t k_beta_stride0 = src7->nb[0] / sizeof(float);  // S_k
-                const int64_t k_beta_stride1 = src7->nb[1] / sizeof(float);  // chunk_size
-                const int64_t attn_stride0 = src8->nb[0] / sizeof(float);  // chunk_size
-                const int64_t attn_stride1 = src8->nb[1] / sizeof(float);  // chunk_size
-                const int64_t decay_mask_stride0 = src5->nb[0] / sizeof(float);  // chunk_size
-                const int64_t decay_mask_stride1 = src5->nb[1] / sizeof(float);  // chunk_size
-                
-                // Get decay mask for this chunk and head
-                float * decay_mask = (float *) src5->data + decay_mask_offset;
-                
-                // Use pre-computed attention matrix from src8 (after triangular updates)
-                // The Python reference computes triangular updates before the chunk loop
-                float * attn_precomputed = (float *) src8->data + attn_offset;
-                
-                // Debug: print precomputed attention matrix values
-                float attn_precomputed_sum = 0.0f;
-                float attn_precomputed_max = 0.0f;
-                for (int64_t i = 0; i < chunk_size * chunk_size; i++) {
-                    attn_precomputed_sum += attn_precomputed[i];
-                    attn_precomputed_max = fmaxf(attn_precomputed_max, fabsf(attn_precomputed[i]));
-                }
-                // Get g values for this chunk and head
-                float * g_vals = (float *) src3->data + g_offset;
-                
-                // Get v_beta and k_beta for this chunk and head
-                float * v_beta_ptr = (float *) src6->data + v_beta_offset;
-                float * k_beta_ptr = (float *) src7->data + k_beta_offset;
-                
-                // Debug: print v_beta and k_beta values
-                float v_beta_sum = 0.0f;
-                for (int64_t i = 0; i < chunk_size * S_v; i++) {
-                    v_beta_sum += v_beta_ptr[i];
-                }
-                // printf("C++ v_beta_sum = %f\n", v_beta_sum);
-                
-                float k_beta_sum = 0.0f;
-                for (int64_t i = 0; i < chunk_size * S_k; i++) {
-                    k_beta_sum += k_beta_ptr[i];
-                }
-                // printf("C++ k_beta_sum = %f\n", k_beta_sum);
-                
-                // Compute value = attn_precomputed @ v_beta with bounds checking
-                float * value = (float *) malloc(chunk_size * S_v * sizeof(float));
-                for (int64_t i = 0; i < chunk_size; i++) {
-                    for (int64_t d = 0; d < S_v; d++) {
-                        float sum = 0.0f;
-                        for (int64_t j = 0; j < chunk_size; j++) {
-                            // Bounds checking for matrix multiplication
-                            if (i * chunk_size + j >= chunk_size * chunk_size ||
-                                j * v_beta_stride0 + d * v_beta_stride1 >= chunk_size * S_v) {
-                                // printf("ERROR: Chunk value matrix multiplication out of bounds: i=%ld, j=%ld, d=%ld\n", i, j, d);
-                                continue;
-                            }
-                            
-                            float attn_val = attn_precomputed[i * chunk_size + j];
-                            float v_beta_val = v_beta_ptr[j * v_beta_stride0 + d * v_beta_stride1];
-                            
-                            // Check for NaN values to prevent propagation
-                            if (isnan(attn_val) || isnan(v_beta_val)) {
-                                // printf("ERROR: NaN detected in chunk value multiplication: attn=%f, v_beta=%f\n", attn_val, v_beta_val);
-                                continue;
-                            }
-                            
-                            sum += attn_val * v_beta_val;
-                        }
-                        value[i * S_v + d] = sum;
-                    }
-                }
-                
-                float value_sum = 0.0f;
-                for (int64_t i = 0; i < chunk_size * S_v; i++) {
-                    value_sum += value[i];
-                }
-                
-                // Compute k_cumdecay = attn_precomputed @ (k_beta * g.exp().unsqueeze(-1)) with bounds checking
-                float * k_cumdecay = (float *) malloc(chunk_size * S_k * sizeof(float));
-                for (int64_t i = 0; i < chunk_size; i++) {
-                    for (int64_t d = 0; d < S_k; d++) {
-                        float sum = 0.0f;
-                        for (int64_t j = 0; j < chunk_size; j++) {
-                            // Bounds checking for matrix multiplication
-                            if (i * chunk_size + j >= chunk_size * chunk_size ||
-                                j * k_beta_stride0 + d * k_beta_stride1 >= chunk_size * S_k ||
-                                j * g_stride0 >= chunk_size) {
-                                // printf("ERROR: Chunk k_cumdecay matrix multiplication out of bounds: i=%ld, j=%ld, d=%ld\n", i, j, d);
-                                continue;
-                            }
-                            
-                            float attn_val = attn_precomputed[i * chunk_size + j];
-                            float k_beta_val = k_beta_ptr[j * k_beta_stride0 + d * k_beta_stride1];
-                            float g_val = g_vals[j * g_stride0];
-                            float g_exp = expf(g_val);
-                            
-                            // Check for NaN values to prevent propagation
-                            if (isnan(attn_val) || isnan(k_beta_val) || isnan(g_val) || isnan(g_exp)) {
-                                // printf("ERROR: NaN detected in chunk k_cumdecay multiplication: attn=%f, k_beta=%f, g_val=%f, g_exp=%f\n",
-                                //        attn_val, k_beta_val, g_val, g_exp);
-                                continue;
-                            }
-                            
-                            sum += attn_val * k_beta_val * g_exp;
-                        }
-                        k_cumdecay[i * S_k + d] = sum;
-                    }
-                }
-                
-                float k_cumdecay_sum = 0.0f;
-                for (int64_t i = 0; i < chunk_size * S_k; i++) {
-                    k_cumdecay_sum += k_cumdecay[i];
-                }
+    float * v_beta_data = (float *) src6->data;
+    float * k_beta_data = (float *) src7->data;
+    float * g_data = (float *) src3->data;
+    float * q_data = (float *) src0->data;
+    float * k_data = (float *) src1->data;
+    //float * v_data = (float *) src2->data;
+    float * state_data = (float *) src4->data;
+    float * decay_mask_data = (float *) src5->data;
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_contiguous(src2));
+    GGML_ASSERT(ggml_is_contiguous(src3));
+    GGML_ASSERT(ggml_is_contiguous(src4));
+    GGML_ASSERT(ggml_is_contiguous(src5));
+    GGML_ASSERT(ggml_is_contiguous(src6));
+    GGML_ASSERT(ggml_is_contiguous(src7));
+    GGML_ASSERT(ggml_is_contiguous(src8));
+
+    for (int64_t seq = 0; seq < n_seqs; seq++) {
+        for (int64_t head = 0; head < H_v; head++) {
+            for (int64_t chunk = 0; chunk < num_chunks; chunk++) {
+                float * attn_data_for_chs = attn_data + (src8->nb[3] / sizeof(float)) * seq + (src8->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
+                float * value_chunk = (float *) malloc(S_v * chunk_size * H_v * n_seqs * sizeof(float));
+                float * k_cumdecay = (float *) malloc(S_v * chunk_size * H_v * n_seqs * sizeof(float));
+                delta_apply_triangular_updates_f32(attn_data_for_chs, chunk_size);
+                delta_add_identity_matrix_f32(attn_data_for_chs, chunk_size);
+                // Calculate the correct v_beta and k_beta pointers for this head and sequence
+                float * v_beta_chunk = v_beta_data + (src6->nb[3] / sizeof(float)) * seq + (src6->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
+                float * k_beta_chunk = k_beta_data + (src7->nb[3] / sizeof(float)) * seq + (src7->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
+                // The g tensor has dimensions [8, 64, 2, 1] = [features, tokens, heads, sequences]
+                // We need to access the correct head data
+                // For each head, we need to access the correct feature for all tokens in the chunk
+                // Let's try accessing feature index chunk (since we have 8 features and chunk=0)
+                float * g_chunk = g_data + (src3->nb[3] / sizeof(float)) * seq + (src3->nb[2] / sizeof(float)) * head + (src3->nb[1] / sizeof(float)) * (chunk * chunk_size);
+                delta_compute_value_f32(attn_data_for_chs, v_beta_chunk, value_chunk, chunk_size, S_v);
+                delta_compute_k_cumdecay_f32(attn_data_for_chs, k_beta_chunk, g_chunk, k_cumdecay, chunk_size, S_k);
+                // Now compute the per-chunk-specific part (corresponding to the inner loop in Python)
+                float * q_chunk = q_data + (src0->nb[3] / sizeof(float)) * seq + (src0->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
+                float * k_chunk = k_data + (src1->nb[3] / sizeof(float)) * seq + (src1->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
+                float * decay_mask_chunk = decay_mask_data + (src5->nb[3] / sizeof(float)) * seq + (src5->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
+                float * k_cumdecay_chunk = k_cumdecay + (S_v * chunk_size * H_v) * seq + (S_v * chunk_size) * head;
                 
                 
-                // Compute fresh attention matrix for this chunk, just like Python reference line 118
-                // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
+                // Allocate temporary variables for the loop
                 float * attn = (float *) malloc(chunk_size * chunk_size * sizeof(float));
                 float * attn = (float *) malloc(chunk_size * chunk_size * sizeof(float));
+                float * v_prime = (float *) malloc(chunk_size * S_v * sizeof(float));
+                float * v_new = (float *) malloc(chunk_size * S_v * sizeof(float));
+                float * attn_inter = (float *) malloc(chunk_size * S_v * sizeof(float));
+                float * core_attn_out_chunk = (float *) malloc(chunk_size * S_v * sizeof(float));
+                float * g_last = (float *) malloc(sizeof(float));
+                float * g_diff_exp = (float *) malloc(chunk_size * sizeof(float));
+                bool * mask = (bool *) malloc(chunk_size * chunk_size * sizeof(bool));
                 
                 
-                // First compute q_i @ k_i.transpose(-1, -2) with bounds checking
+                // Create upper triangular mask for causal attention (exclude diagonal)
                 for (int64_t i = 0; i < chunk_size; i++) {
                 for (int64_t i = 0; i < chunk_size; i++) {
                     for (int64_t j = 0; j < chunk_size; j++) {
                     for (int64_t j = 0; j < chunk_size; j++) {
-                        float sum = 0.0f;
-                        for (int64_t d = 0; d < S_k; d++) {
-                            // Bounds checking for q and k tensor access
-                            int64_t q_idx = q_offset + d * q_stride0 + i * q_stride1;
-                            int64_t k_idx = k_offset + d * k_stride0 + j * k_stride1;
-                            
-                            if (q_idx >= src0->ne[0] * src0->ne[1] * src0->ne[2] * src0->ne[3] ||
-                                k_idx >= src1->ne[0] * src1->ne[1] * src1->ne[2] * src1->ne[3]) {
-                                // printf("ERROR: q/k tensor access out of bounds: q_idx=%ld, k_idx=%ld\n", q_idx, k_idx);
-                                continue;
-                            }
-                            
-                            float q_val = ((float *)src0->data)[q_idx];
-                            float k_val = ((float *)src1->data)[k_idx];
-                            
-                            // Check for NaN values to prevent propagation
-                            if (isnan(q_val) || isnan(k_val)) {
-                                // printf("ERROR: NaN detected in q@k multiplication: q_val=%f, k_val=%f\n", q_val, k_val);
-                                continue;
-                            }
-                            
-                            sum += q_val * k_val;
-                        }
-                        
-                        // Bounds checking for decay mask access
-                        if (i * chunk_size + j >= chunk_size * chunk_size) {
-                            // printf("ERROR: decay mask access out of bounds: i=%ld, j=%ld\n", i, j);
-                            attn[i * chunk_size + j] = 0.0f;
-                            continue;
-                        }
-                        
-                        float decay_val = decay_mask[i * chunk_size + j];
-                        if (isnan(decay_val)) {
-                            // printf("ERROR: NaN detected in decay mask: decay_val=%f\n", decay_val);
-                            attn[i * chunk_size + j] = 0.0f;
-                            continue;
-                        }
-                        
-                        attn[i * chunk_size + j] = sum * decay_val;
-                    }
-                }
-                
-                // Apply upper triangular mask (masked_fill_(mask, 0))
-                for (int64_t i = 0; i < chunk_size; i++) {
-                    for (int64_t j = i + 1; j < chunk_size; j++) {
-                        attn[i * chunk_size + j] = 0.0f;
+                        mask[i * chunk_size + j] = (j > i); // True for upper triangular (excluding diagonal)
                     }
                     }
                 }
                 }
+                                                
+                // Python loop implementation:
+                // q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
+                // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
+                delta_compute_q_k_attn_f32(q_chunk, k_chunk, decay_mask_chunk, attn, mask, chunk_size, S_k);
                 
                 
-                // Compute v_prime = k_cumdecay @ last_recurrent_state
-                float * v_prime = (float *) malloc(chunk_size * S_v * sizeof(float));
-                for (int64_t i = 0; i < chunk_size; i++) {
-                    for (int64_t d = 0; d < S_v; d++) {
-                        float sum = 0.0f;
-                        for (int64_t k = 0; k < S_k; k++) {
-                            sum += k_cumdecay[i * S_k + k] * last_recurrent_state[k * S_v + d];
-                        }
-                        v_prime[i * S_v + d] = sum;
-                    }
-                }
+                // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
+                // Calculate the correct state pointer for this head and sequence
+                float * head_state_data = state_data + (seq * S_v * S_v * H_v) + (head * S_v * S_v);
                 
                 
-                // Debug prints for key intermediate values
-                float attn_sum = 0.0f;
-                for (int64_t i = 0; i < chunk_size * chunk_size; i++) {
-                    attn_sum += attn[i];
-                }
                 
                 
+                delta_matmul_state_f32(k_cumdecay_chunk, head_state_data, v_prime, chunk_size, S_k, S_v);
                 
                 
-                float v_prime_sum = 0.0f;
-                for (int64_t i = 0; i < chunk_size * S_v; i++) {
-                    v_prime_sum += v_prime[i];
-                }
+                // v_new = v_i - v_prime
+                delta_tensor_subtract_f32(value_chunk, v_prime, v_new, chunk_size * S_v);
                 
                 
-                // Compute v_new = v_i - v_prime
-                float * v_new = (float *) malloc(chunk_size * S_v * sizeof(float));
+                // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
+                float * q_g_exp = (float *) malloc(chunk_size * S_k * sizeof(float));
                 for (int64_t i = 0; i < chunk_size; i++) {
                 for (int64_t i = 0; i < chunk_size; i++) {
-                    for (int64_t d = 0; d < S_v; d++) {
-                        v_new[i * S_v + d] = ((float *)src2->data)[v_offset + d * v_stride0 + i * v_stride1] - v_prime[i * S_v + d];
+                    for (int64_t d = 0; d < S_k; d++) {
+                        int64_t q_idx = i * S_k + d;
+                        q_g_exp[q_idx] = q_chunk[q_idx] * expf(g_chunk[i]);
                     }
                     }
                 }
                 }
+                delta_matmul_state_f32(q_g_exp, head_state_data, attn_inter, chunk_size, S_k, S_v);
                 
                 
-                float v_new_sum = 0.0f;
-                for (int64_t i = 0; i < chunk_size * S_v; i++) {
-                    v_new_sum += v_new[i];
-                }
+                // core_attn_out[:, :, i] = attn_inter + attn @ v_new
+                float * attn_v_new = (float *) malloc(chunk_size * S_v * sizeof(float));
+                delta_matmul_state_f32(attn, v_new, attn_v_new, chunk_size, chunk_size, S_v);
+                delta_tensor_add_f32(attn_inter, attn_v_new, core_attn_out_chunk, chunk_size * S_v);
                 
                 
-                // Compute attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
-                float * attn_inter = (float *) malloc(chunk_size * S_v * sizeof(float));
+                // Store the result in the output tensor
                 for (int64_t i = 0; i < chunk_size; i++) {
                 for (int64_t i = 0; i < chunk_size; i++) {
                     for (int64_t d = 0; d < S_v; d++) {
                     for (int64_t d = 0; d < S_v; d++) {
-                        float sum = 0.0f;
-                        float g_exp = expf(g_vals[i * g_stride0]);
-                        for (int64_t k = 0; k < S_k; k++) {
-                            sum += ((float *)src0->data)[q_offset + k * q_stride0 + i * q_stride1] * g_exp * last_recurrent_state[k * S_v + d];
-                        }
-                        attn_inter[i * S_v + d] = sum;
+                        if ((chunk * chunk_size + i) >= n_tokens) continue;
+                        int64_t output_idx = seq * (n_tokens * S_v * H_v) + head * (n_tokens * S_v) + (chunk * chunk_size + i) * S_v + d; 
+                        output[output_idx] = core_attn_out_chunk[i * S_v + d];
                     }
                     }
                 }
                 }
                 
                 
-                float attn_inter_sum = 0.0f;
-                for (int64_t i = 0; i < chunk_size * S_v; i++) {
-                    attn_inter_sum += attn_inter[i];
-                }
-                
-                // Compute core_attn_out = attn_inter + attn @ v_new
-                // Output tensor layout: [S_v * H_v, n_tokens, 1, 1]
-                const int64_t out_offset = head_idx * (S_v * n_tokens) + chunk_idx * (S_v * chunk_size);
-                float * core_attn_out = output + out_offset;
+                // g_last = g[:, :, i, -1, None, None].exp()
+                *g_last = expf(g_chunk[chunk_size - 1]);
                 
                 
+                // Prepare g_diff_exp = (g[:, :, i, -1, None] - g[:, :, i]).exp()
+                float g_last_val = g_chunk[chunk_size - 1];
                 for (int64_t i = 0; i < chunk_size; i++) {
                 for (int64_t i = 0; i < chunk_size; i++) {
-                    for (int64_t d = 0; d < S_v; d++) {
-                        float sum = 0.0f;
-                        for (int64_t j = 0; j < chunk_size; j++) {
-                            sum += attn[i * chunk_size + j] * v_new[j * S_v + d];
-                        }
-                        core_attn_out[i * S_v + d] = attn_inter[i * S_v + d] + sum;
-                    }
+                    g_diff_exp[i] = expf(g_last_val - g_chunk[i]);
                 }
                 }
                 
                 
-                float core_attn_out_sum = 0.0f;
-                for (int64_t i = 0; i < chunk_size * S_v; i++) {
-                    core_attn_out_sum += core_attn_out[i];
-                }                
+                // last_recurrent_state = (
+                //     last_recurrent_state * g_last
+                //     + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
+                // )
+                float * new_recurrent_state = (float *) malloc(S_v * S_v * sizeof(float));
                 
                 
-                float g_last = g_vals[chunk_size - 1];
-                float g_last_exp = expf(g_last);
                 
                 
-                // First part: last_recurrent_state * g_last_exp
-                for (int64_t k = 0; k < S_k; k++) {
-                    for (int64_t v = 0; v < S_v; v++) {
-                        last_recurrent_state[k * S_v + v] *= g_last_exp;
-                    }
-                }
+                delta_update_recurrent_state_f32(head_state_data, g_last, k_chunk, g_diff_exp, v_new,
+                                                 new_recurrent_state, chunk_size, S_v, S_v);
                 
                 
-                // Second part: (k_i * (g_last - g).exp()).transpose(-1, -2) @ v_new
-                float * k_gated = (float *) malloc(chunk_size * S_k * sizeof(float));
-                for (int64_t i = 0; i < chunk_size; i++) {
-                    float g_diff_exp = expf(g_last - g_vals[i]);
-                    for (int64_t d = 0; d < S_k; d++) {
-                        k_gated[i * S_k + d] = ((float *)src1->data)[k_offset + d * k_stride0 + i * k_stride1] * g_diff_exp;
+                
+                // Store the new state
+                for (int64_t i = 0; i < S_v; i++) {
+                    for (int64_t j = 0; j < S_v; j++) {
+                        int64_t state_idx = seq * S_v * S_v * H_v + head * S_v * S_v + i * S_v + j;
+                        new_state[state_idx] = new_recurrent_state[i * S_v + j];
                     }
                     }
                 }
                 }
                 
                 
-                // Compute k_gated.T @ v_new
-                for (int64_t k = 0; k < S_k; k++) {
-                    for (int64_t v = 0; v < S_v; v++) {
-                        float sum = 0.0f;
-                        for (int64_t i = 0; i < chunk_size; i++) {
-                            sum += k_gated[i * S_k + k] * v_new[i * S_v + v];
-                        }
-                        last_recurrent_state[k * S_v + v] += sum;
+                // Update the original state tensor with the new state for the next chunk
+                for (int64_t i = 0; i < S_v; i++) {
+                    for (int64_t j = 0; j < S_v; j++) {
+                        int64_t state_idx = i * S_v + j;
+                        head_state_data[state_idx] = new_recurrent_state[state_idx];
                     }
                     }
                 }
                 }
                 
                 
+                // Recalculate head_state_data to point to the updated state for the next iteration
+                head_state_data = state_data + (seq * S_v * S_v * H_v) + (head * S_v * S_v);
+                
                 // Free temporary memory
                 // Free temporary memory
                 free(attn);
                 free(attn);
-                free(value);
-                free(k_cumdecay);
                 free(v_prime);
                 free(v_prime);
                 free(v_new);
                 free(v_new);
                 free(attn_inter);
                 free(attn_inter);
-                free(k_gated);
+                free(core_attn_out_chunk);
+                free(g_last);
+                free(g_diff_exp);
+                free(mask);
+                free(q_g_exp);
+                free(attn_v_new);
+                free(new_recurrent_state);
+                
+                // Free the value and k_cumdecay allocated at the beginning of the loop
+                free(value_chunk);
+                free(k_cumdecay);
             }
             }
         }
         }
-    }
+    }    
 }
 }
 
 
 // ggml_compute_forward_rwkv_wkv7
 // ggml_compute_forward_rwkv_wkv7

+ 7 - 5
src/models/llm_build_qwen3next.cpp

@@ -229,8 +229,8 @@ struct ggml_tensor * llm_build_qwen3next::delta_net(
     beta = ggml_cont(ctx, ggml_permute(ctx, beta, 1, 2, 0, 3));
     beta = ggml_cont(ctx, ggml_permute(ctx, beta, 1, 2, 0, 3));
     cb(beta, "beta_reshape", il);
     cb(beta, "beta_reshape", il);
 
 
-    g = ggml_cont(ctx, ggml_permute(ctx, g, 1, 0, 3, 2));
-    cb(g, "g_reshape", il);
+    g = ggml_cont(ctx, ggml_permute(ctx, g, 2, 0, 3, 1));
+    cb(g, "g_permute", il);
 
 
     // Then, pad the second dimension (n_tokens) to chunk_size
     // Then, pad the second dimension (n_tokens) to chunk_size
     q = ggml_pad(ctx, q, 0, pad_size, 0, 0); 
     q = ggml_pad(ctx, q, 0, pad_size, 0, 0); 
@@ -250,7 +250,7 @@ struct ggml_tensor * llm_build_qwen3next::delta_net(
     GGML_ASSERT(k->ne[1] % GGML_DELTA_NET_CHUNK == 0 && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
     GGML_ASSERT(k->ne[1] % GGML_DELTA_NET_CHUNK == 0 && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
     GGML_ASSERT(v->ne[1] % GGML_DELTA_NET_CHUNK == 0 && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
     GGML_ASSERT(v->ne[1] % GGML_DELTA_NET_CHUNK == 0 && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
     GGML_ASSERT(beta->ne[0] % GGML_DELTA_NET_CHUNK == 0 && beta->ne[1] == H_k && beta->ne[2] == 1 && beta->ne[3] == n_seqs);
     GGML_ASSERT(beta->ne[0] % GGML_DELTA_NET_CHUNK == 0 && beta->ne[1] == H_k && beta->ne[2] == 1 && beta->ne[3] == n_seqs);
-    GGML_ASSERT(g->ne[0] % GGML_DELTA_NET_CHUNK == 0 && g->ne[1] == H_k && g->ne[2] == 1 && g->ne[3] == n_seqs);
+    GGML_ASSERT(g->ne[0] % GGML_DELTA_NET_CHUNK == 0 && g->ne[2] == H_k && g->ne[1] == 1 && g->ne[3] == n_seqs);
 
 
     ggml_tensor * beta_unsq = ggml_cont_4d(ctx, beta, 1, GGML_DELTA_NET_CHUNK * num_chunks, H_k, n_seqs);
     ggml_tensor * beta_unsq = ggml_cont_4d(ctx, beta, 1, GGML_DELTA_NET_CHUNK * num_chunks, H_k, n_seqs);
     ggml_tensor * beta_bcast = ggml_repeat_4d(ctx, beta_unsq, S_v, GGML_DELTA_NET_CHUNK * num_chunks, H_k, n_seqs);
     ggml_tensor * beta_bcast = ggml_repeat_4d(ctx, beta_unsq, S_v, GGML_DELTA_NET_CHUNK * num_chunks, H_k, n_seqs);
@@ -265,6 +265,8 @@ struct ggml_tensor * llm_build_qwen3next::delta_net(
     cb(k_beta, "k_beta", il);
     cb(k_beta, "k_beta", il);
     k = ggml_reshape_4d(ctx, k, S_v, GGML_DELTA_NET_CHUNK, H_k * num_chunks, n_seqs);
     k = ggml_reshape_4d(ctx, k, S_v, GGML_DELTA_NET_CHUNK, H_k * num_chunks, n_seqs);
     cb(k_beta, "k_reshape", il);
     cb(k_beta, "k_reshape", il);
+    g = ggml_reshape_4d(ctx, g, GGML_DELTA_NET_CHUNK, 1, H_k * num_chunks, n_seqs);
+    cb(g, "g_reshape", il);
     struct ggml_tensor * g_cumsum = ggml_cumsum(ctx, g);
     struct ggml_tensor * g_cumsum = ggml_cumsum(ctx, g);
     cb(g_cumsum, "g_cumsum", il);
     cb(g_cumsum, "g_cumsum", il);
         
         
@@ -298,7 +300,7 @@ struct ggml_tensor * llm_build_qwen3next::delta_net(
     cb(attn, "attn_in", il);
     cb(attn, "attn_in", il);
 
 
     // We'll be returning the result as a 1D tensor due to the dimensions mismatch of the state and output tensors
     // We'll be returning the result as a 1D tensor due to the dimensions mismatch of the state and output tensors
-    const int64_t ne[1] = { (S_v * H_v * n_tokens) + (S_v * S_v * H_v * n_seqs) };
+    const int64_t ne[1] = { (S_v * H_v * n_tokens * n_seqs ) + (S_v * S_v * H_v * n_seqs) };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 1, ne);
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 1, ne);
 
 
     ggml_set_op_params_i32(result, 0, H_v);
     ggml_set_op_params_i32(result, 0, H_v);
@@ -548,7 +550,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
         ggml_view_1d(ctx0, attn_out, output_flat_size, 0);
         ggml_view_1d(ctx0, attn_out, output_flat_size, 0);
     cb(attn_out_1d, "attn_out_1d", il);
     cb(attn_out_1d, "attn_out_1d", il);
     
     
-    ggml_tensor * attn_out_final = ggml_cont_4d(ctx0, attn_out_1d, head_dim, n_heads, n_tokens, n_seqs);
+    ggml_tensor * attn_out_final = ggml_cont_4d(ctx0, attn_out_1d, head_dim, n_tokens, n_heads, n_seqs);
     cb(attn_out_final, "attn_out_final", il);
     cb(attn_out_final, "attn_out_final", il);
    
    
     // Extract the state part (second part of the concatenated tensor)
     // Extract the state part (second part of the concatenated tensor)