Просмотр исходного кода

Tensor preparation for delta_net complete

Piotr Wilkin 3 месяцев назад
Родитель
Сommit
c87e8d550c
6 измененных файлов с 578 добавлено и 538 удалено
  1. 0 12
      ggml/include/ggml.h
  2. 387 378
      ggml/src/ggml-cpu/ops.cpp
  3. 0 118
      ggml/src/ggml.c
  4. 1 1
      src/llama-context.cpp
  5. 178 18
      src/models/llm_build_qwen3next.cpp
  6. 12 11
      src/models/llm_build_qwen3next.h

+ 0 - 12
ggml/include/ggml.h

@@ -2316,18 +2316,6 @@ extern "C" {
             struct ggml_tensor  * b,
             struct ggml_tensor  * b,
             struct ggml_tensor  * state);
             struct ggml_tensor  * state);
 
 
-    GGML_API struct ggml_tensor * ggml_delta_net(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * q,
-            struct ggml_tensor  * k,
-            struct ggml_tensor  * v,
-            struct ggml_tensor  * g,
-            struct ggml_tensor  * beta,
-            struct ggml_tensor  * state,
-            bool                  use_qk_l2norm,
-            float                 scale,
-            float                 eps_norm);
-
     // custom operators
     // custom operators
 
 
     typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
     typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);

+ 387 - 378
ggml/src/ggml-cpu/ops.cpp

@@ -10590,6 +10590,51 @@ static void ggml_delta_net_matmul_f32(const float * a, const int64_t rows_a, con
     }
     }
 }
 }
 
 
+// Helper function to apply triangular updates to attention matrix
+static void delta_net_compute_diagonal_updates(
+    float * attn,
+    const int64_t chunk_size,
+    const int64_t n_heads,
+    const int64_t n_seqs) {
+    
+    // Apply triangular updates like in the Python reference:
+    // for i in range(1, chunk_size):
+    //     row = attn[..., i, :i].clone()
+    //     sub = attn[..., :i, :i].clone()
+    //     attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
+    // attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
+    
+    for (int64_t head = 0; head < n_heads; head++) {
+        for (int64_t seq = 0; seq < n_seqs; seq++) {
+            // Get pointer to this head's attention matrix
+            float * attn_head = attn + (head * chunk_size * chunk_size * n_seqs) + (seq * chunk_size * chunk_size);
+            
+            // Create temporary storage for the original values to avoid in-place modification
+            float * attn_copy = (float *) malloc(chunk_size * chunk_size * sizeof(float));
+            memcpy(attn_copy, attn_head, chunk_size * chunk_size * sizeof(float));
+            
+            // Apply triangular updates using the original values
+            for (int64_t i = 1; i < chunk_size; i++) {
+                for (int64_t j = 0; j < i; j++) {
+                    float sum = 0.0f;
+                    // Compute (row.unsqueeze(-1) * sub).sum(-2) using original values
+                    for (int64_t k = 0; k < i; k++) {
+                        sum += attn_copy[i * chunk_size + k] * attn_copy[k * chunk_size + j];
+                    }
+                    attn_head[i * chunk_size + j] = attn_copy[i * chunk_size + j] + sum;
+                }
+            }
+            
+            // Add identity matrix
+            for (int64_t i = 0; i < chunk_size; i++) {
+                attn_head[i * chunk_size + i] += 1.0f;
+            }
+            
+            free(attn_copy);
+        }
+    }
+}
+
 void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml_tensor * dst) {
 void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml_tensor * dst) {
     const struct ggml_tensor * src0 = dst->src[0];  // q (already normalized and scaled)
     const struct ggml_tensor * src0 = dst->src[0];  // q (already normalized and scaled)
     const struct ggml_tensor * src1 = dst->src[1];  // k (already normalized)
     const struct ggml_tensor * src1 = dst->src[1];  // k (already normalized)
@@ -10614,11 +10659,11 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
     GGML_ASSERT(src1->ne[3] == n_seqs);  // k tensor
     GGML_ASSERT(src1->ne[3] == n_seqs);  // k tensor
     GGML_ASSERT(src2->ne[3] == n_seqs);  // v tensor
     GGML_ASSERT(src2->ne[3] == n_seqs);  // v tensor
     GGML_ASSERT(src3->ne[3] == n_seqs);  // g tensor
     GGML_ASSERT(src3->ne[3] == n_seqs);  // g tensor
-    GGML_ASSERT(src4->ne[3] == n_seqs);  // beta tensor
+    GGML_ASSERT(src4->ne[3] == n_seqs);  // state tensor
 
 
     float * dst_data  = (float *) dst->data;
     float * dst_data  = (float *) dst->data;
     // Following GLA pattern: output is first part, state is second part
     // Following GLA pattern: output is first part, state is second part
-    float * output    = dst_data;  // [S_v * H_v, n_tokens, 1, 1] - only real sequence length, not padded
+    float * output    = dst_data; // [S_v * H_v, n_tokens, 1, 1] - only real sequence length, not padded
     float * new_state = dst_data + (S_v * H_v * n_tokens);  // [S_v * H_v, S_v * n_seqs, 1, 1]
     float * new_state = dst_data + (S_v * H_v * n_tokens);  // [S_v * H_v, S_v * n_seqs, 1, 1]
 
 
     const int ith = params->ith;
     const int ith = params->ith;
@@ -10633,414 +10678,378 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
     // Clear output and new state section
     // Clear output and new state section
     memset(output, 0, ((S_v * H_v * n_tokens) + (S_v * H_v * S_v * n_seqs)) * sizeof(float));
     memset(output, 0, ((S_v * H_v * n_tokens) + (S_v * H_v * S_v * n_seqs)) * sizeof(float));
 
 
-    // Get tensor data pointers
-    float * state_data = (float *) src4->data;
-    float * decay_mask = (float *) src5->data;
-
-    // Allocate temporary buffers for computation
+    // Calculate chunk size
     const int64_t chunk_size = GGML_DELTA_NET_CHUNK;
     const int64_t chunk_size = GGML_DELTA_NET_CHUNK;
-    // The first dimension is the chunk_size, second is head_dim, third is num_heads, fourth is n_seqs
-    // Note: In reference Python implementation, tensors are padded to multiple of chunk_size
-    // but the output only contains the real sequence length, not the padded length
-
-    // Calculate the actual padded sequence length for internal processing
-    const int64_t pad_size              = (chunk_size - n_tokens % chunk_size) % chunk_size;
-    const int64_t total_sequence_length = n_tokens + pad_size;
-    const int64_t n_chunks              = (total_sequence_length + chunk_size - 1) / chunk_size;  // Ceiling division
-
-    // Temporary buffers for each chunk
-    std::vector<float>  attn(chunk_size * chunk_size, 0.0f);
-    std::vector<float>  value(chunk_size * S_v, 0.0f);
-    std::vector<float>  k_cumdecay(chunk_size * S_k, 0.0f);
-    std::vector<double> g_exp(chunk_size, 0.0f);
-    std::vector<float>  g_cumsum(chunk_size, 0.0f);
-    std::vector<float>  last_state(S_v * S_v * H_v, 0.0f);
-
-    // Initialize last_state with input state data
-    // State format in GGML: [S_v, S_v * H_v, 1, 1] where S_v * H_v = S_v * num_heads
-    // The state tensor has format [S_v, S_v * H_v, 1, 1] where second dimension is S_v * num_heads
-    // For delta_net, S_k == S_v (both k and v have the same head dimension)
-    for (int64_t h = 0; h < H_v; h++) {
-        for (int64_t d1 = 0; d1 < S_v; d1++) {
-            for (int64_t d2 = 0; d2 < S_v; d2++) {
-                // GGML state index: [d1, d2 + h*S_v, 0, 0] in flattened form
-                int64_t ggml_state_idx         = d1 * (S_v * H_v) + (d2 + h * S_v);
-                // Our computed state index: [d1, d2 + h*S_v]
-                int64_t computed_state_idx     = d1 * (S_v * H_v) + (d2 + h * S_v);
-                last_state[computed_state_idx] = state_data[ggml_state_idx];
-            }
-        }
-    }
-
-    // Maintain running cumulative sum across all chunks
-    std::vector<float> running_cumsum(n_tokens, 0.0f);
-
-    // Process each chunk
-    for (int64_t chunk_idx = 0; chunk_idx < n_chunks; chunk_idx++) {
-        // Process each head and sequence
-        for (int64_t h = 0; h < H_k; h++) {
-            for (int64_t seq = 0; seq < n_seqs; seq++) {
-                // Extract chunk data for this head and sequence
-                std::vector<float> q_chunk(chunk_size * S_k);
-                std::vector<float> k_chunk(chunk_size * S_k);
-                std::vector<float> v_chunk(chunk_size * S_v);
-                std::vector<float> v_beta_chunk(chunk_size * S_v);
-                std::vector<float> k_beta_chunk(chunk_size * S_k);
-                std::vector<float> g_chunk(chunk_size);
-
-                // Initialize chunks with zeros for padding
-                std::fill(q_chunk.begin(), q_chunk.end(), 0.0f);
-                std::fill(k_chunk.begin(), k_chunk.end(), 0.0f);
-                std::fill(v_chunk.begin(), v_chunk.end(), 0.0f);
-                std::fill(v_beta_chunk.begin(), v_beta_chunk.end(), 0.0f);
-                std::fill(k_beta_chunk.begin(), k_beta_chunk.end(), 0.0f);
-                std::fill(g_chunk.begin(), g_chunk.end(), 0.0f);
-
-                // Determine actual tokens in this chunk
-                int64_t tokens_in_chunk = std::min(chunk_size, n_tokens - chunk_idx * chunk_size);
-
-                // Copy data for this chunk
-                for (int64_t t = 0; t < tokens_in_chunk; t++) {
-                    int64_t actual_pos = chunk_idx * chunk_size + t;  // Position in the original sequence
-
-                    // Only copy if this position is within the original sequence length
-                    if (actual_pos < n_tokens) {
-                        // Calculate indices in GGML format [chunk_size, head_dim, num_heads, n_seqs]
-                        for (int64_t d = 0; d < S_k; d++) {
-                            q_chunk[t * S_k + d]      = ggml_get_f32_nd(src0, actual_pos, d, h, seq);
-                            k_chunk[t * S_k + d]      = ggml_get_f32_nd(src1, actual_pos, d, h, seq);
-                            k_beta_chunk[t * S_k + d] = ggml_get_f32_nd(src7, actual_pos, d, h, seq);
-                        }
-
-                        for (int64_t d = 0; d < S_v; d++) {
-                            v_chunk[t * S_v + d]      = ggml_get_f32_nd(src2, actual_pos, d, h, seq);
-                            v_beta_chunk[t * S_v + d] = ggml_get_f32_nd(src6, actual_pos, d, h, seq);
-                        }
-
-                        if (actual_pos <
-                            n_tokens) {  // Only copy if this position is within the original sequence length
-                            // Use the safe GGML function to access tensor values
-                            g_chunk[t] = ggml_get_f32_nd(src3, actual_pos, 0, h, seq);
-                        } else {
-                            // For padded positions, set to 0 (or a default value)
-                            g_chunk[t] = 0.0f;
-                        }
-                    } else {
-                        // For padded positions beyond original sequence, set to 0
-                        for (int64_t d = 0; d < S_k; d++) {
-                            q_chunk[t * S_k + d]      = 0.0f;
-                            k_chunk[t * S_k + d]      = 0.0f;
-                            k_beta_chunk[t * S_k + d] = 0.0f;
-                        }
-                        for (int64_t d = 0; d < S_v; d++) {
-                            v_chunk[t * S_v + d]      = 0.0f;
-                            v_beta_chunk[t * S_v + d] = 0.0f;
-                        }
-                        g_chunk[t] = 0.0f;
-                    }
-                }
-
-                // In Python, cumsum is applied to each chunk separately after reshaping
-                // So we need to compute cumsum within this chunk only
-
-                // g_chunk already contains the cumsum values from src3 (g_cumsum), so use them directly
-                for (int64_t i = 0; i < tokens_in_chunk; i++) {  // Only for actual tokens, not full chunk size
-                    g_cumsum[i] = g_chunk[i];
-                }
-
-                // For padded positions, set cumsum values to 0
-                for (int64_t i = tokens_in_chunk; i < chunk_size; i++) {
-                    g_cumsum[i] = 0.0f;
-                }
-
-                // Compute g_exp from cumulative sums (like Python: g.cumsum().exp())
-                // Apply numerical stability to prevent underflow for very negative values
-                for (int64_t i = 0; i < tokens_in_chunk; i++) {  // Only for actual tokens, not full chunk size
-                    // Use double precision for exponential to avoid overflow/underflow
-                    // Apply lower bound to prevent extreme underflow - exp(-50) is about 1.9e-22
-                    double g_val        = (double) g_cumsum[i];
-                    double g_exp_double = exp(g_val);
-                    g_exp[i]            = g_exp_double;
-                }
-
-                // For padded positions, set exp values to 0
-                for (int64_t i = tokens_in_chunk; i < chunk_size; i++) {
-                    g_exp[i] = 0.0f;
-                }
-                // Step 1: Compute k_beta @ key.T (this corresponds to the Python: k_beta @ key.transpose(-1, -2))
-                // Only compute for actual tokens in chunk
-                ggml_compute_k_beta_key_t_f32(k_beta_chunk.data(), k_chunk.data(), attn.data(), tokens_in_chunk,
-                                              S_k);  // Use actual tokens, not full chunk size
-
-                // Apply precomputed decay mask from src5 and negate the result (like Python: -(...))
-                // The decay mask is computed in ggml_delta_net in ggml.c and passed as src5
-                // Apply the precomputed decay mask from src5 (decay_mask tensor)
-                // The decay_mask tensor now contains exp(g_cumsum[j] - g_cumsum[i]) values
-                // where g_cumsum[j] - g_cumsum[i] is computed in the main function
-                for (int64_t i = 0; i < tokens_in_chunk; i++) {      // Only for actual tokens
-                    for (int64_t j = 0; j < tokens_in_chunk; j++) {  // Only for actual tokens
-                        // Get decay mask value from precomputed tensor
-                        // src5 decay_mask has shape [chunk_size, chunk_size, H_k, n_seqs] in GGML format
-                        // Format: [i_pos, j_pos, head, seq] - represents exp(g_cumsum[j] - g_cumsum[i])
-                        float decay_val = ggml_get_f32_nd(
-                            src5, i, j, h, seq);  // [i, j, h, seq] to get exp(g_cumsum[j] - g_cumsum[i]) for head h
-                        if (j <= i) {             // Only apply to lower triangular part (i >= j)
-                            // The decay_val already contains exp(g_cumsum[j] - g_cumsum[i]), no need for additional exponential
-                            // Apply the decay mask and negate (like Python: -((k_beta @ key.T) * decay_mask))
-                            attn[i * chunk_size + j] = -attn[i * chunk_size + j] * decay_val;
-                        } else {
-                            attn[i * chunk_size + j] =
-                                0.0f;  // Zero out upper triangular part (like Python: masked_fill(mask, 0))
-                        }
+    const int64_t pad_size = (chunk_size - n_tokens % chunk_size) % chunk_size;
+    const int64_t num_chunks = (n_tokens + pad_size) / chunk_size;
+
+    // Apply triangular updates to the precomputed attention matrix
+    // This is the missing piece that was causing the attention matrix to have all zeros
+    float * attn_data = (float *) src8->data;
+    delta_net_compute_diagonal_updates(attn_data, chunk_size, H_v, n_seqs);
+    
+    // Debug: Check attention matrix after triangular updates
+    float attn_after_updates_sum = 0.0f;
+    float attn_after_updates_max = 0.0f;
+    for (int64_t i = 0; i < chunk_size * chunk_size * H_v * n_seqs; i++) {
+        attn_after_updates_sum += attn_data[i];
+        attn_after_updates_max = fmaxf(attn_after_updates_max, fabsf(attn_data[i]));
+    }
+    printf("C++ attn_after_triangular_updates sum = %f, max = %f\n", attn_after_updates_sum, attn_after_updates_max);
+
+    // Compute value = attn @ v_beta and k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
+    // These should be computed once before the chunk loop, like in the Python reference
+    printf("=== Computing value and k_cumdecay before chunk loop ===\n");
+    
+    // Compute value and k_cumdecay for each head and sequence
+    for (int64_t seq_idx = 0; seq_idx < n_seqs; seq_idx++) {
+        for (int64_t head_idx = 0; head_idx < H_v; head_idx++) {
+            // Get offsets for this head and sequence
+            const int64_t attn_offset = (seq_idx * src8->nb[3] + head_idx * src8->nb[2]) / sizeof(float);
+            const int64_t v_beta_offset = (seq_idx * src6->nb[3] + head_idx * src6->nb[2]) / sizeof(float);
+            const int64_t k_beta_offset = (seq_idx * src7->nb[3] + head_idx * src7->nb[2]) / sizeof(float);
+            const int64_t g_offset = (seq_idx * src3->nb[3] + head_idx * src3->nb[2]) / sizeof(float);
+            
+            float * attn_precomputed = (float *) src8->data + attn_offset;
+            float * v_beta_ptr = (float *) src6->data + v_beta_offset;
+            float * k_beta_ptr = (float *) src7->data + k_beta_offset;
+            float * g_vals = (float *) src3->data + g_offset;
+            
+            // Compute value = attn @ v_beta
+            float * value = (float *) malloc(chunk_size * S_v * sizeof(float));
+            for (int64_t i = 0; i < chunk_size; i++) {
+                for (int64_t d = 0; d < S_v; d++) {
+                    float sum = 0.0f;
+                    for (int64_t j = 0; j < chunk_size; j++) {
+                        sum += attn_precomputed[i * chunk_size + j] * v_beta_ptr[j * S_v + d];
                     }
                     }
+                    value[i * S_v + d] = sum;
                 }
                 }
-
-                // Step 2: Apply triangular updates (equivalent to Python's complex triangular update)
-                // Python: for i in range(1, chunk_size):
-                //           row = attn[..., i, :i].clone()  // row = attn[i, 0:i]
-                //           sub = attn[..., :i, :i].clone() // sub = attn[0:i, 0:i]
-                //           attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
-                // This means: new_attn[i, j] = old_attn[i, j] + sum_k(old_attn[i, k] * old_attn[k, j]) for k < i
-                for (int64_t i = 1; i < tokens_in_chunk; i++) {
-                    // Store the original row values to avoid using updated values in computation
-                    std::vector<float> original_row(i);
-                    for (int64_t j = 0; j < i; j++) {
-                        original_row[j] = attn[i * tokens_in_chunk + j];  // Use tokens_in_chunk for indexing
+            }
+            
+            float value_sum = 0.0f;
+            for (int64_t i = 0; i < chunk_size * S_v; i++) {
+                value_sum += value[i];
+            }
+            printf("C++ PRE-CHUNK value_sum = %f (head %ld, seq %ld)\n", value_sum, head_idx, seq_idx);
+            
+            // Compute k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
+            float * k_cumdecay = (float *) malloc(chunk_size * S_k * sizeof(float));
+            for (int64_t i = 0; i < chunk_size; i++) {
+                for (int64_t d = 0; d < S_k; d++) {
+                    float sum = 0.0f;
+                    for (int64_t j = 0; j < chunk_size; j++) {
+                        float g_exp = expf(g_vals[j]);
+                        sum += attn_precomputed[i * chunk_size + j] * k_beta_ptr[j * S_k + d] * g_exp;
                     }
                     }
-
-                    for (int64_t j = 0; j < i; j++) {
+                    k_cumdecay[i * S_k + d] = sum;
+                }
+            }
+            
+            float k_cumdecay_sum = 0.0f;
+            for (int64_t i = 0; i < chunk_size * S_k; i++) {
+                k_cumdecay_sum += k_cumdecay[i];
+            }
+            printf("C++ PRE-CHUNK k_cumdecay_sum = %f (head %ld, seq %ld)\n", k_cumdecay_sum, head_idx, seq_idx);
+            
+            free(value);
+            free(k_cumdecay);
+        }
+    }
+    printf("=== End pre-chunk computations ===\n");
+
+    // Initialize last_recurrent_state
+    // last_recurrent_state = torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value)
+    // if initial_state is None else initial_state.to(value)
+    float * initial_state_ptr = (float *) src4->data;
+    
+    // If initial_state is provided, copy it to new_state, otherwise initialize new_state to zeros
+    // Output state layout: [S_v * H_v, S_v * n_seqs, 1, 1]
+    // This means: [n_heads * v_head_dim, v_head_dim * n_seqs, 1, 1]
+    // So total size is: S_v * H_v * S_v * n_seqs
+    if (initial_state_ptr != NULL) {
+        // Copy initial state to new state
+        memcpy(new_state, initial_state_ptr, S_v * H_v * S_v * n_seqs * sizeof(float));
+    } else {
+        // Initialize new state to zeros
+        memset(new_state, 0, S_v * H_v * S_v * n_seqs * sizeof(float));
+    }
+    
+    // Process each chunk for the main computation
+    // Following the Python reference implementation exactly
+    for (int64_t chunk_idx = 0; chunk_idx < num_chunks; chunk_idx++) {
+        for (int64_t seq_idx = 0; seq_idx < n_seqs; seq_idx++) {
+            // Process each head in this chunk
+            for (int64_t head_idx = 0; head_idx < H_v; head_idx++) {
+                // Get the recurrent state for this sequence and head
+                // Output state layout: [S_v * H_v, S_v * n_seqs, 1, 1]
+                // For each head and sequence, we need to find the correct state slice
+                // The state is organized as: [head_idx * S_v * S_v * n_seqs + seq_idx * S_v * S_v]
+                float * last_recurrent_state = new_state + (head_idx * S_v * S_v * n_seqs) + (seq_idx * S_v * S_v);
+                printf("\n=== C++ Processing chunk %ld, seq %ld, head %ld ===\n", chunk_idx, seq_idx, head_idx);
+                // Get pointers to current chunk data for this head
+                // GGML tensor layout: [S_k/S_v, chunk_size, H_v, n_seqs]
+                // Python layout: [batch_size, sequence_length, num_heads, k_head_dim]
+                // After transpose: [batch_size, num_heads, sequence_length, k_head_dim]
+                
+                // For GGML: ne[0]=S_k/S_v, ne[1]=chunk_size, ne[2]=H_v, ne[3]=n_seqs
+                // nb[0]=sizeof(float)*S_k/S_v, nb[1]=sizeof(float)*S_k/S_v*chunk_size, etc.
+                
+                const int64_t q_offset = (seq_idx * src0->nb[3] + head_idx * src0->nb[2]) / sizeof(float);
+                const int64_t k_offset = (seq_idx * src1->nb[3] + head_idx * src1->nb[2]) / sizeof(float);
+                const int64_t v_offset = (seq_idx * src2->nb[3] + head_idx * src2->nb[2]) / sizeof(float);
+                const int64_t g_offset = (seq_idx * src3->nb[3] + head_idx * src3->nb[2]) / sizeof(float);
+                const int64_t v_beta_offset = (seq_idx * src6->nb[3] + head_idx * src6->nb[2]) / sizeof(float);
+                const int64_t k_beta_offset = (seq_idx * src7->nb[3] + head_idx * src7->nb[2]) / sizeof(float);
+                const int64_t attn_offset = (seq_idx * src8->nb[3] + head_idx * src8->nb[2]) / sizeof(float);
+                const int64_t decay_mask_offset = (seq_idx * src5->nb[3] + head_idx * src5->nb[2]) / sizeof(float);
+                
+                // Calculate strides for each tensor
+                const int64_t q_stride0 = src0->nb[0] / sizeof(float);  // S_k
+                const int64_t q_stride1 = src0->nb[1] / sizeof(float);  // chunk_size
+                const int64_t k_stride0 = src1->nb[0] / sizeof(float);  // S_k
+                const int64_t k_stride1 = src1->nb[1] / sizeof(float);  // chunk_size
+                const int64_t v_stride0 = src2->nb[0] / sizeof(float);  // S_v
+                const int64_t v_stride1 = src2->nb[1] / sizeof(float);  // chunk_size
+                const int64_t g_stride0 = src3->nb[0] / sizeof(float);  // chunk_size
+                const int64_t v_beta_stride0 = src6->nb[0] / sizeof(float);  // S_v
+                const int64_t v_beta_stride1 = src6->nb[1] / sizeof(float);  // chunk_size
+                const int64_t k_beta_stride0 = src7->nb[0] / sizeof(float);  // S_k
+                const int64_t k_beta_stride1 = src7->nb[1] / sizeof(float);  // chunk_size
+                const int64_t attn_stride0 = src8->nb[0] / sizeof(float);  // chunk_size
+                const int64_t attn_stride1 = src8->nb[1] / sizeof(float);  // chunk_size
+                const int64_t decay_mask_stride0 = src5->nb[0] / sizeof(float);  // chunk_size
+                const int64_t decay_mask_stride1 = src5->nb[1] / sizeof(float);  // chunk_size
+                
+                // Get decay mask for this chunk and head
+                float * decay_mask = (float *) src5->data + decay_mask_offset;
+                
+                // Use pre-computed attention matrix from src8 (after triangular updates)
+                // The Python reference computes triangular updates before the chunk loop
+                float * attn_precomputed = (float *) src8->data + attn_offset;
+                
+                // Debug: print precomputed attention matrix values
+                float attn_precomputed_sum = 0.0f;
+                float attn_precomputed_max = 0.0f;
+                for (int64_t i = 0; i < chunk_size * chunk_size; i++) {
+                    attn_precomputed_sum += attn_precomputed[i];
+                    attn_precomputed_max = fmaxf(attn_precomputed_max, fabsf(attn_precomputed[i]));
+                }
+                printf("C++ attn_precomputed_sum = %f, max = %f\n", attn_precomputed_sum, attn_precomputed_max);
+                printf("C++ attn_precomputed first 10 values: %f, %f, %f, %f, %f, %f, %f, %f, %f, %f\n",
+                       attn_precomputed[0], attn_precomputed[1], attn_precomputed[2], attn_precomputed[3], attn_precomputed[4],
+                       attn_precomputed[5], attn_precomputed[6], attn_precomputed[7], attn_precomputed[8], attn_precomputed[9]);
+                printf("C++ attn_precomputed diagonal values: %f, %f, %f, %f, %f\n",
+                       attn_precomputed[0], attn_precomputed[65], attn_precomputed[130], attn_precomputed[195], attn_precomputed[260]);
+                
+                // Get g values for this chunk and head
+                float * g_vals = (float *) src3->data + g_offset;
+                
+                // Get v_beta and k_beta for this chunk and head
+                float * v_beta_ptr = (float *) src6->data + v_beta_offset;
+                float * k_beta_ptr = (float *) src7->data + k_beta_offset;
+                
+                // Debug: print v_beta and k_beta values
+                float v_beta_sum = 0.0f;
+                for (int64_t i = 0; i < chunk_size * S_v; i++) {
+                    v_beta_sum += v_beta_ptr[i];
+                }
+                printf("C++ v_beta_sum = %f\n", v_beta_sum);
+                
+                float k_beta_sum = 0.0f;
+                for (int64_t i = 0; i < chunk_size * S_k; i++) {
+                    k_beta_sum += k_beta_ptr[i];
+                }
+                printf("C++ k_beta_sum = %f\n", k_beta_sum);
+                
+                // Compute value = attn_precomputed @ v_beta
+                float * value = (float *) malloc(chunk_size * S_v * sizeof(float));
+                for (int64_t i = 0; i < chunk_size; i++) {
+                    for (int64_t d = 0; d < S_v; d++) {
                         float sum = 0.0f;
                         float sum = 0.0f;
-                        for (int64_t k = 0; k < i; k++) {
-                            // This implements: sum over k of (original_row[k] * sub[k, j])
-                            // Where sub[k, j] is attn[k, j] (the original value before updates)
-                            sum += original_row[k] * attn[k * tokens_in_chunk + j];  // Use tokens_in_chunk for indexing
+                        for (int64_t j = 0; j < chunk_size; j++) {
+                            sum += attn_precomputed[i * chunk_size + j] * v_beta_ptr[j * v_beta_stride0 + d * v_beta_stride1];
                         }
                         }
-                        // The new value is: original_value + matrix_mult_result
-                        attn[i * tokens_in_chunk + j] = original_row[j] + sum;
+                        value[i * S_v + d] = sum;
                     }
                     }
                 }
                 }
-
-                // Step 3: Add identity matrix (equivalent to Python's: attn = attn + torch.eye(...))
-                for (int64_t i = 0; i < tokens_in_chunk; i++) {
-                    attn[i * tokens_in_chunk + i] += 1.0f;
-                }
-
-                // Step 4: Compute value = attn @ v_beta
-                ggml_compute_value_f32(attn.data(), v_beta_chunk.data(), value.data(), tokens_in_chunk,
-                                       S_v);  // Use actual tokens, not full chunk size
-
-                // Step 5: Compute k_cumdecay = attn @ (k_beta * g_exp)
-                ggml_compute_k_cumdecay_f32(attn.data(), k_beta_chunk.data(), g_exp.data(), k_cumdecay.data(),
-                                            tokens_in_chunk, S_k);  // Use actual tokens, not full chunk size
-
-                // Step 6: Compute core attention output for this chunk
-                // First, compute v_new for all tokens in the chunk
-                std::vector<float> v_new_chunk(tokens_in_chunk * S_v);
-                for (int64_t i = 0; i < tokens_in_chunk; i++) {
-                    // v_prime = k_cumdecay @ last_state
-                    // k_cumdecay[i] is [S_k], last_state for head h is [S_k, S_v]
-                    std::vector<float> v_prime(S_v, 0.0f);
-                    for (int64_t d1 = 0; d1 < S_v; d1++) {
-                        for (int64_t d2 = 0; d2 < S_k; d2++) {
-                            // State index: [d2, d1 + h*S_v] in GGML format
-                            int64_t state_idx = d2 * (S_v * H_v) + (d1 + h * S_v);
-                            v_prime[d1] += k_cumdecay[i * S_k + d2] * last_state[state_idx];
+                
+                float value_sum = 0.0f;
+                for (int64_t i = 0; i < chunk_size * S_v; i++) {
+                    value_sum += value[i];
+                }
+                printf("C++ value_sum = %f (head %ld, seq %ld)\n", value_sum, head_idx, seq_idx);
+                
+                // Compute k_cumdecay = attn_precomputed @ (k_beta * g.exp().unsqueeze(-1))
+                float * k_cumdecay = (float *) malloc(chunk_size * S_k * sizeof(float));
+                for (int64_t i = 0; i < chunk_size; i++) {
+                    for (int64_t d = 0; d < S_k; d++) {
+                        float sum = 0.0f;
+                        for (int64_t j = 0; j < chunk_size; j++) {
+                            float g_exp = expf(g_vals[j * g_stride0]);
+                            sum += attn_precomputed[i * chunk_size + j] * k_beta_ptr[j * k_beta_stride0 + d * k_beta_stride1] * g_exp;
                         }
                         }
-                    }
-
-                    // v_new = v_i - v_prime
-                    for (int64_t d = 0; d < S_v; d++) {
-                        v_new_chunk[i * S_v + d] = value[i * S_v + d] - v_prime[d];
+                        k_cumdecay[i * S_k + d] = sum;
                     }
                     }
                 }
                 }
-
-                // Now process each token in the chunk to compute output
-                for (int64_t i = 0; i < tokens_in_chunk; i++) {
-                    // q_i @ k_i.T * decay_mask
-                    std::vector<float> q_k_attn(chunk_size);
+                
+                float k_cumdecay_sum = 0.0f;
+                for (int64_t i = 0; i < chunk_size * S_k; i++) {
+                    k_cumdecay_sum += k_cumdecay[i];
+                }
+                printf("C++ k_cumdecay_sum = %f (head %ld, seq %ld)\n", k_cumdecay_sum, head_idx, seq_idx);
+                
+                // Compute fresh attention matrix for this chunk, just like Python reference line 118
+                // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
+                float * attn = (float *) malloc(chunk_size * chunk_size * sizeof(float));
+                
+                // First compute q_i @ k_i.transpose(-1, -2)
+                for (int64_t i = 0; i < chunk_size; i++) {
                     for (int64_t j = 0; j < chunk_size; j++) {
                     for (int64_t j = 0; j < chunk_size; j++) {
                         float sum = 0.0f;
                         float sum = 0.0f;
                         for (int64_t d = 0; d < S_k; d++) {
                         for (int64_t d = 0; d < S_k; d++) {
-                            sum += q_chunk[i * S_k + d] * k_chunk[j * S_k + d];
-                        }
-                        // Apply decay mask - use the precomputed decay mask from src5 tensor
-                        if (j <= i) {                 // Only apply to lower triangular part (i >= j)
-                            float decay_val = ggml_get_f32_nd(
-                                src5, i, j, h, seq);  // [i, j, h, seq] to get exp(g_cumsum[i] - g_cumsum[j]) for head h
-                            q_k_attn[j] = sum * decay_val;
-                        } else {
-                            q_k_attn[j] = 0.0f;  // Zero out upper triangular part (i < j)
+                            float q_val = ((float *)src0->data)[q_offset + d * q_stride0 + i * q_stride1];
+                            float k_val = ((float *)src1->data)[k_offset + d * k_stride0 + j * k_stride1];
+                            sum += q_val * k_val;
                         }
                         }
+                        attn[i * chunk_size + j] = sum * decay_mask[i * chunk_size + j];
                     }
                     }
-
-                    // attn_inter = q_i * g_exp @ last_state
-                    // q_chunk[i] is [S_k], g_exp[i] is scalar, last_state for head h is [S_k, S_v]
-                    std::vector<float> attn_inter(S_v, 0.0f);
-                    for (int64_t d1 = 0; d1 < S_v; d1++) {
-                        for (int64_t d2 = 0; d2 < S_k; d2++) {
-                            // State index: [d2, d1 + h*S_v] in GGML format
-                            int64_t state_idx = d2 * (S_v * H_v) + (d1 + h * S_v);
-                            // Use double precision for the computation and then cast to float
-                            double  temp_result =
-                                (double) q_chunk[i * S_k + d2] * g_exp[i] * (double) last_state[state_idx];
-                            attn_inter[d1] += (float) temp_result;
-                        }
+                }
+                
+                // Apply upper triangular mask (masked_fill_(mask, 0))
+                for (int64_t i = 0; i < chunk_size; i++) {
+                    for (int64_t j = i + 1; j < chunk_size; j++) {
+                        attn[i * chunk_size + j] = 0.0f;
                     }
                     }
-
-                    // core_attn_out = attn_inter + attn @ v_new
-                    // We need to use the attention matrix computed for this position (i)
-                    // The attn matrix was computed earlier in the chunk processing
-                    // attn @ v_new where attn is [chunk_size, chunk_size] and v_new is [chunk_size, S_v]
-                    // For token i, we want sum_j(attn[i, j] * v_new[j, :])
-                    std::vector<float> attn_v_new(S_v, 0.0f);
+                }
+                
+                // Compute v_prime = k_cumdecay @ last_recurrent_state
+                float * v_prime = (float *) malloc(chunk_size * S_v * sizeof(float));
+                for (int64_t i = 0; i < chunk_size; i++) {
                     for (int64_t d = 0; d < S_v; d++) {
                     for (int64_t d = 0; d < S_v; d++) {
-                        for (int64_t j = 0; j < tokens_in_chunk; j++) {  // Only process actual tokens
-                            // Use the attention matrix that was computed for position i
-                            // attn[i * chunk_size + j] is the attention from position i to j
-                            // v_new_chunk[j * S_v + d] is the v_new value for token j, dimension d
-                            attn_v_new[d] += attn[i * chunk_size + j] * v_new_chunk[j * S_v + d];
-                        }
-                    }
-
-                    // Store output - only store for the original sequence length (not the padded part)
-                    int64_t global_pos =
-                        chunk_idx * chunk_size + i;  // Convert local chunk position to global sequence position
-                    if (global_pos < n_tokens) {     // Make sure we don't exceed the original sequence length
-                        for (int64_t d = 0; d < S_v; d++) {
-                            // Output tensor is [S_v * H_v * n_tokens] for single sequence (n_seqs=1)
-                            // Indexing: [dim_idx + head_idx*S_v + pos_idx*S_v*H_v]
-                            int64_t ggml_idx = d + h * S_v + global_pos * S_v * H_v;
-                            output[ggml_idx] = attn_inter[d] + attn_v_new[d];
+                        float sum = 0.0f;
+                        for (int64_t k = 0; k < S_k; k++) {
+                            sum += k_cumdecay[i * S_k + k] * last_recurrent_state[k * S_v + d];
                         }
                         }
+                        v_prime[i * S_v + d] = sum;
                     }
                     }
                 }
                 }
-
-                // Step 7: Update last_recurrent_state
-                std::vector<float> new_state_vec(S_v * S_v * H_v);
-
-                // Update running cumulative sum with current chunk's values
-                float prev_cumsum = 0.0f;  // Cumulative sum from all previous chunks
-                if (chunk_idx > 0) {
-                    // Get the cumulative sum of the last token from the previous chunk
-                    int64_t prev_chunk_last_token = std::min(chunk_size, n_tokens - (chunk_idx - 1) * chunk_size) - 1;
-                    if (prev_chunk_last_token >= 0) {
-                        prev_cumsum = running_cumsum[(chunk_idx - 1) * chunk_size + prev_chunk_last_token];
+                
+                // Debug prints for key intermediate values
+                float attn_sum = 0.0f;
+                for (int64_t i = 0; i < chunk_size * chunk_size; i++) {
+                    attn_sum += attn[i];
+                }
+                printf("C++ attn_sum = %f\n", attn_sum);
+                
+                // Debug: print first few values of attn matrix
+                printf("C++ attn first 5 values: %f, %f, %f, %f, %f\n",
+                       attn[0], attn[1], attn[2], attn[3], attn[4]);
+                
+                float v_prime_sum = 0.0f;
+                for (int64_t i = 0; i < chunk_size * S_v; i++) {
+                    v_prime_sum += v_prime[i];
+                }
+                printf("C++ v_prime_sum = %f\n", v_prime_sum);
+                
+                // Compute v_new = v_i - v_prime
+                float * v_new = (float *) malloc(chunk_size * S_v * sizeof(float));
+                for (int64_t i = 0; i < chunk_size; i++) {
+                    for (int64_t d = 0; d < S_v; d++) {
+                        v_new[i * S_v + d] = ((float *)src2->data)[v_offset + d * v_stride0 + i * v_stride1] - v_prime[i * S_v + d];
                     }
                     }
                 }
                 }
-
-                // Update running_cumsum for tokens in this chunk
-                for (int64_t t = 0; t < tokens_in_chunk; t++) {
-                    int64_t global_pos = chunk_idx * chunk_size + t;
-                    if (global_pos < n_tokens) {
-                        running_cumsum[global_pos] = prev_cumsum + g_cumsum[t];
-                    }
+                
+                float v_new_sum = 0.0f;
+                for (int64_t i = 0; i < chunk_size * S_v; i++) {
+                    v_new_sum += v_new[i];
                 }
                 }
-
-                // Find the last token position in the current chunk (not the entire sequence)
-                int64_t last_pos_in_chunk =
-                    std::min((chunk_idx + 1) * chunk_size, n_tokens) - 1;  // Last actual token in this chunk
-                if (last_pos_in_chunk >= chunk_idx * chunk_size && last_pos_in_chunk < n_tokens) {
-                    float g_last =
-                        running_cumsum[last_pos_in_chunk];  // Use the last token's cumulative sum in this chunk
-                    // Use double precision for exponential to avoid overflow/underflow
-                    double g_last_exp_double = exp((double) g_last);
-                    float  g_last_exp        = (float) g_last_exp_double;
-
-                    // last_state * g_exp[last]
-                    for (int64_t i = 0; i < S_k; i++) {
-                        for (int64_t j = 0; j < S_v; j++) {
-                            // State index: [i, j + h*S_v] in GGML format
-                            int64_t state_idx                              = i * (S_v * H_v) + (j + h * S_v);
-                            new_state_vec[i * (S_v * H_v) + (j + h * S_v)] = last_state[state_idx] * g_last_exp;
+                printf("C++ v_new_sum = %f\n", v_new_sum);
+                
+                // Compute attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
+                float * attn_inter = (float *) malloc(chunk_size * S_v * sizeof(float));
+                for (int64_t i = 0; i < chunk_size; i++) {
+                    for (int64_t d = 0; d < S_v; d++) {
+                        float sum = 0.0f;
+                        float g_exp = expf(g_vals[i * g_stride0]);
+                        for (int64_t k = 0; k < S_k; k++) {
+                            sum += ((float *)src0->data)[q_offset + k * q_stride0 + i * q_stride1] * g_exp * last_recurrent_state[k * S_v + d];
                         }
                         }
+                        attn_inter[i * S_v + d] = sum;
                     }
                     }
-
-                    // Add (k_i * (g_last - g_i).exp()).T @ v_new
-                    // This should be: (k_chunk * g_diff_exp).T @ v_new_chunk
-                    // where k_chunk is [chunk_size, S_k], v_new_chunk is [chunk_size, S_v]
-                    // result is [S_k, S_v]
-
-                    // First compute v_new for all positions in the chunk
-                    std::vector<float> v_new_chunk(chunk_size * S_v);
-                    for (int64_t i = 0; i < tokens_in_chunk; i++) {  // Only process actual tokens, not full chunk size
-                        for (int64_t d1 = 0; d1 < S_v; d1++) {
-                            // Recompute v_prime for this position
-                            float v_prime = 0.0f;
-                            for (int64_t d2 = 0; d2 < S_k; d2++) {
-                                // State index: [d2, d1 + h*S_v] in GGML format
-                                int64_t state_idx = d2 * (S_v * H_v) + (d1 + h * S_v);
-                                float   k_val     = k_cumdecay[i * S_k + d2];
-                                float   s_val     = last_state[state_idx];
-                                v_prime += k_val * s_val;
-                            }
-                            v_new_chunk[i * S_v + d1] = value[i * S_v + d1] - v_prime;
+                }
+                
+                float attn_inter_sum = 0.0f;
+                for (int64_t i = 0; i < chunk_size * S_v; i++) {
+                    attn_inter_sum += attn_inter[i];
+                }
+                printf("C++ attn_inter_sum = %f\n", attn_inter_sum);
+                
+                // Compute core_attn_out = attn_inter + attn @ v_new
+                // Output tensor layout: [S_v * H_v, n_tokens, 1, 1]
+                const int64_t out_offset = head_idx * (S_v * n_tokens) + chunk_idx * (S_v * chunk_size);
+                float * core_attn_out = output + out_offset;
+                
+                for (int64_t i = 0; i < chunk_size; i++) {
+                    for (int64_t d = 0; d < S_v; d++) {
+                        float sum = 0.0f;
+                        for (int64_t j = 0; j < chunk_size; j++) {
+                            sum += attn[i * chunk_size + j] * v_new[j * S_v + d];
                         }
                         }
+                        core_attn_out[i * S_v + d] = attn_inter[i * S_v + d] + sum;
                     }
                     }
-
-                    // Now compute (k_chunk * g_diff_exp).T @ v_new_chunk
-                    // This is a matrix multiplication: [S_k, chunk_size] @ [chunk_size, S_v] = [S_k, S_v]
-                    // Only process the original sequence length, not the padded chunk size
-                    // In the Python reference, this is: (k_i * g_diff_exp).transpose(-1, -2) @ v_new
-                    // where g_diff_exp = torch.exp(g_last - g) and g_last = g[-1] (last token in chunk)
-                    for (int64_t d1 = 0; d1 < S_k; d1++) {
-                        for (int64_t d2 = 0; d2 < S_v; d2++) {
-                            float sum = 0.0f;
-                            for (int64_t i = 0; i < tokens_in_chunk; i++) {  // Only process actual tokens
-                                // Get g values for the current chunk from the cumsum tensor (src3)
-                                // For state update: g_last (last token in chunk) - g_current (current token)
-                                // g tensor has shape [GGML_DELTA_NET_CHUNK, 1, H_v, n_seqs] in GGML format after cumsum and reshaping
-
-                                // Access g_cumsum for current position in chunk - need to access the original g tensor before cumsum
-                                // The g_cumsum tensor is src3, but we need the original g values for the diff computation
-                                // Actually, we need to access g values that were cumsummed to compute the diff
-
-                                // Get the original g_cumsum values for current and last token in the chunk
-                                // g_cumsum values are stored in src3, which was reshaped from [chunk_size, 1, H_v, n_seqs] to [chunk_size, 1, H_v, n_seqs]
-                                float g_current = g_cumsum[i];      // Use the g_cumsum computed earlier in this chunk
-                                float g_last =
-                                    g_cumsum[tokens_in_chunk - 1];  // Use the last token's cumsum in this chunk
-
-                                float g_diff = g_last - g_current;
-                                float g_diff_exp;
-                                // Use double precision for exponential to avoid overflow/underflow
-                                // For numerical stability, if g_diff is very negative, exp(g_diff) will be very small
-                                if (g_diff < -50.0f) {
-                                    g_diff_exp = 0.0f;  // Set to zero to avoid underflow
-                                } else {
-                                    double g_diff_exp_double = exp((double) g_diff);
-                                    g_diff_exp               = (float) g_diff_exp_double;
-                                }
-                                sum += k_chunk[i * S_k + d1] * g_diff_exp * v_new_chunk[i * S_v + d2];
-                            }
-                            // State index: [d1, d2 + h*S_v] in GGML format
-                            int64_t state_idx = d1 * (S_v * H_v) + (d2 + h * S_v);
-                            new_state_vec[state_idx] += sum;
-                        }
+                }
+                
+                float core_attn_out_sum = 0.0f;
+                for (int64_t i = 0; i < chunk_size * S_v; i++) {
+                    core_attn_out_sum += core_attn_out[i];
+                }
+                printf("C++ core_attn_out_sum = %f\n", core_attn_out_sum);
+                
+                // Update last_recurrent_state
+                // last_recurrent_state = (
+                //     last_recurrent_state * g[:, :, i, -1, None, None].exp()
+                //     + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
+                // )
+                
+                float g_last = g_vals[chunk_size - 1];
+                float g_last_exp = expf(g_last);
+                
+                // First part: last_recurrent_state * g_last_exp
+                for (int64_t k = 0; k < S_k; k++) {
+                    for (int64_t v = 0; v < S_v; v++) {
+                        last_recurrent_state[k * S_v + v] *= g_last_exp;
                     }
                     }
-
-                    // Update last_state
-                    for (int64_t i = 0; i < S_k; i++) {
-                        for (int64_t j = 0; j < S_v; j++) {
-                            // State index: [i, j + h*S_v] in GGML format
-                            int64_t state_idx     = i * (S_v * H_v) + (j + h * S_v);
-                            last_state[state_idx] = new_state_vec[state_idx];
+                }
+                
+                // Second part: (k_i * (g_last - g).exp()).transpose(-1, -2) @ v_new
+                float * k_gated = (float *) malloc(chunk_size * S_k * sizeof(float));
+                for (int64_t i = 0; i < chunk_size; i++) {
+                    float g_diff_exp = expf(g_last - g_vals[i]);
+                    for (int64_t d = 0; d < S_k; d++) {
+                        k_gated[i * S_k + d] = ((float *)src1->data)[k_offset + d * k_stride0 + i * k_stride1] * g_diff_exp;
+                    }
+                }
+                
+                // Compute k_gated.T @ v_new
+                for (int64_t k = 0; k < S_k; k++) {
+                    for (int64_t v = 0; v < S_v; v++) {
+                        float sum = 0.0f;
+                        for (int64_t i = 0; i < chunk_size; i++) {
+                            sum += k_gated[i * S_k + k] * v_new[i * S_v + v];
                         }
                         }
+                        last_recurrent_state[k * S_v + v] += sum;
                     }
                     }
                 }
                 }
-            }
-        }
-    }
-    // Copy the final state to the output tensor in the correct GGML layout
-    // GGML expects state layout: [d1, d2 + h*head_dim]
-    for (int64_t h = 0; h < H_v; h++) {
-        for (int64_t d1 = 0; d1 < S_v; d1++) {
-            for (int64_t d2 = 0; d2 < S_v; d2++) {
-                // GGML state index: [d1, d2 + h*head_dim]
-                int64_t ggml_state_idx     = d1 * (S_v * H_v) + (d2 + h * S_v);
-                // Our computed state index: [d1, d2 + h*S_v]
-                int64_t computed_state_idx = d1 * (S_v * H_v) + (d2 + h * S_v);
-                float   val                = last_state[computed_state_idx];
-                new_state[ggml_state_idx]  = val;
+                
+                // Free temporary memory
+                free(attn);
+                free(value);
+                free(k_cumdecay);
+                free(v_prime);
+                free(v_new);
+                free(attn_inter);
+                free(k_gated);
             }
             }
         }
         }
     }
     }

+ 0 - 118
ggml/src/ggml.c

@@ -5510,124 +5510,6 @@ struct ggml_tensor * ggml_rwkv_wkv7(
     return result;
     return result;
 }
 }
 
 
-// ggml_delta_net
-// prepare all the tensor data for the operation so we only
-// do the absolutely necessary steps in the op itself
-struct ggml_tensor * ggml_delta_net(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * q,
-        struct ggml_tensor  * k,
-        struct ggml_tensor  * v,
-        struct ggml_tensor  * g,
-        struct ggml_tensor  * beta,
-        struct ggml_tensor  * state,
-        bool                  use_qk_l2norm,
-        float                 scale,
-        float                 eps_norm
-    ) {
-    GGML_ASSERT(ggml_is_contiguous(q));
-    GGML_ASSERT(ggml_is_contiguous(k));
-    GGML_ASSERT(ggml_is_contiguous(v));
-    GGML_ASSERT(ggml_is_contiguous(g));
-    GGML_ASSERT(ggml_is_contiguous(beta));
-    GGML_ASSERT(ggml_is_contiguous(state));
-
-    const int64_t S_k = q->ne[0];
-    const int64_t H_k = q->ne[1];
-    const int64_t n_tokens = q->ne[2];
-    const int64_t n_seqs = q->ne[3];
-
-    const int64_t S_v = v->ne[0];
-    const int64_t H_v = v->ne[1];
-
-    GGML_ASSERT(v->ne[2] == n_tokens);
-    GGML_ASSERT(k->ne[2] == n_tokens);
-    GGML_ASSERT(g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
-    GGML_ASSERT(beta->ne[1] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
-    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == 1);
-
-    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
-    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
-
-    GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
-
-    if (use_qk_l2norm) {
-        q = ggml_l2_norm(ctx, q, eps_norm);
-        k = ggml_l2_norm(ctx, k, eps_norm);
-    }
-    q = ggml_scale(ctx, q, scale);
-
-    int64_t pad_size = (GGML_DELTA_NET_CHUNK - n_tokens % GGML_DELTA_NET_CHUNK) % GGML_DELTA_NET_CHUNK;
-    int64_t num_chunks = (n_tokens + pad_size) / GGML_DELTA_NET_CHUNK;
-
-    // First, permute to chunk format: [n_tokens, S_k, H_k, n_seqs]
-    q = ggml_cont(ctx, ggml_permute(ctx, q, 1, 2, 0, 3));
-    k = ggml_cont(ctx, ggml_permute(ctx, k, 1, 2, 0, 3));
-    v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3));
-    beta = ggml_cont(ctx, ggml_permute(ctx, beta, 1, 2, 0, 3));
-    g = ggml_cont(ctx, ggml_permute(ctx, g, 1, 2, 0, 3));
-
-    // Then, pad the sequence dimension (n_tokens) to chunk_size
-    q = ggml_pad(ctx, q, pad_size, 0, 0, 0); // [CS, S_k, H_k, n_seqs]
-    k = ggml_pad(ctx, k, pad_size, 0, 0, 0); // [CS, S_k, H_k, n_seqs]
-    v = ggml_pad(ctx, v, pad_size, 0, 0, 0); // [CS, S_v, H_v, n_seqs]
-    beta = ggml_pad(ctx, beta, pad_size, 0, 0, 0); // [CS, 1, H_v, n_seqs]
-    g = ggml_pad(ctx, g, pad_size, 0, 0, 0); // [CS, 1, H_v, n_seqs]
-    
-    GGML_ASSERT(q->ne[0] % GGML_DELTA_NET_CHUNK == 0 && q->ne[1] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
-    GGML_ASSERT(k->ne[0] % GGML_DELTA_NET_CHUNK == 0 && k->ne[1] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
-    GGML_ASSERT(v->ne[0] % GGML_DELTA_NET_CHUNK == 0 && v->ne[1] == S_v && v->ne[2] == H_v && v->ne[3] == n_seqs);
-    GGML_ASSERT(beta->ne[0] % GGML_DELTA_NET_CHUNK == 0 && beta->ne[1] == 1 && beta->ne[2] == H_v && beta->ne[3] == n_seqs);
-    GGML_ASSERT(g->ne[0] % GGML_DELTA_NET_CHUNK == 0 && g->ne[1] == 1 && g->ne[2] == H_v && g->ne[3] == n_seqs);
-
-    struct ggml_tensor * v_beta = ggml_mul(ctx, v, beta);
-    struct ggml_tensor * k_beta = ggml_mul(ctx, k, beta);
-    // struct ggml_tensor * mask = ggml_tri(ctx, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, GGML_DELTA_NET_CHUNK, GGML_DELTA_NET_CHUNK), 1.0f, GGML_TRI_TYPE_UPPER);
-    struct ggml_tensor * g_cumsum = ggml_cumsum(ctx, g);
-        
-    struct ggml_tensor * gcs_i = ggml_cont(ctx, g_cumsum);  // [chunk_size, 1, H_v, n_seqs]
-    struct ggml_tensor * gcs_j = ggml_cont(ctx, ggml_permute(ctx, g_cumsum, 1, 0, 2, 3));  // [1, chunk_size, H_v, n_seqs]
-    
-    // Broadcast both tensors to [chunk_size, chunk_size, H_v, n_seqs]
-    struct ggml_tensor * gcs_i_broadcast = ggml_repeat_4d(ctx, gcs_i, num_chunks * GGML_DELTA_NET_CHUNK, num_chunks * GGML_DELTA_NET_CHUNK, H_v, n_seqs);  // [chunk_size, 1, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
-    struct ggml_tensor * gcs_j_broadcast = ggml_repeat_4d(ctx, gcs_j, num_chunks * GGML_DELTA_NET_CHUNK, num_chunks * GGML_DELTA_NET_CHUNK, H_v, n_seqs);  // [1, chunk_size, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
-    
-    struct ggml_tensor * decay_mask = ggml_sub(ctx, gcs_j_broadcast, gcs_i_broadcast); 
-    
-    // Apply lower triangular mask to ensure attention is causal (only past tokens influence current)
-    decay_mask = ggml_tri_keep(ctx, decay_mask, GGML_TRI_TYPE_LOWER);   
-    // Apply exponential to get the decay mask values
-    decay_mask = ggml_exp(ctx, decay_mask);
-    // Apply lower triangular mask again to ensure only lower triangular values remain
-    decay_mask = ggml_tri_keep(ctx, decay_mask, GGML_TRI_TYPE_LOWER);
-
-    GGML_LOG_INFO("k_beta shape = [%ld, %ld, %ld, %ld], k shape = [%ld, %ld, %ld, %ld]\n", k_beta->ne[0], k_beta->ne[1], k_beta->ne[2], k_beta->ne[3], k->ne[0], k->ne[1], k->ne[2], k->ne[3]);
-    struct ggml_tensor * attn = ggml_neg(ctx, ggml_tri_keep(ctx, ggml_mul(ctx, ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx, k_beta)), ggml_cont(ctx, ggml_transpose(ctx, k))), decay_mask), GGML_TRI_TYPE_LOWER));
-    GGML_LOG_INFO("attn shape = [%ld, %ld, %ld, %ld]\n", attn->ne[0], attn->ne[1], attn->ne[2], attn->ne[3]);
-
-    // We'll be returning the result as a 1D tensor due to the dimensions mismatch of the state and output tensors
-    // Use original n_tokens for output size and padded chunk size for state size
-    const int64_t ne[1] = { (S_v * H_v * n_tokens) + (S_v * S_v * H_v * n_seqs) };
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 1, ne);
-
-    ggml_set_op_params_i32(result, 0, H_v);
-    ggml_set_op_params_i32(result, 1, S_k);
-    ggml_set_op_params_i32(result, 2, S_v);
-    ggml_set_op_params_i32(result, 3, n_tokens); // Pass original n_tokens
-
-    result->op     = GGML_OP_DELTA_NET;
-    result->src[0] = q;
-    result->src[1] = k;
-    result->src[2] = v;
-    result->src[3] = g_cumsum;
-    result->src[4] = state;
-    result->src[5] = decay_mask;
-    result->src[6] = v_beta;
-    result->src[7] = k_beta;
-    result->src[8] = attn;
-
-    return result;
-}
 
 
 // ggml_unary
 // ggml_unary
 
 

+ 1 - 1
src/llama-context.cpp

@@ -1362,7 +1362,7 @@ void llama_context::output_reorder() {
 //
 //
 
 
 uint32_t llama_context::graph_max_nodes() const {
 uint32_t llama_context::graph_max_nodes() const {
-    return std::max<uint32_t>(1024u, 8u*model.n_tensors());
+    return std::max<uint32_t>(1024u, 32u*model.n_tensors());
 }
 }
 
 
 llm_graph_result * llama_context::get_gf_res_reserve() const {
 llm_graph_result * llama_context::get_gf_res_reserve() const {

+ 178 - 18
src/models/llm_build_qwen3next.cpp

@@ -1,4 +1,5 @@
 #include "llm_build_qwen3next.h"
 #include "llm_build_qwen3next.h"
+#include "../../ggml/src/ggml-impl.h"
 
 
 #include <cmath>
 #include <cmath>
 
 
@@ -98,7 +99,7 @@ struct ggml_tensor * llm_build_qwen3next::build_q3n_norm(struct ggml_tensor * in
     return build_norm(input, input_norm, nullptr, LLM_NORM_RMS, layer);
     return build_norm(input, input_norm, nullptr, LLM_NORM_RMS, layer);
 }
 }
 
 
-ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_tensor *             cur,
+struct ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_tensor *             cur,
                                                                    ggml_tensor *             inp_pos,
                                                                    ggml_tensor *             inp_pos,
                                                                    llm_graph_input_attn_kv * inp_attn,
                                                                    llm_graph_input_attn_kv * inp_attn,
                                                                    const llama_model &       model,
                                                                    const llama_model &       model,
@@ -152,12 +153,177 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_tensor *
     return cur;
     return cur;
 }
 }
 
 
+// delta_net
+// prepare all the tensor data for the operation so we only
+// do the absolutely necessary steps in the op itself
+struct ggml_tensor * llm_build_qwen3next::delta_net(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * q,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * g,
+        struct ggml_tensor  * beta,
+        struct ggml_tensor  * state,
+        bool                  use_qk_l2norm,
+        float                 eps_norm,
+        const int             il
+    ) {
+    GGML_ASSERT(ggml_is_contiguous(q));
+    GGML_ASSERT(ggml_is_contiguous(k));
+    GGML_ASSERT(ggml_is_contiguous(v));
+    GGML_ASSERT(ggml_is_contiguous(g));
+    GGML_ASSERT(ggml_is_contiguous(beta));
+    GGML_ASSERT(ggml_is_contiguous(state));
+
+    const int64_t S_k = q->ne[0];
+    const int64_t H_k = q->ne[1];
+    const int64_t n_tokens = q->ne[2];
+    const int64_t n_seqs = q->ne[3];
+
+    const int64_t S_v = v->ne[0];
+    const int64_t H_v = v->ne[1];
+
+    GGML_ASSERT(v->ne[2] == n_tokens);
+    GGML_ASSERT(k->ne[2] == n_tokens);
+    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
+    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
+    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == 1);
+
+    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+
+    GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
+
+    cb(q, "q_prenorm", il);
+    cb(k, "k_prenorm", il);
+
+    if (use_qk_l2norm) {
+        q = ggml_l2_norm(ctx, q, eps_norm);
+        k = ggml_l2_norm(ctx, k, eps_norm);
+    }
+
+    cb(k, "k_postnorm", il);
+    cb(q, "q_prescale", il);
+
+    int64_t pad_size = (GGML_DELTA_NET_CHUNK - n_tokens % GGML_DELTA_NET_CHUNK) % GGML_DELTA_NET_CHUNK;
+    // yes, n_tokens, not H_k, the reference implementation has wrong naming
+    int64_t num_chunks = (n_tokens + pad_size) / GGML_DELTA_NET_CHUNK;
+
+    float scale = 1.0f / sqrtf(S_v);
+    q = ggml_scale(ctx, q, scale);
+
+    cb(beta, "beta_raw", il);
+    beta = ggml_sigmoid(ctx, beta);
+
+    cb(q, "q_postscale", il);
+    cb(beta, "beta_sigmoid", il);   
+
+    // First, permute to chunked format: [S_k, n_tokens, H_k, n_seqs]
+    q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3));
+    cb(q, "q_reshape", il);
+    k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3));
+    cb(k, "k_reshape", il);
+    v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3));
+    cb(v, "v_reshape", il);
+    
+    beta = ggml_cont(ctx, ggml_permute(ctx, beta, 1, 2, 0, 3));
+    cb(beta, "beta_reshape", il);
+
+    g = ggml_cont(ctx, ggml_permute(ctx, g, 1, 0, 3, 2));
+    cb(g, "g_reshape", il);
+
+    // Then, pad the second dimension (n_tokens) to chunk_size
+    q = ggml_pad(ctx, q, 0, pad_size, 0, 0); 
+    k = ggml_pad(ctx, k, 0, pad_size, 0, 0);
+    v = ggml_pad(ctx, v, 0, pad_size, 0, 0);
+    // ... except for beta and g, where we pad the last dimension
+    beta = ggml_pad(ctx, beta, pad_size, 0, 0, 0);
+    g = ggml_pad(ctx, g, pad_size, 0, 0, 0);
+
+    cb(q, "q_pad", il);
+    cb(k, "k_pad", il);
+    cb(v, "v_pad", il);
+    cb(beta, "beta_pad", il);
+    cb(g, "g_pad", il);
+    
+    GGML_ASSERT(q->ne[1] % GGML_DELTA_NET_CHUNK == 0 && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[1] % GGML_DELTA_NET_CHUNK == 0 && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
+    GGML_ASSERT(v->ne[1] % GGML_DELTA_NET_CHUNK == 0 && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
+    GGML_ASSERT(beta->ne[0] % GGML_DELTA_NET_CHUNK == 0 && beta->ne[1] == H_k && beta->ne[2] == 1 && beta->ne[3] == n_seqs);
+    GGML_ASSERT(g->ne[0] % GGML_DELTA_NET_CHUNK == 0 && g->ne[1] == H_k && g->ne[2] == 1 && g->ne[3] == n_seqs);
+
+    ggml_tensor * beta_unsq = ggml_cont_4d(ctx, beta, 1, num_chunks * GGML_DELTA_NET_CHUNK, H_k, n_seqs);
+    ggml_tensor * beta_bcast = ggml_repeat_4d(ctx, beta_unsq, S_v, num_chunks * GGML_DELTA_NET_CHUNK, H_k, n_seqs);
+    cb(beta_unsq, "beta_unsq", il);
+    cb(beta_bcast, "beta_bcast", il);
+
+    struct ggml_tensor * v_beta = ggml_mul(ctx, v, beta_bcast);
+    cb(v_beta, "v_beta", il);
+    struct ggml_tensor * k_beta = ggml_mul(ctx, k, beta_bcast);
+    cb(k_beta, "k_beta", il);
+    struct ggml_tensor * g_cumsum = ggml_cumsum(ctx, g);
+    cb(g_cumsum, "g_cumsum", il);
+        
+    struct ggml_tensor * gcs_i = ggml_cont_4d(ctx, g_cumsum, num_chunks * GGML_DELTA_NET_CHUNK, 1, H_v, n_seqs);  // [chunk_size, 1, n_tokens, n_seqs]
+    struct ggml_tensor * gcs_j = ggml_cont_4d(ctx, g_cumsum, 1, num_chunks * GGML_DELTA_NET_CHUNK, H_v, n_seqs);  // [1, chunk_size, n_tokens, n_seqs]
+    
+    // Broadcast both tensors to [chunk_size, chunk_size, H_v, n_seqs]
+    struct ggml_tensor * gcs_i_broadcast = ggml_repeat_4d(ctx, gcs_i, num_chunks * GGML_DELTA_NET_CHUNK, num_chunks * GGML_DELTA_NET_CHUNK, H_v, n_seqs);  // [chunk_size, 1, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
+    struct ggml_tensor * gcs_j_broadcast = ggml_repeat_4d(ctx, gcs_j, num_chunks * GGML_DELTA_NET_CHUNK, num_chunks * GGML_DELTA_NET_CHUNK, H_v, n_seqs);  // [1, chunk_size, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
+    
+    struct ggml_tensor * decay_mask = ggml_sub(ctx, gcs_j_broadcast, gcs_i_broadcast); 
+    cb(decay_mask, "sub", il);
+    
+    // Apply lower triangular mask to ensure attention is causal (only past tokens influence current)
+    decay_mask = ggml_tri_keep(ctx, decay_mask, GGML_TRI_TYPE_LOWER_DIAG); 
+    cb(decay_mask, "sub_tri", il);
+    // Apply exponential to get the decay mask values
+    decay_mask = ggml_exp(ctx, decay_mask);
+    cb(decay_mask, "sub_tri_exp", il);
+    // Apply lower triangular mask again to ensure only lower triangular values remain
+    decay_mask = ggml_tri_keep(ctx, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
+    cb(decay_mask, "decay_mask", il);
+
+    struct ggml_tensor * kmulkbeta = ggml_mul_mat(ctx, ggml_cont(ctx, k_beta), ggml_cont(ctx, k));
+    cb(kmulkbeta, "k @ k_beta", il);
+    
+    struct ggml_tensor * k_decay = ggml_mul(ctx, kmulkbeta, decay_mask);
+    cb(k_decay, "(k @ k_beta) * decay_mask", il);
+
+    struct ggml_tensor * attn = ggml_neg(ctx, ggml_tri_keep(ctx, k_decay, GGML_TRI_TYPE_LOWER));
+    cb(attn, "attn_in", il);
+
+    // We'll be returning the result as a 1D tensor due to the dimensions mismatch of the state and output tensors
+    // Use original n_tokens for output size and padded chunk size for state size
+    const int64_t ne[1] = { (S_v * H_v * n_tokens) + (S_v * S_v * H_v * n_seqs) };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 1, ne);
+
+    ggml_set_op_params_i32(result, 0, H_v);
+    ggml_set_op_params_i32(result, 1, S_k);
+    ggml_set_op_params_i32(result, 2, S_v);
+    ggml_set_op_params_i32(result, 3, n_tokens); // Pass original n_tokens
+
+    result->op     = GGML_OP_DELTA_NET;
+    result->src[0] = q;
+    result->src[1] = k;
+    result->src[2] = v;
+    result->src[3] = g_cumsum;
+    result->src[4] = state;
+    result->src[5] = decay_mask;
+    result->src[6] = v_beta;
+    result->src[7] = k_beta;
+    result->src[8] = attn;
+
+    return result;
+}
+
+
 ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_input_rs * inp,
 ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_input_rs * inp,
                                                                      ggml_tensor *        cur,
                                                                      ggml_tensor *        cur,
                                                                      const llama_model &  model,
                                                                      const llama_model &  model,
                                                                      const llama_ubatch & ubatch,
                                                                      const llama_ubatch & ubatch,
                                                                      int                  il) {
                                                                      int                  il) {
-    // Gated Delta Net implementation using the new ggml_delta_net function
+    // Gated Delta Net implementation using the new delta_net function
     const auto * mctx_cur = inp->mctx;
     const auto * mctx_cur = inp->mctx;
 
 
     const int64_t d_inner  = hparams.ssm_d_inner;
     const int64_t d_inner  = hparams.ssm_d_inner;
@@ -352,31 +518,25 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_tokens, n_seqs);
     k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_tokens, n_seqs);
     v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_tokens, n_seqs);
     v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_tokens, n_seqs);
 
 
-    beta  = ggml_cont_4d(ctx0, b, 1, num_v_heads, n_tokens, n_seqs);
-    alpha = ggml_cont_4d(ctx0, a, 1, num_v_heads, n_tokens, n_seqs);
+    beta  = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_tokens, n_seqs);
 
 
     ggml_tensor * state = ggml_reshape_4d(ctx0, ssm_states_all, head_dim, head_dim * n_heads, 1, 1);
     ggml_tensor * state = ggml_reshape_4d(ctx0, ssm_states_all, head_dim, head_dim * n_heads, 1, 1);
-    gate = ggml_reshape_4d(ctx0, gate, 1, n_heads, n_tokens, n_seqs);
 
 
-        // if head keys and value keys are different, repeat to force tensors into matching shapes
+    // if head keys and value keys are different, repeat to force tensors into matching shapes
     if (num_k_heads != num_v_heads) {
     if (num_k_heads != num_v_heads) {
         GGML_ASSERT(num_v_heads % num_k_heads == 0);
         GGML_ASSERT(num_v_heads % num_k_heads == 0);
         int64_t repeat_factor = num_v_heads / num_k_heads;
         int64_t repeat_factor = num_v_heads / num_k_heads;
 
 
-        q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, n_tokens, num_k_heads, n_seqs);
-        k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, n_tokens, num_k_heads, n_seqs);
-
-        q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, n_tokens * repeat_factor, num_k_heads, n_seqs);
-        k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, n_tokens * repeat_factor, num_k_heads, n_seqs);
-
-        // Fix dimension order: last two should be [tokens, batches]
-        q_conv = ggml_reshape_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_tokens, n_seqs);
-        k_conv = ggml_reshape_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_tokens, n_seqs);
+        q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_k_heads * repeat_factor, n_tokens, n_seqs);
+        k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_k_heads * repeat_factor, n_tokens, n_seqs);
     }
     }
 
 
-    // Call the new ggml_delta_net function with the corrected flow
-    const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(num_k_heads)) : hparams.f_attention_scale;
-    ggml_tensor * attn_out = ggml_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, true, kq_scale, hparams.f_norm_rms_eps);
+    cb(q_conv, "q_conv_predelta", il);
+    cb(k_conv, "k_conv_predelta", il);
+    cb(v_conv, "v_conv_predelta", il);
+
+    // Call the new delta_net function with the corrected flow
+    ggml_tensor * attn_out = delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, true, hparams.f_norm_rms_eps, il);
     cb(attn_out, "attn_out", il);
     cb(attn_out, "attn_out", il);
 
 
     // The tensors were concatenated 1d, so we need to extract them 1d as well
     // The tensors were concatenated 1d, so we need to extract them 1d as well

+ 12 - 11
src/models/llm_build_qwen3next.h

@@ -10,17 +10,18 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
     llm_build_qwen3next(const llama_model & model, const llm_graph_params & params);
     llm_build_qwen3next(const llama_model & model, const llm_graph_params & params);
 
 
 private:
 private:
-    // ggml_delta_net
-    ggml_tensor * ggml_delta_net_op(struct ggml_tensor * q,
-                                   struct ggml_tensor * k,
-                                   struct ggml_tensor * v,
-                                   struct ggml_tensor * g,
-                                   struct ggml_tensor * beta,
-                                   struct ggml_tensor * state,
-                                   bool                 use_qk_l2norm,
-                                   float                scale,
-                                   float                eps_norm,
-                                   int                  il);
+    // delta_net
+    struct ggml_tensor * delta_net(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * q,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * g,
+        struct ggml_tensor  * beta,
+        struct ggml_tensor  * state,
+        bool                  use_qk_l2norm,
+        float                 eps_norm,
+        const int             il);
 
 
     ggml_tensor * build_qwen3next_attention_layer(ggml_tensor *             cur,
     ggml_tensor * build_qwen3next_attention_layer(ggml_tensor *             cur,
                                                   ggml_tensor *             inp_pos,
                                                   ggml_tensor *             inp_pos,