Просмотр исходного кода

Delta.net chunked reimplemented

Piotr Wilkin 3 месяцев назад
Родитель
Сommit
75586ea36e
3 измененных файлов с 521 добавлено и 173 удалено
  1. 1 1
      ggml/src/ggml-cpu/ggml-cpu.c
  2. 516 171
      ggml/src/ggml-cpu/ops.cpp
  3. 4 1
      src/models/llm_build_qwen3next.cpp

+ 1 - 1
ggml/src/ggml-cpu/ggml-cpu.c

@@ -2295,6 +2295,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_POOL_2D:
         case GGML_OP_POOL_2D_BACK:
         case GGML_OP_DELTA_NET_RECURRENT:
+        case GGML_OP_DELTA_NET:
             {
                 n_tasks = 1;
             } break;
@@ -2312,7 +2313,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_RWKV_WKV6:
         case GGML_OP_GATED_LINEAR_ATTN:
         case GGML_OP_RWKV_WKV7:
-        case GGML_OP_DELTA_NET:
             {
                 n_tasks = n_threads;
             } break;

+ 516 - 171
ggml/src/ggml-cpu/ops.cpp

@@ -10529,7 +10529,69 @@ static void delta_compute_k_beta_key_t_f32(const float * k_beta, const float * k
     }
 }
 
-// Helper function to apply triangular updates
+// Helper function to apply triangular updates to entire chunk (all sequences and heads)
+static void delta_apply_triangular_updates_chunk_f32(float * attn, const int64_t chunk_size,
+                                                    const int64_t n_seqs, const int64_t H_v) {
+    for (int64_t seq = 0; seq < n_seqs; seq++) {
+        for (int64_t head = 0; head < H_v; head++) {
+            float * attn_ptr = attn + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
+            
+            // Apply triangular updates following the Python reference exactly:
+            // for i in range(1, chunk_size):
+            //     row = attn[..., i, :i].clone()
+            //     sub = attn[..., :i, :i].clone()
+            //     attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
+            for (int64_t i = 1; i < chunk_size; i++) {
+                // Create temporary storage for row and sub to avoid modifying during computation
+                float * row = (float *) malloc(i * sizeof(float));
+                float * sub = (float *) malloc(i * i * sizeof(float));
+                
+                // Copy row = attn[..., i, :i]
+                for (int64_t j = 0; j < i; j++) {
+                    row[j] = attn_ptr[i * chunk_size + j];
+                }
+                
+                // Copy sub = attn[..., :i, :i]
+                for (int64_t k = 0; k < i; k++) {
+                    for (int64_t j = 0; j < i; j++) {
+                        sub[k * i + j] = attn_ptr[k * chunk_size + j];
+                    }
+                }
+                
+                // Compute updates for each j in :i
+                for (int64_t j = 0; j < i; j++) {
+                    // Compute (row.unsqueeze(-1) * sub).sum(-2)
+                    float sum_val = 0.0f;
+                    for (int64_t k = 0; k < i; k++) {
+                        sum_val += row[k] * sub[k * i + j];
+                    }
+                                       
+                    // Update: attn[..., i, j] = row[j] + sum_val
+                    attn_ptr[i * chunk_size + j] = row[j] + sum_val;
+                }
+                
+                free(row);
+                free(sub);
+            }
+        }
+    }
+}
+
+// Helper function to add identity matrix to entire chunk (all sequences and heads)
+static void delta_add_identity_matrix_chunk_f32(float * matrix, const int64_t chunk_size,
+                                                const int64_t n_seqs, const int64_t H_v) {
+    for (int64_t seq = 0; seq < n_seqs; seq++) {
+        for (int64_t head = 0; head < H_v; head++) {
+            float * matrix_ptr = matrix + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
+            // Add identity matrix directly
+            for (int64_t i = 0; i < chunk_size; i++) {
+                matrix_ptr[i * chunk_size + i] += 1.0f;
+            }
+        }
+    }
+}
+
+// Helper function to apply triangular updates (original version for individual matrices)
 static void delta_apply_triangular_updates_f32(float * attn, const int64_t chunk_size) {
     for (int64_t i = 1; i < chunk_size; i++) {
         for (int64_t j = 0; j < i; j++) {
@@ -10542,40 +10604,41 @@ static void delta_apply_triangular_updates_f32(float * attn, const int64_t chunk
     }
 }
 
-// Helper function to add identity matrix
+// Helper function to add identity matrix (original version for individual matrices)
 static void delta_add_identity_matrix_f32(float * matrix, const int64_t size) {
     for (int64_t i = 0; i < size; i++) {
         matrix[i * size + i] += 1.0f;
     }
 }
 
-// Helper function to compute value = attn @ v_beta
-static void delta_compute_value_f32(const float * attn, const float * v_beta,
-                                      float * value,
-                                      const int64_t chunk_size, const int64_t v_head_dim) {
-    for (int64_t i = 0; i < chunk_size; i++) {
-        for (int64_t d = 0; d < v_head_dim; d++) {
-            float sum = 0.0f;
-            for (int64_t j = 0; j < chunk_size; j++) {
-                int64_t v_beta_idx = j * v_head_dim + d;
-                sum += attn[i * chunk_size + j] * v_beta[v_beta_idx];
-            }
-            value[i * v_head_dim + d] = sum;
+static void delta_compute_value_f32(const float * attn,
+                                    const float * v_beta,
+                                    float *       value,
+                                    const int64_t chunk_size,
+                                    const int64_t v_head_dim,
+                                    const int64_t n_heads,
+                                    const int64_t n_seqs) {
+    for (int64_t seq = 0; seq < n_seqs; seq++) {
+        for (int64_t head = 0; head < n_heads; head++) {
+            delta_matmul_f32(
+                attn + (chunk_size * chunk_size * n_heads) * seq + (chunk_size * chunk_size) * head, 
+                v_beta + (chunk_size * v_head_dim * n_heads) * seq + (chunk_size * v_head_dim) * head, 
+                value + (chunk_size * v_head_dim * n_heads) * seq + (chunk_size * v_head_dim) * head, 
+                chunk_size, v_head_dim, chunk_size);
         }
     }
 }
 
-// Helper function to compute k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
+// Helper function to compute k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) for single head/sequence
 static void delta_compute_k_cumdecay_f32(const float * attn, const float * k_beta, const float * g,
-                                           float * k_cumdecay, const int64_t chunk_size, const int64_t k_head_dim) {
+                                        float * k_cumdecay, const int64_t chunk_size, const int64_t k_head_dim) {
     for (int64_t i = 0; i < chunk_size; i++) {
-        for (int64_t d = 0; d < k_head_dim; d++) {
+        for (int64_t j = 0; j < k_head_dim; j++) {
             float sum = 0.0f;
-            for (int64_t j = 0; j < chunk_size; j++) {
-                int64_t k_beta_idx = j * k_head_dim + d;
-                sum += attn[i * chunk_size + j] * k_beta[k_beta_idx] * expf(g[j]);
+            for (int64_t k = 0; k < chunk_size; k++) {
+                sum += attn[i * chunk_size + k] * k_beta[k * k_head_dim + j] * expf(g[k]);
             }
-            k_cumdecay[i * k_head_dim + d] = sum;
+            k_cumdecay[i * k_head_dim + j] = sum;
         }
     }
 }
@@ -10625,7 +10688,9 @@ static void delta_matmul_state_f32(const float * a, const float * state, float *
             for (int64_t k = 0; k < cols_a; k++) {
                 int64_t a_idx = i * cols_a + k;
                 int64_t state_idx = k * cols_state + j;
-                sum += a[a_idx] * state[state_idx];
+                float a_val = a[a_idx];
+                float state_val = state[state_idx];
+                sum += a_val * state_val;
             }
             dst[i * cols_state + j] = sum;
         }
@@ -10670,6 +10735,108 @@ static void delta_update_recurrent_state_f32(const float * last_state, const flo
     }
 }
 
+// Helper function to compute q_i @ k_i.transpose(-1, -2) * decay_mask and apply mask for entire chunk
+static void delta_compute_q_k_attn_chunk_f32(const float * q, const float * k, const float * decay_mask,
+                                             float * attn, const bool * mask,
+                                             const int64_t chunk_size, const int64_t head_dim,
+                                             const int64_t n_seqs, const int64_t H_v) {
+    for (int64_t seq = 0; seq < n_seqs; seq++) {
+        for (int64_t head = 0; head < H_v; head++) {
+            const float * q_ptr = q + seq * (chunk_size * head_dim * H_v) + head * (chunk_size * head_dim);
+            const float * k_ptr = k + seq * (chunk_size * head_dim * H_v) + head * (chunk_size * head_dim);
+            const float * decay_mask_ptr = decay_mask + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
+            float * attn_ptr = attn + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
+            delta_compute_q_k_attn_f32(q_ptr, k_ptr, decay_mask_ptr, attn_ptr, mask, chunk_size, head_dim);
+        }
+    }
+}
+
+// Helper function for matrix multiplication with state tensors for entire chunk
+static void delta_matmul_state_chunk_f32(const float * a, const float * state, float * dst,
+                                        const int64_t rows_a, const int64_t cols_a, const int64_t cols_state,
+                                        const int64_t n_seqs, const int64_t H_v) {
+    for (int64_t seq = 0; seq < n_seqs; seq++) {
+        for (int64_t head = 0; head < H_v; head++) {
+            const float * a_ptr = a + seq * (rows_a * cols_a * H_v) + head * (rows_a * cols_a);
+            const float * state_ptr = state + seq * (cols_a * cols_state * H_v) + head * (cols_a * cols_state);
+            float * dst_ptr = dst + seq * (rows_a * cols_state * H_v) + head * (rows_a * cols_state);
+            delta_matmul_state_f32(a_ptr, state_ptr, dst_ptr, rows_a, cols_a, cols_state);
+        }
+    }
+}
+
+// Helper function to update recurrent state for entire chunk
+static void delta_update_recurrent_state_chunk_f32(const float * state, const float * g_last,
+                                                  const float * k, const float * g_diff_exp, const float * v_new, float * new_state,
+                                                  const int64_t chunk_size, const int64_t k_head_dim, const int64_t v_head_dim,
+                                                  const int64_t n_seqs, const int64_t H_v) {
+    for (int64_t seq = 0; seq < n_seqs; seq++) {
+        for (int64_t head = 0; head < H_v; head++) {
+            const float * state_ptr = state + seq * (k_head_dim * v_head_dim * H_v) + head * (k_head_dim * v_head_dim);
+            const float * k_ptr = k + seq * (chunk_size * k_head_dim * H_v) + head * (chunk_size * k_head_dim);
+            const float * g_diff_exp_ptr = g_diff_exp + seq * (chunk_size * H_v) + head * chunk_size;
+            const float * v_new_ptr = v_new + seq * (chunk_size * v_head_dim * H_v) + head * (chunk_size * v_head_dim);
+            float * new_state_ptr = new_state + seq * (k_head_dim * v_head_dim * H_v) + head * (k_head_dim * v_head_dim);
+                        
+            for (int64_t i = 0; i < k_head_dim; i++) {
+                for (int64_t j = 0; j < v_head_dim; j++) {
+                    int64_t state_idx = i * v_head_dim + j;
+                    
+                    // last_recurrent_state * g_last
+                    float term1 = state_ptr[state_idx] * g_last[seq * H_v + head];
+                    
+                    // (k_i * g_diff_exp).transpose(-1, -2) @ v_new
+                    float term2 = 0.0f;
+                    for (int64_t k = 0; k < chunk_size; k++) {
+                        int64_t k_idx = k * k_head_dim + i;
+                        int64_t v_idx = k * v_head_dim + j;
+                        term2 += k_ptr[k_idx] * g_diff_exp_ptr[k] * v_new_ptr[v_idx];
+                    }
+                    
+                    new_state_ptr[state_idx] = term1 + term2;
+                }
+            }
+        }
+    }
+}
+
+// Helper function for element-wise tensor subtraction for entire chunk
+static void delta_tensor_subtract_chunk_f32(const float * a, const float * b, float * dst, const int64_t size,
+                                           const int64_t n_seqs, const int64_t H_v) {
+    for (int64_t seq = 0; seq < n_seqs; seq++) {
+        for (int64_t head = 0; head < H_v; head++) {
+            const float * a_ptr = a + seq * (size * H_v) + head * size;
+            const float * b_ptr = b + seq * (size * H_v) + head * size;
+            float * dst_ptr = dst + seq * (size * H_v) + head * size;
+            delta_tensor_subtract_f32(a_ptr, b_ptr, dst_ptr, size);
+        }
+    }
+}
+
+// Helper function for element-wise tensor addition for entire chunk
+static void delta_tensor_add_chunk_f32(const float * a, const float * b, float * dst, const int64_t size,
+                                       const int64_t n_seqs, const int64_t H_v) {
+    for (int64_t seq = 0; seq < n_seqs; seq++) {
+        for (int64_t head = 0; head < H_v; head++) {
+            const float * a_ptr = a + seq * (size * H_v) + head * size;
+            const float * b_ptr = b + seq * (size * H_v) + head * size;
+            float * dst_ptr = dst + seq * (size * H_v) + head * size;
+            delta_tensor_add_f32(a_ptr, b_ptr, dst_ptr, size);
+        }
+    }
+}
+
+
+static void print_debug_info(float * data, size_t size, const char * name, int64_t token) {
+    GGML_LOG_INFO("\nggml-debug: %s (%ld) first 5 values: [%.6f, %.6f, %.6f, %.6f, %.6f, ...]\n", 
+        name, token, data[0], data[1], data[2], data[3], data[4]);
+    double sum = 0.0;
+    for (unsigned int i = 0; i < size; i++) {
+        sum += data[i];
+    }
+    GGML_LOG_INFO("total elements: %ld, sum = %.10f\n", size, sum);
+}
+
 void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml_tensor * dst) {
     const struct ggml_tensor * src0 = dst->src[0];  // q (already normalized and scaled)
     const struct ggml_tensor * src1 = dst->src[1];  // k (already normalized)
@@ -10682,7 +10849,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
     const struct ggml_tensor * src8 = dst->src[8];  // attn
 
     const int64_t H_v               = (int64_t) dst->op_params[0];
-    const int64_t S_k               = (int64_t) dst->op_params[1];
     const int64_t S_v               = (int64_t) dst->op_params[2];
     const int64_t original_n_tokens = (int64_t) dst->op_params[3];  // Get original sequence length
     const int64_t n_tokens          = original_n_tokens;            // Use the original sequence length
@@ -10698,15 +10864,17 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
 
     float * dst_data  = (float *) dst->data;
     // Following GLA pattern: output is first part, state is second part
-    float * output    = dst_data; // [S_v * H_v, n_tokens, 1, 1] - only real sequence length, not padded
-    float * new_state = dst_data + (S_v * H_v * n_tokens);  // [S_v * H_v, S_v * n_seqs, 1, 1]
+    float * output    = dst_data; // [S_v * H_v, n_tokens, 1, n_seqs] - only real sequence length, not padded
+    float * new_state = dst_data + (S_v * H_v * n_tokens * n_seqs);  // [S_v * H_v, S_v * n_seqs, 1, 1]
 
     const int ith = params->ith;
-    const int nth = params->nth;  // nth is unused
+    // const int nth = params->nth;  // nth is unused
 
     // Clear output and new state section
     if (ith == 0) {
         memset(output, 0, ((S_v * H_v * n_tokens * n_seqs) + (S_v * S_v * H_v * n_seqs)) * sizeof(float));
+    } else {
+        return;
     }
 
     // Calculate chunk size
@@ -10714,16 +10882,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
     const int64_t pad_size = (chunk_size - n_tokens % chunk_size) % chunk_size;
     const int64_t num_chunks = (n_tokens + pad_size) / chunk_size;
 
-    // Apply triangular updates to the precomputed attention matrix
-    float * attn_data = (float *) src8->data;
-    float * v_beta_data = (float *) src6->data;
-    float * k_beta_data = (float *) src7->data;
-    float * g_data = (float *) src3->data;
-    float * q_data = (float *) src0->data;
-    float * k_data = (float *) src1->data;
-    //float * v_data = (float *) src2->data;
     float * state_data = (float *) src4->data;
-    float * decay_mask_data = (float *) src5->data;
 
     GGML_ASSERT(ggml_is_contiguous(src0));
     GGML_ASSERT(ggml_is_contiguous(src1));
@@ -10735,161 +10894,347 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
     GGML_ASSERT(ggml_is_contiguous(src7));
     GGML_ASSERT(ggml_is_contiguous(src8));
 
-    int64_t total_params = n_seqs * H_v * num_chunks;
-    int64_t per_thread = (total_params % nth == 0) ? total_params / nth : (total_params / nth) + 1;
+    // int64_t total_params = n_seqs * H_v * num_chunks;
+    // int64_t per_thread = (total_params % nth == 0) ? total_params / nth : (total_params / nth) + 1;
+
+    // Create helper lambda for state tensor access
+    const auto state_ptr = [state_data, src4] (int64_t seq, int64_t head, int64_t i, int64_t j) {
+        return state_data + (j * src4->nb[0] / sizeof(float)) + (i * src4->nb[1] / sizeof(float)) +
+            (head * src4->nb[2] / sizeof(float)) + (seq * src4->nb[3] / sizeof(float));
+    };
+    
+    float * attn          = (float *) malloc(chunk_size * chunk_size * H_v * n_seqs * sizeof(float));
+    float * value         = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
+    float * k_cumdecay    = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
+    bool *  mask          = (bool *) malloc(chunk_size * chunk_size * sizeof(bool));
+    float * g =             (float *) malloc(chunk_size * H_v * n_seqs * sizeof(float));
+
+    // Create upper triangular mask for causal attention (exclude diagonal)
+    for (int64_t i = 0; i < chunk_size; i++) {
+        for (int64_t j = 0; j < chunk_size; j++) {
+            mask[i * chunk_size + j] = (j > i);  // True for upper triangular (excluding diagonal)
+        }
+    }
+
+    // Make a copy of the attention tensor and the gate cumsum tensor
+    memcpy(attn, src8->data, ggml_nbytes(src8));
+    memcpy(g, src3->data, ggml_nbytes(src3));
+
+    // Prepare the initial attention matrix with triangular updates and identity (for entire chunks)
+    // This corresponds to the reference implementation:
+    // for i in range(1, chunk_size): attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
+    // attn = attn + torch.eye(chunk_size)
+    delta_apply_triangular_updates_chunk_f32(attn, chunk_size, n_seqs, H_v);
+    delta_add_identity_matrix_chunk_f32(attn, chunk_size, n_seqs, H_v);
 
+    // Compute value = attn @ v_beta
+    delta_compute_value_f32(attn, (const float *) src6->data, value, chunk_size, S_v, H_v, n_seqs);
     for (int64_t seq = 0; seq < n_seqs; seq++) {
         for (int64_t head = 0; head < H_v; head++) {
-            for (int64_t chunk = 0; chunk < num_chunks; chunk++) {
-                int64_t tidx = seq * (H_v * num_chunks) + head * num_chunks + chunk;
-                if (tidx < ith * per_thread || tidx >= (ith + 1) * per_thread) {
-                    continue; // not our thread;
-                }
-                float * attn_data_for_chs = attn_data + (src8->nb[3] / sizeof(float)) * seq + (src8->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
-                float * value_chunk = (float *) malloc(S_v * chunk_size * H_v * n_seqs * sizeof(float));
-                float * k_cumdecay = (float *) malloc(S_v * chunk_size * H_v * n_seqs * sizeof(float));
-                delta_apply_triangular_updates_f32(attn_data_for_chs, chunk_size);
-                delta_add_identity_matrix_f32(attn_data_for_chs, chunk_size);
-                // Calculate the correct v_beta and k_beta pointers for this head and sequence
-                float * v_beta_chunk = v_beta_data + (src6->nb[3] / sizeof(float)) * seq + (src6->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
-                float * k_beta_chunk = k_beta_data + (src7->nb[3] / sizeof(float)) * seq + (src7->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
-                // The g tensor has dimensions [8, 64, 2, 1] = [features, tokens, heads, sequences]
-                // We need to access the correct head data
-                // For each head, we need to access the correct feature for all tokens in the chunk
-                // Let's try accessing feature index chunk (since we have 8 features and chunk=0)
-                float * g_chunk = g_data + (src3->nb[3] / sizeof(float)) * seq + (src3->nb[2] / sizeof(float)) * head + (src3->nb[1] / sizeof(float)) * (chunk * chunk_size);
-                delta_compute_value_f32(attn_data_for_chs, v_beta_chunk, value_chunk, chunk_size, S_v);
-                delta_compute_k_cumdecay_f32(attn_data_for_chs, k_beta_chunk, g_chunk, k_cumdecay, chunk_size, S_k);
-                // Now compute the per-chunk-specific part (corresponding to the inner loop in Python)
-                float * q_chunk = q_data + (src0->nb[3] / sizeof(float)) * seq + (src0->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
-                float * k_chunk = k_data + (src1->nb[3] / sizeof(float)) * seq + (src1->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
-                float * decay_mask_chunk = decay_mask_data + (src5->nb[3] / sizeof(float)) * seq + (src5->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
-                float * k_cumdecay_chunk = k_cumdecay + (S_v * chunk_size * H_v) * seq + (S_v * chunk_size) * head;
-                
-                // Allocate temporary variables for the loop
-                float * attn = (float *) malloc(chunk_size * chunk_size * sizeof(float));
-                float * v_prime = (float *) malloc(chunk_size * S_v * sizeof(float));
-                float * v_new = (float *) malloc(chunk_size * S_v * sizeof(float));
-                float * attn_inter = (float *) malloc(chunk_size * S_v * sizeof(float));
-                float * core_attn_out_chunk = (float *) malloc(chunk_size * S_v * sizeof(float));
-                float * g_last = (float *) malloc(sizeof(float));
-                float * g_diff_exp = (float *) malloc(chunk_size * sizeof(float));
-                bool * mask = (bool *) malloc(chunk_size * chunk_size * sizeof(bool));
-                
-                // Create upper triangular mask for causal attention (exclude diagonal)
+                delta_compute_k_cumdecay_f32(attn + (chunk_size * chunk_size * H_v) * seq + (chunk_size * chunk_size) * head, 
+                    (float *) src7->data + (chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head,
+                    g + (chunk_size * H_v) * seq + chunk_size * head,
+                    k_cumdecay + (chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head,
+                    chunk_size, S_v);
+        }
+    }
+    print_debug_info(k_cumdecay, chunk_size * S_v * H_v * n_seqs, "k_cumdecay", -1);
+
+    // Process each chunk with all sequences and heads together
+    for (int64_t chunk = 0; chunk < num_chunks; chunk++) {
+        GGML_LOG_INFO("\n=== Processing chunk %ld ===\n", chunk);
+
+        // Create lambdas for tensor access similar to recurrent function
+        const auto q_chunk = [chunk, src0](int64_t seq, int64_t head, int64_t token_idx, int64_t i) {
+            return ggml_get_f32_nd(src0, i, chunk * chunk_size + token_idx, head, seq);
+        };
+        const auto k_chunk = [chunk, src1](int64_t seq, int64_t head, int64_t token_idx, int64_t i) {
+            return ggml_get_f32_nd(src1, i, chunk * chunk_size + token_idx, head, seq);
+        };
+        const auto g_chunk = [chunk, src3](int64_t seq, int64_t head, int64_t token_idx) {
+            return ggml_get_f32_nd(src3, chunk * chunk_size + token_idx, 0, head, seq);
+        };
+
+        // Allocate per-chunk arrays containing all sequences and heads
+        float * temp_state    = (float *) malloc(S_v * S_v * H_v * n_seqs * sizeof(float));
+        float * core_attn_out = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
+        float * attn_inter    = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
+        float * v_new         = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
+        float * v_prime       = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
+        float * g_diff_exp    = (float *) malloc(chunk_size * H_v * n_seqs * sizeof(float));
+        float * g_last        = (float *) malloc(H_v * n_seqs * sizeof(float));
+
+        // Initialize temp_state with zeros for all sequences and heads (state should be empty initially)
+        memset(temp_state, 0, S_v * S_v * H_v * n_seqs * sizeof(float));
+    
+        // Create temporary arrays for entire chunk
+        float * q_chunk_data    = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
+        float * k_chunk_data    = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
+        float * q_g_exp         = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
+        float * attn_v_new      = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
+
+        // Fill temporary arrays with data from all sequences and heads
+        for (int64_t seq = 0; seq < n_seqs; seq++) {
+            for (int64_t head = 0; head < H_v; head++) {
+                float * q_ptr = q_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
+                float * k_ptr = k_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
+                float * g_ptr       = g + seq * (chunk_size * H_v) + head * chunk_size;
+                float * q_g_exp_ptr = q_g_exp + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
+
+                // Fill q, k, decay_mask, and g data
                 for (int64_t i = 0; i < chunk_size; i++) {
-                    for (int64_t j = 0; j < chunk_size; j++) {
-                        mask[i * chunk_size + j] = (j > i); // True for upper triangular (excluding diagonal)
+                    for (int64_t d = 0; d < S_v; d++) {
+                        q_ptr[i * S_v + d] = q_chunk(seq, head, i, d);
+                        k_ptr[i * S_v + d] = k_chunk(seq, head, i, d);
                     }
+                    g_ptr[i] = g_chunk(seq, head, i);
                 }
-                                                
-                // Python loop implementation:
-                // q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
-                // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
-                delta_compute_q_k_attn_f32(q_chunk, k_chunk, decay_mask_chunk, attn, mask, chunk_size, S_k);
-                
-                // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
-                // Calculate the correct state pointer for this head and sequence
-                float * head_state_data = state_data + (seq * S_v * S_v * H_v) + (head * S_v * S_v);
-                
-                
-                delta_matmul_state_f32(k_cumdecay_chunk, head_state_data, v_prime, chunk_size, S_k, S_v);
-                
-                // v_new = v_i - v_prime
-                delta_tensor_subtract_f32(value_chunk, v_prime, v_new, chunk_size * S_v);
-                
-                // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
-                float * q_g_exp = (float *) malloc(chunk_size * S_k * sizeof(float));
+
+                // Compute q_g_exp = q * g.exp()
                 for (int64_t i = 0; i < chunk_size; i++) {
-                    for (int64_t d = 0; d < S_k; d++) {
-                        int64_t q_idx = i * S_k + d;
-                        q_g_exp[q_idx] = q_chunk[q_idx] * expf(g_chunk[i]);
+                    for (int64_t d = 0; d < S_v; d++) {
+                        q_g_exp_ptr[i * S_v + d] = q_ptr[i * S_v + d] * expf(g_ptr[i]);
                     }
                 }
-                delta_matmul_state_f32(q_g_exp, head_state_data, attn_inter, chunk_size, S_k, S_v);
-                
-                // core_attn_out[:, :, i] = attn_inter + attn @ v_new
-                float * attn_v_new = (float *) malloc(chunk_size * S_v * sizeof(float));
-                delta_matmul_state_f32(attn, v_new, attn_v_new, chunk_size, chunk_size, S_v);
-                delta_tensor_add_f32(attn_inter, attn_v_new, core_attn_out_chunk, chunk_size * S_v);
-                
-                // Store the result in the output tensor
+            }
+        }
+
+        print_debug_info(q_chunk_data, chunk_size * S_v * H_v * n_seqs, "q_i_chunk", chunk);
+        print_debug_info(k_chunk_data, chunk_size * S_v * H_v * n_seqs, "k_i_chunk", chunk);
+
+        // Step 4: Compute NEW attention matrix for this chunk: attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
+        // Note: decay_mask[:, :, i] means we need to use the decay_mask for this specific chunk
+        // The mask applied is the simple causal attention mask: torch.triu(torch.ones(chunk_size, chunk_size), diagonal=1)
+        
+        // Now compute attention for all sequences and heads together
+        for (int64_t seq = 0; seq < n_seqs; seq++) {
+            for (int64_t head = 0; head < H_v; head++) {
+                float * attn_ptr = attn + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
+                const float * q_ptr = q_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
+                const float * k_ptr = k_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
+
+                float * k_trans = (float *) malloc(chunk_size * S_v * sizeof(float));
+                for (int i = 0; i < S_v; i++) {
+                    for (int j = 0; j < chunk_size; j++) {
+                        k_trans[i * chunk_size + j] = k_ptr[j * S_v + i];
+                    }
+                }
+
+                delta_matmul_f32(q_ptr, k_trans, attn_ptr, chunk_size, chunk_size, S_v);
+            }
+        }
+        print_debug_info(attn, chunk_size * chunk_size * H_v * n_seqs, "q_k_trans", chunk);
+
+
+        for (int64_t seq = 0; seq < n_seqs; seq++) {
+            for (int64_t head = 0; head < H_v; head++) {
                 for (int64_t i = 0; i < chunk_size; i++) {
-                    for (int64_t d = 0; d < S_v; d++) {
-                        if ((chunk * chunk_size + i) >= n_tokens) continue;
-                        int64_t output_idx = seq * (n_tokens * S_v * H_v) + head * (n_tokens * S_v) + (chunk * chunk_size + i) * S_v + d; 
-                        output[output_idx] = core_attn_out_chunk[i * S_v + d];
+                    for (int64_t j = 0; j < chunk_size; j++) {
+                        float * attn_ptr = attn + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
+                        const float * decay_mask_ptr = (float *) src5->data + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
+                        float attn_val = attn_ptr[i * chunk_size + j] * decay_mask_ptr[i * chunk_size + j];
+                        // Apply simple causal attention mask (upper triangular with diagonal=1)
+                        // This corresponds to: torch.triu(torch.ones(chunk_size, chunk_size), diagonal=1)
+                        if (j > i) {
+                            attn_val = 0.0f;
+                        }
+                        attn_ptr[i * chunk_size + j] = attn_val;
                     }
                 }
-                
-                // g_last = g[:, :, i, -1, None, None].exp()
-                *g_last = expf(g_chunk[chunk_size - 1]);
-                
-                // Prepare g_diff_exp = (g[:, :, i, -1, None] - g[:, :, i]).exp()
-                float g_last_val = g_chunk[chunk_size - 1];
+            }
+        }
+        
+        print_debug_info(attn, chunk_size * chunk_size * H_v * n_seqs, "attn_step4_new_chunk", chunk);
+
+        // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
+        // k_cumdecay has shape [chunk_size, v_head_dim], state has shape [v_head_dim, v_head_dim]
+        delta_matmul_state_chunk_f32(k_cumdecay, state_data, v_prime, chunk_size, S_v, S_v, n_seqs, H_v);
+        print_debug_info(v_prime, chunk_size * S_v * H_v * n_seqs, "v_prime_chunk", chunk);
+
+        // v_new = v_i - v_prime
+        delta_tensor_subtract_chunk_f32(value, v_prime, v_new, chunk_size * S_v, n_seqs, H_v);
+        print_debug_info(v_new, chunk_size * S_v * H_v * n_seqs, "v_new_chunk", chunk);
+
+        // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
+        delta_matmul_state_chunk_f32(q_g_exp, state_data, attn_inter, chunk_size, S_v, S_v, n_seqs, H_v);
+        print_debug_info(attn_inter, chunk_size * S_v * H_v * n_seqs, "attn_inter_chunk", chunk);
+
+        // core_attn_out[:, :, i] = attn_inter + attn @ v_new
+        // Use regular matrix multiplication for attn @ v_new
+        for (int64_t seq = 0; seq < n_seqs; seq++) {
+            for (int64_t head = 0; head < H_v; head++) {
+                const float * attn_ptr = attn + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
+                const float * v_new_ptr = v_new + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
+                float * attn_v_new_ptr = attn_v_new + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
+                                
+                // Compute attn @ v_new: [chunk_size, chunk_size] @ [chunk_size, S_v] -> [chunk_size, S_v]
+                delta_matmul_f32(attn_ptr, v_new_ptr, attn_v_new_ptr, chunk_size, S_v, chunk_size);
+            }
+        }
+        print_debug_info(attn_v_new, chunk_size * S_v * H_v * n_seqs, "attn_v_new_chunk", chunk);
+        delta_tensor_add_chunk_f32(attn_inter, attn_v_new, core_attn_out, chunk_size * S_v, n_seqs, H_v);
+        print_debug_info(core_attn_out, chunk_size * S_v * H_v * n_seqs, "core_attn_out_chunk", chunk);
+
+        // Prepare g_last and g_diff_exp for state update
+        for (int64_t seq = 0; seq < n_seqs; seq++) {
+            for (int64_t head = 0; head < H_v; head++) {
+                float * g_ptr = g + seq * (chunk_size * H_v) + head * chunk_size;
+                float g_last_val         = g_ptr[chunk_size - 1];
+                g_last[seq * H_v + head] = expf(g_last_val);
+
+                float * g_diff_exp_ptr = g_diff_exp + seq * (chunk_size * H_v) + head * chunk_size;
                 for (int64_t i = 0; i < chunk_size; i++) {
-                    g_diff_exp[i] = expf(g_last_val - g_chunk[i]);
+                    float diff        = g_last_val - g_ptr[i];
+                    g_diff_exp_ptr[i] = expf(diff);
                 }
-                
-                // last_recurrent_state = (
-                //     last_recurrent_state * g_last
-                //     + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
-                // )
-                float * new_recurrent_state = (float *) malloc(S_v * S_v * sizeof(float));
-                
-                
-                delta_update_recurrent_state_f32(head_state_data, g_last, k_chunk, g_diff_exp, v_new,
-                                                 new_recurrent_state, chunk_size, S_v, S_v);
-                
-                
-                // Store the new state
-                for (int64_t i = 0; i < S_v; i++) {
+            }
+        }
+
+        print_debug_info(g_last, H_v * n_seqs, "g_last_chunk", chunk);
+        print_debug_info(g_diff_exp, chunk_size * H_v * n_seqs, "g_diff_exp", chunk);
+
+        float * k_g_diffexp = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
+        for (int64_t seq = 0; seq < n_seqs; seq++) {
+            for (int64_t head = 0; head < H_v; head++) {
+                for (int64_t i = 0; i < chunk_size; i++) {
                     for (int64_t j = 0; j < S_v; j++) {
-                        int64_t state_idx = seq * S_v * S_v * H_v + head * S_v * S_v + i * S_v + j;
-                        new_state[state_idx] = new_recurrent_state[i * S_v + j];
+                        k_g_diffexp[seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v) + i * S_v + j] = 
+                            k_chunk(seq, head, i, j) * g_diff_exp[seq * (chunk_size * H_v) + head * chunk_size + i];
                     }
                 }
-                
-                // Update the original state tensor with the new state for the next chunk
+            }
+        }
+        print_debug_info(k_g_diffexp, chunk_size * S_v * H_v * n_seqs, "k_g_diffexp", chunk);
+        float * k_g_diffexp_T = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
+        for (int64_t seq = 0; seq < n_seqs; seq++) {
+            for (int64_t head = 0; head < H_v; head++) {
+                for (int64_t i = 0; i < S_v; i++) {
+                    for (int64_t j = 0; j < chunk_size; j++) {
+                        k_g_diffexp_T[seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v) + i * chunk_size + j] = 
+                            k_g_diffexp[seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v) + j * S_v + i];
+                    }
+                }
+            }
+        }
+
+        // for (int64_t seq = 0; seq < n_seqs; seq++) {
+        //     for (int64_t head = 0; head < H_v; head++) {
+        //         GGML_LOG_INFO("Sequence %ld, head %ld: \n[ ", seq, head);
+        //         for (int i = 0; i < chunk_size; i++) {
+        //             GGML_LOG_INFO("[ ");
+        //             for (int j = 0; j < S_v; j++) {
+        //                 GGML_LOG_INFO("%.6f", k_g_diffexp[(chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head + i * S_v + j]);
+        //                 if (j < chunk_size - 1) {
+        //                     GGML_LOG_INFO(", ");
+        //                 }
+        //             }
+        //             GGML_LOG_INFO("], \n");
+        //         }
+        //         GGML_LOG_INFO("]\n");
+        //     }
+        // }
+
+        print_debug_info(k_g_diffexp_T, chunk_size * S_v * H_v * n_seqs, "k_g_diffexp_T", chunk);
+
+        float * kgd_mul_vnew = (float *) malloc(S_v * S_v * H_v * n_seqs * sizeof(float));
+
+        for (int64_t seq = 0; seq < n_seqs; seq++) {
+            for (int64_t head = 0; head < H_v; head++) {
+                delta_matmul_f32(k_g_diffexp_T + (chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head, 
+                    v_new + (chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head,
+                    kgd_mul_vnew + (S_v * S_v * H_v) * seq + (S_v * S_v) * head,
+                    S_v, S_v, chunk_size);
+            }
+        }
+        print_debug_info(kgd_mul_vnew, S_v * S_v * H_v * n_seqs, "kgd_mul_vnew", chunk);
+        
+        // for (int64_t seq = 0; seq < n_seqs; seq++) {
+        //     for (int64_t head = 0; head < H_v; head++) {
+        //         GGML_LOG_INFO("Sequence %ld, head %ld: \n[ ", seq, head);
+        //         for (int i = 0; i < S_v; i++) {
+        //             GGML_LOG_INFO("[ ");
+        //             for (int j = 0; j < S_v; j++) {
+        //                 GGML_LOG_INFO("%.6f", kgd_mul_vnew[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + i * S_v + j]);
+        //                 if (j < S_v - 1) {
+        //                     GGML_LOG_INFO(", ");
+        //                 }
+        //             }
+        //             GGML_LOG_INFO("], \n");
+        //         }
+        //         GGML_LOG_INFO("]\n");
+        //     }
+        // }
+
+        for (int64_t seq = 0; seq < n_seqs; seq++) {
+            for (int64_t head = 0; head < H_v; head++) {
+                for (int i = 0; i < S_v; i++) {
+                    for (int j = 0; j < S_v; j++) {
+                        temp_state[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j] = 
+                            state_data[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j] * g_last[seq * H_v + head] + 
+                            kgd_mul_vnew[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j];
+                    }
+                }
+            }
+        }
+        print_debug_info(temp_state, S_v * S_v * H_v * n_seqs, "temp_state", chunk);
+
+        // Free temporary memory
+        free(q_chunk_data);
+        free(k_chunk_data);
+        free(q_g_exp);
+        free(attn_v_new);
+        free(kgd_mul_vnew);
+        free(k_g_diffexp_T);
+        free(k_g_diffexp);
+
+        // Store output for this chunk (all sequences and heads)
+        for (int64_t seq = 0; seq < n_seqs; seq++) {
+            for (int64_t head = 0; head < H_v; head++) {
+                float * core_attn_out_ptr = core_attn_out + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
+
+                // Store output for this chunk
+                for (int64_t i = 0; i < n_tokens; i++) {
+                    for (int64_t d = 0; d < S_v; d++) {
+                        int64_t output_idx =
+                            seq * (n_tokens * S_v * H_v) + head * (n_tokens * S_v) + (chunk * chunk_size + i) * S_v + d;
+                        output[output_idx] = core_attn_out_ptr[i * S_v + d];
+                    }
+                }
+            }
+        }
+        print_debug_info(output, S_v * H_v * n_tokens * n_seqs, "output", chunk);
+
+        // Update state tensor (all sequences and heads)
+        for (int64_t seq = 0; seq < n_seqs; seq++) {
+            for (int64_t head = 0; head < H_v; head++) {
+                float * temp_state_ptr = temp_state + seq * (S_v * S_v * H_v) + head * (S_v * S_v);
+
                 for (int64_t i = 0; i < S_v; i++) {
                     for (int64_t j = 0; j < S_v; j++) {
-                        int64_t state_idx = i * S_v + j;
-                        head_state_data[state_idx] = new_recurrent_state[state_idx];
+                        int64_t state_idx             = seq * S_v * S_v * H_v + head * S_v * S_v + i * S_v + j;
+                        new_state[state_idx]          = temp_state_ptr[i * S_v + j];
+                        *(state_ptr(seq, head, i, j)) = temp_state_ptr[i * S_v + j];
                     }
                 }
-                
-                // Recalculate head_state_data to point to the updated state for the next iteration
-                head_state_data = state_data + (seq * S_v * S_v * H_v) + (head * S_v * S_v);
-                
-                // Free temporary memory
-                free(attn);
-                free(v_prime);
-                free(v_new);
-                free(attn_inter);
-                free(core_attn_out_chunk);
-                free(g_last);
-                free(g_diff_exp);
-                free(mask);
-                free(q_g_exp);
-                free(attn_v_new);
-                free(new_recurrent_state);
-                
-                // Free the value and k_cumdecay allocated at the beginning of the loop
-                free(value_chunk);
-                free(k_cumdecay);
             }
         }
-    }    
-}
+        print_debug_info(new_state, S_v * S_v * H_v * n_seqs, "new_state", chunk);
 
-static void print_debug_info(float * data, size_t size, const char * name, int64_t token) {
-    GGML_LOG_INFO("\nggml-debug: %s (%ld) first 5 values: [%.6f, %.6f, %.6f, %.6f, %.6f, ...]\n", 
-        name, token, data[0], data[1], data[2], data[3], data[4]);
-    double sum = 0.0;
-    for (unsigned int i = 0; i < size; i++) {
-        sum += data[i];
+        free(temp_state);
+        free(core_attn_out);
+        free(attn_inter);
+        free(v_new);
+        free(v_prime);
+        free(g_diff_exp);
+        free(g_last);
     }
-    GGML_LOG_INFO("sum = %.10f\n", sum);
+
+    GGML_ASSERT(output + S_v * H_v * n_tokens * n_seqs == new_state);
+    free(attn);
+    free(value);
+    free(k_cumdecay);
+    free(mask);
+    free(g);
 }
 
 void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * params, ggml_tensor * dst) {
@@ -10971,7 +11316,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
                 }
             }
         }
-        //print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state_copy", token);
+        print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state_copy", token);
 
         // 1. last_recurrent_state = last_recurrent_state * g_t (for all seqs and heads)
         for (int64_t seq = 0; seq < n_seqs; seq++) {
@@ -10985,7 +11330,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
                 }
             }
         }
-        //print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state_times_g_t", token);
+        print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state_times_g_t", token);
         
         // 2. kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) (for all seqs and heads)
         for (int64_t seq = 0; seq < n_seqs; seq++) {
@@ -11000,7 +11345,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
                 }
             }
         }
-        //print_debug_info(kv_mem, n_seqs * H_v * S_v, "kv_mem", token);
+        print_debug_info(kv_mem, n_seqs * H_v * S_v, "kv_mem", token);
         
         // 3. delta = (v_t - kv_mem) * beta_t (for all seqs and heads)
         for (int64_t seq = 0; seq < n_seqs; seq++) {
@@ -11012,7 +11357,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
                 }
             }
         }
-        //print_debug_info(delta, n_seqs * H_v * S_v, "delta", token);
+        print_debug_info(delta, n_seqs * H_v * S_v, "delta", token);
         
         // 4. last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) (for all seqs and heads)
         for (int64_t seq = 0; seq < n_seqs; seq++) {
@@ -11026,7 +11371,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
                 }
             }
         }
-        //print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state", token);
+        print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state", token);
         
         // 5. core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) (for all seqs and heads)
         for (int64_t seq = 0; seq < n_seqs; seq++) {
@@ -11040,7 +11385,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
                 }
             }
         }
-        //print_debug_info(attn_out_t, n_seqs * S_v * H_v, "attn_out_t", token);
+        print_debug_info(attn_out_t, n_seqs * S_v * H_v, "attn_out_t", token);
         
         // Store the output for this token (for all seqs and heads)
         for (int64_t seq = 0; seq < n_seqs; seq++) {

+ 4 - 1
src/models/llm_build_qwen3next.cpp

@@ -735,8 +735,11 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     ggml_tensor * attn_out_1d =
         ggml_view_1d(ctx0, attn_out, output_flat_size, 0);
     cb(attn_out_1d, "attn_out_1d", il);
+
+    ggml_tensor * attn_out_reshaped = ggml_cont_4d(ctx0, attn_out_1d, head_v_dim, n_seq_tokens, num_v_heads, n_seqs);
+    cb(attn_out_1d, "attn_out_reshaped", il);
     
-    ggml_tensor * attn_out_final = ggml_cont(ctx0, ggml_permute(ctx0, ggml_cont_4d(ctx0, attn_out_1d, head_v_dim, n_seq_tokens, num_v_heads, n_seqs), 0, 2, 1, 3));
+    ggml_tensor * attn_out_final = ggml_cont(ctx0, ggml_permute(ctx0, attn_out_reshaped, 0, 2, 1, 3));
     cb(attn_out_final, "attn_out_final", il);
    
     // Extract the state part (second part of the concatenated tensor)