فهرست منبع

Now that eval's running move delta net stuff back to llama-model, add cbs

Piotr Wilkin 3 ماه پیش
والد
کامیت
43eb7a7757
5فایلهای تغییر یافته به همراه183 افزوده شده و 321 حذف شده
  1. 0 1
      ggml/CMakeLists.txt
  2. 0 43
      ggml/include/ggml-delta.h
  3. 0 2
      ggml/src/CMakeLists.txt
  4. 0 197
      ggml/src/ggml-delta.c
  5. 183 78
      src/llama-model.cpp

+ 0 - 1
ggml/CMakeLists.txt

@@ -273,7 +273,6 @@ set(GGML_PUBLIC_HEADERS
     include/ggml-cpp.h
     include/ggml-cuda.h
     include/ggml-opt.h
-    include/ggml-delta.h
     include/ggml-metal.h
     include/ggml-rpc.h
     include/ggml-sycl.h

+ 0 - 43
ggml/include/ggml-delta.h

@@ -1,43 +0,0 @@
-#pragma once
-
-#include "ggml-backend.h"
-#include "ggml.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-// Delta-Net linear layer activation
-// Implements the complete Delta-Net gated linear attention mechanism
-// This includes causal convolution preprocessing and gated delta rule computation
-// k, v, q, g: [S, H, n_tokens, n_seqs] - key, value, query, gate tensors
-// conv_weight: [conv_dim, 1, conv_kernel_size] - convolution kernel weights
-// conv_bias: [conv_dim] - convolution bias (optional, can be NULL)
-// beta: [H, n_tokens, n_seqs] - beta parameter for delta rule
-// state: [S, S, H, n_seqs] - recurrent state tensor
-// chunk_size: chunk size for chunked computation (0 for recurrent mode)
-// use_qk_l2norm: whether to apply L2 normalization to query and key
-// scale: attention scaling factor
-GGML_API struct ggml_tensor * ggml_delta_net(struct ggml_context * ctx,
-                                             struct ggml_tensor *  k,
-                                             struct ggml_tensor *  v,
-                                             struct ggml_tensor *  q,
-                                             struct ggml_tensor *  g,
-                                             struct ggml_tensor *  beta,
-                                             struct ggml_tensor *  state,
-                                             bool                  use_qk_l2norm,
-                                             float                 scale);
-
-GGML_API struct ggml_tensor * ggml_delta_net_op(struct ggml_context * ctx,
-                                                struct ggml_tensor *  q,
-                                                struct ggml_tensor *  k,
-                                                struct ggml_tensor *  v,
-                                                struct ggml_tensor *  g,
-                                                struct ggml_tensor *  beta,
-                                                struct ggml_tensor *  state,
-                                                bool                  use_qk_l2norm,
-                                                float                 scale);
-
-#ifdef __cplusplus
-}
-#endif

+ 0 - 2
ggml/src/CMakeLists.txt

@@ -194,9 +194,7 @@ add_library(ggml-base
             ../include/ggml-cpp.h
             ../include/ggml-opt.h
             ../include/gguf.h
-            ../include/ggml-delta.h
             ggml.c
-            ggml-delta.c
             ggml.cpp
             ggml-alloc.c
             ggml-backend.cpp

+ 0 - 197
ggml/src/ggml-delta.c

@@ -1,197 +0,0 @@
-#include "ggml.h"
-#include "ggml-delta.h"
-#include "ggml-impl.h"
-
-static void report_tensor_size(const char * tensor_name, const struct ggml_tensor * tensor) {
-#ifdef HAVE_DEBUG_DELTA_NET
-    GGML_LOG_INFO("[%s] tensor size is [%lu, %lu, %lu, %lu], strides [%lu, %lu, %lu, %lu]\n", 
-        tensor_name, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3],
-        tensor->nb[0], tensor->nb[1], tensor->nb[2], tensor->nb[3]);
-#endif
-}
-
-// ggml_delta_net
-struct ggml_tensor * ggml_delta_net(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * k,
-        struct ggml_tensor  * v,
-        struct ggml_tensor  * q,
-        struct ggml_tensor  * g,
-        struct ggml_tensor  * beta,
-        struct ggml_tensor  * state,
-        bool                  use_qk_l2norm,
-        float                 scale) {
-    
-    GGML_ASSERT(ggml_is_contiguous(k));
-    GGML_ASSERT(ggml_is_contiguous(v));
-    GGML_ASSERT(ggml_is_contiguous(q));
-    GGML_ASSERT(ggml_is_contiguous(g));
-    GGML_ASSERT(ggml_is_contiguous(beta));
-    GGML_ASSERT(ggml_is_contiguous(state));
-    report_tensor_size("orig_k", k);
-    report_tensor_size("orig_v", v);
-    report_tensor_size("orig_q", q);
-    report_tensor_size("orig_g", g);
-    report_tensor_size("orig_beta", beta);
-    report_tensor_size("orig_state", state);
-    
-    const int64_t S_k = k->ne[0];
-    const int64_t H_k = k->ne[1];
-    const int64_t n_tokens = k->ne[2];  
-    const int64_t n_seqs = k->ne[3];
-    
-    const int64_t S_v = v->ne[0];
-    const int64_t H_v = v->ne[1];
-    
-    GGML_ASSERT(v->ne[2] == n_tokens);
-    GGML_ASSERT(q->ne[2] == n_tokens);
-    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[1] == n_tokens && beta->ne[3] == n_seqs);
-    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == n_seqs && state->ne[3] == n_tokens);
-    
-    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens);
-    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens);
-       
-    GGML_ASSERT(g->ne[0] == S_v && g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
-       
-    // Beta sigmoid
-    struct ggml_tensor * beta_sigmoid = ggml_sigmoid(ctx, beta);
-    report_tensor_size("beta_sigmoid", beta_sigmoid);
-
-    // Gate calculations are done elsewhere in llama-model.cpp
-
-    struct ggml_tensor * q_broadcast = q;
-    struct ggml_tensor * k_broadcast = k;
-    
-    // if head keys and value keys are different, repeat to force tensors into matching shapes
-    if (H_k != H_v) {
-        GGML_ASSERT(H_v % H_k == 0);
-        int64_t repeat_factor = H_v / H_k;
-        
-        q_broadcast = ggml_cont_4d(ctx, q, S_k, n_tokens, H_k, n_seqs);
-        report_tensor_size("q_broadcast_reshape1", q_broadcast);
-        k_broadcast = ggml_cont_4d(ctx, k, S_k, n_tokens, H_k, n_seqs);
-        report_tensor_size("k_broadcast_reshape1", k_broadcast);
-        
-        q_broadcast = ggml_repeat_4d(ctx, q_broadcast, S_k, n_tokens * repeat_factor, H_k, n_seqs);
-        report_tensor_size("q_broadcast_repeat", q_broadcast);
-        k_broadcast = ggml_repeat_4d(ctx, k_broadcast, S_k, n_tokens * repeat_factor, H_k, n_seqs);
-        report_tensor_size("k_broadcast_repeat", k_broadcast);
-        
-        q_broadcast = ggml_reshape_4d(ctx, q_broadcast, S_k, H_v, n_seqs, n_tokens);
-        report_tensor_size("q_broadcast_reshape2", q_broadcast);
-        k_broadcast = ggml_reshape_4d(ctx, k_broadcast, S_k, H_v, n_seqs, n_tokens);
-        report_tensor_size("k_broadcast_reshape2", k_broadcast);
-    }
-
-    struct ggml_tensor * v_reshape = ggml_cont_4d(ctx, v, S_v, H_v, n_seqs, n_tokens);
-    report_tensor_size("v_reshape", v_reshape);
-    struct ggml_tensor * g_reshape = ggml_cont_4d(ctx, g, S_v, H_v, n_seqs, n_tokens);
-    report_tensor_size("g_reshape", g_reshape);
-    struct ggml_tensor * beta_broadcast = ggml_cont_4d(ctx, beta, 1, H_v, n_seqs, n_tokens);
-    report_tensor_size("beta_broadcast", beta_broadcast);
-    struct ggml_tensor * state_broadcast = ggml_cont(ctx, state);
-    report_tensor_size("state_broadcast", state_broadcast);
-    
-    return ggml_delta_net_op(ctx, q_broadcast, k_broadcast, v_reshape, g_reshape, beta_broadcast, state_broadcast, use_qk_l2norm, scale);
-}
-
-struct ggml_tensor * ggml_delta_net_op(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * q,
-        struct ggml_tensor  * k,
-        struct ggml_tensor  * v,
-        struct ggml_tensor  * g,
-        struct ggml_tensor  * beta,
-        struct ggml_tensor  * state,
-        bool                  use_qk_l2norm,
-        float                 scale) {
-    
-    // Debug: Log input tensor dimensions
-    report_tensor_size("q_input", q);
-    report_tensor_size("k_input", k);
-    report_tensor_size("v_input", v);
-    report_tensor_size("g_input", g);
-    report_tensor_size("beta_input", beta);
-    report_tensor_size("state_input", state);
-    
-    GGML_ASSERT(ggml_is_contiguous(q));
-    GGML_ASSERT(ggml_is_contiguous(k));
-    GGML_ASSERT(ggml_is_contiguous(v));
-    GGML_ASSERT(ggml_is_contiguous(g));
-    GGML_ASSERT(ggml_is_contiguous(beta));
-    GGML_ASSERT(ggml_is_contiguous(state));
-    
-    const int64_t S_k = q->ne[0];  
-    const int64_t H_k = q->ne[1];  
-    const int64_t n_seq = q->ne[2];  
-    const int64_t n_tokens = q->ne[3];
-    
-    const int64_t S_v = v->ne[0];  
-    const int64_t H_v = v->ne[1];
-
-    GGML_ASSERT(H_k == H_v); // we broadcasted the tensors in the main function to guarantee this
-    
-    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_v && k->ne[2] == n_seq && k->ne[3] == n_tokens);
-    GGML_ASSERT(v->ne[1] == H_v && v->ne[2] == n_seq && v->ne[3] == n_tokens);
-    GGML_ASSERT(g->ne[0] == S_v && g->ne[1] == H_v && g->ne[2] == n_seq && g->ne[3] == n_tokens);
-    GGML_ASSERT(beta->ne[0] == 1 && beta->ne[1] == H_v && beta->ne[2] == n_seq && beta->ne[3] == n_tokens);
-    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == n_seq && state->ne[3] == n_tokens);
-       
-    struct ggml_tensor * new_state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S_v, S_v * H_v, n_seq, n_tokens);
-    
-    new_state = ggml_cpy(ctx, state, new_state);
-    report_tensor_size("new_state_copied", new_state);
-    
-    if (use_qk_l2norm) {
-        q = ggml_l2_norm(ctx, q, 1e-6f);
-        report_tensor_size("q_l2norm", q);
-        k = ggml_l2_norm(ctx, k, 1e-6f);
-        report_tensor_size("k_l2norm", k);
-    }
-    
-    q = ggml_scale(ctx, q, scale);
-    report_tensor_size("q_scaled", q);
-    
-    struct ggml_tensor * state_flat = ggml_reshape_2d(ctx, new_state, S_v * S_v * H_v, n_seq * n_tokens);
-    report_tensor_size("state_flat", state_flat);
-                                  
-    struct ggml_tensor * state_decay = ggml_mul(ctx, state, g);
-    report_tensor_size("state_decay", state_decay);
-               
-    struct ggml_tensor * kv_mem_presum = ggml_mul(ctx, state_decay, k);
-    report_tensor_size("kv_mem_presum", kv_mem_presum);
-
-    // Gotta do some squeezing here...
-    struct ggml_tensor * kv_mem_presum_squeeze = ggml_reshape_4d(ctx, kv_mem_presum, S_v, S_v, H_v, n_seq * n_tokens);
-    report_tensor_size("kv_mem_presum_sequeeze", kv_mem_presum_squeeze);
-
-    struct ggml_tensor * kv_mem = ggml_permute(ctx, ggml_sum_rows(ctx, ggml_cont(ctx, ggml_permute(ctx, kv_mem_presum_squeeze, 1, 2, 0, 3))), 2, 0, 1, 3);
-    report_tensor_size("kv_mem", kv_mem);
-
-    struct ggml_tensor * kv_mem_reshape = ggml_reshape_4d(ctx, kv_mem, S_v, S_v, n_seq, n_tokens);
-    report_tensor_size("kv_mem_reshape", kv_mem_reshape);
-                
-    struct ggml_tensor * delta = ggml_mul(ctx, ggml_sub(ctx, kv_mem_reshape, v), beta);
-    report_tensor_size("delta", delta);
-
-    struct ggml_tensor * delta_kt = ggml_mul(ctx, delta, k);
-    report_tensor_size("delta_kt", delta_kt);
-
-    struct ggml_tensor * state_plus_k_delta = ggml_add(ctx, state_decay, delta_kt);
-    report_tensor_size("state_plus_k_delta", state_plus_k_delta);
-
-    struct ggml_tensor * state_q = ggml_mul(ctx, state_plus_k_delta, q);
-    report_tensor_size("state_q", state_q);
-
-    // And here...
-    state_q = ggml_reshape_4d(ctx, state_q, S_v, S_v, H_v, n_seq * n_tokens);
-    struct ggml_tensor * output = ggml_permute(ctx, ggml_sum_rows(ctx, state_q), 2, 0, 1, 3);
-    output = ggml_reshape_4d(ctx, output, S_v, H_v, n_seq, n_tokens);
-    report_tensor_size("output", output);
-    
-    struct ggml_tensor * result = ggml_concat(ctx, output, state_plus_k_delta, 1);
-    report_tensor_size("result_final", result);
-    return result;
-}
-
-

+ 183 - 78
src/llama-model.cpp

@@ -12,7 +12,6 @@
 #include "llama-memory-recurrent.h"
 
 #include "ggml-cpp.h"
-#include "ggml-delta.h"
 
 #include <algorithm>
 #include <cassert>
@@ -18970,9 +18969,7 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
             struct ggml_tensor * inpSA = inpL;
 
             // Pre-norm for attention/linear attention
-            cur = build_norm(inpL,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, il);
+            cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
             cb(cur, "attn_norm", il);
 
             // Determine layer type and build appropriate attention mechanism
@@ -18981,19 +18978,15 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
                 cur = build_qwen3next_linear_attn_layer(inp->get_recr(), cur, model, ubatch, il);
             } else {
                 // Full attention layer
-                cur = build_qwen3next_attention_layer(
-                    cur, inp_pos, inp->get_attn(), model,
-                    n_embd_head, il);
+                cur = build_qwen3next_attention_layer(cur, inp_pos, inp->get_attn(), model, n_embd_head, il);
             }
 
             // Post-attention norm
-            cur = build_norm(cur,
-                    model.layers[il].attn_post_norm, NULL,
-                    LLM_NORM_RMS, il);
+            cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il);
             cb(cur, "attn_post_norm", il);
 
             if (il == n_layer - 1 && inp_out_ids) {
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
                 inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
             }
 
@@ -19011,9 +19004,7 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
         cur = inpL;
 
         // Final norm
-        cur = build_norm(cur,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, -1);
+        cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
 
         cb(cur, "result_norm", -1);
         res->t_embd = cur;
@@ -19028,15 +19019,148 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
         ggml_build_forward_expand(gf, cur);
     }
 
-private:
-    ggml_tensor * build_qwen3next_attention_layer(
-              ggml_tensor             * cur,
-              ggml_tensor             * inp_pos,
-              llm_graph_input_attn_kv * inp_attn,
-        const llama_model             & model,
-        const int64_t                 n_embd_head,
-        const int                     il) {
+  private:
+    // ggml_delta_net
+    struct ggml_tensor * ggml_delta_net(struct ggml_tensor *  k, struct ggml_tensor *  v, struct ggml_tensor *  q, struct ggml_tensor *  g,
+                                        struct ggml_tensor *  beta, struct ggml_tensor *  state, bool use_qk_l2norm, float scale, int il) {
+        GGML_ASSERT(ggml_is_contiguous(k));
+        GGML_ASSERT(ggml_is_contiguous(v));
+        GGML_ASSERT(ggml_is_contiguous(q));
+        GGML_ASSERT(ggml_is_contiguous(g));
+        GGML_ASSERT(ggml_is_contiguous(beta));
+        GGML_ASSERT(ggml_is_contiguous(state));
+
+        const int64_t S_k      = k->ne[0];
+        const int64_t H_k      = k->ne[1];
+        const int64_t n_tokens = k->ne[2];
+        const int64_t n_seqs   = k->ne[3];
+
+        const int64_t S_v = v->ne[0];
+        const int64_t H_v = v->ne[1];
+
+        GGML_ASSERT(v->ne[2] == n_tokens);
+        GGML_ASSERT(q->ne[2] == n_tokens);
+        GGML_ASSERT(beta->ne[0] == H_v && beta->ne[1] == n_tokens && beta->ne[3] == n_seqs);
+        GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == n_seqs &&
+                    state->ne[3] == n_tokens);
+
+        GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens);
+        GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens);
+
+        GGML_ASSERT(g->ne[0] == S_v && g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
+
+        // Beta sigmoid
+        struct ggml_tensor * beta_sigmoid = ggml_sigmoid(ctx0, beta);
+        cb(beta_sigmoid, "beta_sigmoid", il);
+
+        // Gate calculations are done elsewhere in llama-model.cpp
+
+        struct ggml_tensor * q_broadcast = q;
+        struct ggml_tensor * k_broadcast = k;
+
+        // if head keys and value keys are different, repeat to force tensors into matching shapes
+        if (H_k != H_v) {
+            GGML_ASSERT(H_v % H_k == 0);
+            int64_t repeat_factor = H_v / H_k;
 
+            q_broadcast = ggml_cont_4d(ctx0, q, S_k, n_tokens, H_k, n_seqs);
+            k_broadcast = ggml_cont_4d(ctx0, k, S_k, n_tokens, H_k, n_seqs);
+
+            q_broadcast = ggml_repeat_4d(ctx0, q_broadcast, S_k, n_tokens * repeat_factor, H_k, n_seqs);
+            k_broadcast = ggml_repeat_4d(ctx0, k_broadcast, S_k, n_tokens * repeat_factor, H_k, n_seqs);
+
+            q_broadcast = ggml_reshape_4d(ctx0, q_broadcast, S_k, H_v, n_seqs, n_tokens);
+            k_broadcast = ggml_reshape_4d(ctx0, k_broadcast, S_k, H_v, n_seqs, n_tokens);
+        }
+
+        struct ggml_tensor * v_reshape = ggml_cont_4d(ctx0, v, S_v, H_v, n_seqs, n_tokens);
+        struct ggml_tensor * g_reshape = ggml_cont_4d(ctx0, g, S_v, H_v, n_seqs, n_tokens);
+        struct ggml_tensor * beta_broadcast = ggml_cont_4d(ctx0, beta_sigmoid, 1, H_v, n_seqs, n_tokens);
+        struct ggml_tensor * state_broadcast = ggml_cont(ctx0, state);
+
+        return ggml_delta_net_op(q_broadcast, k_broadcast, v_reshape, g_reshape, beta_broadcast, state_broadcast,
+                                 use_qk_l2norm, scale, il);
+    }
+
+    struct ggml_tensor * ggml_delta_net_op(struct ggml_tensor *  q, struct ggml_tensor *  k, struct ggml_tensor *  v, struct ggml_tensor *  g,
+                                           struct ggml_tensor *  beta, struct ggml_tensor *  state, bool use_qk_l2norm, float scale, int il) {
+        GGML_ASSERT(ggml_is_contiguous(q));
+        GGML_ASSERT(ggml_is_contiguous(k));
+        GGML_ASSERT(ggml_is_contiguous(v));
+        GGML_ASSERT(ggml_is_contiguous(g));
+        GGML_ASSERT(ggml_is_contiguous(beta));
+        GGML_ASSERT(ggml_is_contiguous(state));
+
+        const int64_t S_k      = q->ne[0];
+        const int64_t H_k      = q->ne[1];
+        const int64_t n_seq    = q->ne[2];
+        const int64_t n_tokens = q->ne[3];
+
+        const int64_t S_v = v->ne[0];
+        const int64_t H_v = v->ne[1];
+
+        GGML_ASSERT(H_k == H_v);  // we broadcasted the tensors in the main function to guarantee this
+
+        GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_v && k->ne[2] == n_seq && k->ne[3] == n_tokens);
+        GGML_ASSERT(v->ne[1] == H_v && v->ne[2] == n_seq && v->ne[3] == n_tokens);
+        GGML_ASSERT(g->ne[0] == S_v && g->ne[1] == H_v && g->ne[2] == n_seq && g->ne[3] == n_tokens);
+        GGML_ASSERT(beta->ne[0] == 1 && beta->ne[1] == H_v && beta->ne[2] == n_seq && beta->ne[3] == n_tokens);
+        GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == n_seq &&
+                    state->ne[3] == n_tokens);
+
+        struct ggml_tensor * new_state = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, S_v, S_v * H_v, n_seq, n_tokens);
+
+        new_state = ggml_cpy(ctx0, state, new_state);
+        cb(new_state, "new_state", il);
+
+        if (use_qk_l2norm) {
+            q = ggml_l2_norm(ctx0, q, 1e-6f);
+            cb(q, "q_l2_norm", il);
+            k = ggml_l2_norm(ctx0, k, 1e-6f);
+            cb(q, "k_l2_norm", il);
+        }
+
+        q = ggml_scale(ctx0, q, scale);
+        cb(q, "q_scaled", il);
+
+        struct ggml_tensor * state_decay = ggml_mul(ctx0, state, g);
+        cb(state_decay, "state_decay", il);
+        struct ggml_tensor * kv_mem_presum = ggml_mul(ctx0, state_decay, k);
+
+        // Gotta do some squeezing here...
+        struct ggml_tensor * kv_mem_presum_squeeze =
+            ggml_reshape_4d(ctx0, kv_mem_presum, S_v, S_v, H_v, n_seq * n_tokens);
+
+        struct ggml_tensor * kv_mem = ggml_permute(
+            ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, kv_mem_presum_squeeze, 1, 2, 0, 3))), 2, 0, 1, 3);
+        cb(kv_mem, "kv_mem", il);
+        struct ggml_tensor * kv_mem_reshape = ggml_reshape_4d(ctx0, kv_mem, S_v, S_v, n_seq, n_tokens);
+        struct ggml_tensor * delta = ggml_mul(ctx0, ggml_sub(ctx0, kv_mem_reshape, v), beta);
+        cb(delta, "delta", il);
+        struct ggml_tensor * delta_kt = ggml_mul(ctx0, delta, k);
+        cb(delta_kt, "delta_kt", il);
+        struct ggml_tensor * state_plus_k_delta = ggml_add(ctx0, state_decay, delta_kt);
+        cb(state_plus_k_delta, "state_plus_k_delta", il);
+        struct ggml_tensor * state_q = ggml_mul(ctx0, state_plus_k_delta, q);
+        cb(state_q, "state_q", il);
+
+        // And here...
+        state_q                     = ggml_reshape_4d(ctx0, state_q, S_v, S_v, H_v, n_seq * n_tokens);
+        struct ggml_tensor * output = ggml_permute(ctx0, ggml_sum_rows(ctx0, state_q), 2, 0, 1, 3);
+        output                      = ggml_reshape_4d(ctx0, output, S_v, H_v, n_seq, n_tokens);
+        cb(output, "delta_net_output", il);
+
+        struct ggml_tensor * result = ggml_concat(ctx0, output, state_plus_k_delta, 1);
+        cb(result, "delta_net_result", il);
+        return result;
+    }
+
+    ggml_tensor * build_qwen3next_attention_layer(ggml_tensor *             cur,
+                                                  ggml_tensor *             inp_pos,
+                                                  llm_graph_input_attn_kv * inp_attn,
+                                                  const llama_model &       model,
+                                                  const int64_t             n_embd_head,
+                                                  const int                 il) {
         ggml_tensor * gate = build_lora_mm(model.layers[il].wq_gate, cur);
 
         // compute Q and K and RoPE them
@@ -19060,30 +19184,26 @@ private:
         cb(Kcur, "Kcur_normed", il);
 
         // Apply RoPE
-        Qcur = ggml_rope_ext(
-                ctx0, Qcur, inp_pos, nullptr,
-                n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                ext_factor, attn_factor, beta_fast, beta_slow);
+        Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                             ext_factor, attn_factor, beta_fast, beta_slow);
 
-        Kcur = ggml_rope_ext(
-                ctx0, Kcur, inp_pos, nullptr,
-                n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                ext_factor, attn_factor, beta_fast, beta_slow);
+        Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                             ext_factor, attn_factor, beta_fast, beta_slow);
 
         cb(Qcur, "Qcur", il);
         cb(Kcur, "Kcur", il);
         cb(Vcur, "Vcur", il);
 
         // Attention computation
-        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
-        cur = build_attn(inp_attn,
-                model.layers[il].wo, nullptr,
-                Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
-        
+        const float kq_scale =
+            hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+        cur = build_attn(inp_attn, model.layers[il].wo, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale,
+                         il);
+
         // Apply gating
         cur = ggml_cont(ctx0, ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)));
         cb(cur, "attn_gated", il);
-        
+
         return cur;
     }
 
@@ -19252,16 +19372,18 @@ private:
         cb(conv_output, "conv_output_final", il);
 
         // Extract the convolved Q, K, V from conv_output
-        ggml_tensor * q_conv = ggml_cont(ctx0, ggml_view_4d(ctx0, conv_output, head_k_dim, num_k_heads, n_tokens, n_seqs,
-                                            head_k_dim, conv_output->nb[1], conv_output->nb[2], 0));
-        cb(q_conv, "q_conv", il);
-        ggml_tensor * k_conv =
+        ggml_tensor * q_conv =
             ggml_cont(ctx0, ggml_view_4d(ctx0, conv_output, head_k_dim, num_k_heads, n_tokens, n_seqs, head_k_dim,
-                         conv_output->nb[1], conv_output->nb[2], head_k_dim * num_k_heads * ggml_element_size(conv_output)));
+                                         conv_output->nb[1], conv_output->nb[2], 0));
+        cb(q_conv, "q_conv", il);
+        ggml_tensor * k_conv = ggml_cont(ctx0, ggml_view_4d(ctx0, conv_output, head_k_dim, num_k_heads, n_tokens,
+                                                            n_seqs, head_k_dim, conv_output->nb[1], conv_output->nb[2],
+                                                            head_k_dim * num_k_heads * ggml_element_size(conv_output)));
         cb(q_conv, "k_conv", il);
         ggml_tensor * v_conv =
             ggml_cont(ctx0, ggml_view_4d(ctx0, conv_output, head_v_dim, num_v_heads, n_tokens, n_seqs, head_v_dim,
-                         conv_output->nb[1], conv_output->nb[2], 2 * head_k_dim * num_k_heads * ggml_element_size(conv_output)));
+                                         conv_output->nb[1], conv_output->nb[2],
+                                         2 * head_k_dim * num_k_heads * ggml_element_size(conv_output)));
         cb(q_conv, "v_conv", il);
 
         ggml_build_forward_expand(gf, ssm_states_all);
@@ -19274,28 +19396,21 @@ private:
         ggml_tensor * target_gate     = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_dim, n_heads, n_tokens, n_seqs);
         ggml_tensor * gate_broadcast  = ggml_reshape_4d(ctx0, gate, 1, n_heads, n_tokens, n_seqs);
         gate                          = ggml_repeat(ctx0, gate_broadcast, target_gate);
+        cb(gate, "gate", il);
 
         // Call the new ggml_delta_net function with the corrected flow
-        ggml_tensor * output = ggml_delta_net(ctx0,
-                                              k_conv,           // k tensor (already convolved)
-                                              v_conv,           // v tensor (already convolved)
-                                              q_conv,           // q tensor (already convolved)
-                                              gate,             // g tensor
-                                              beta,             // beta tensor
-                                              state_broadcast,  // state tensor
-                                              true,             // use_qk_l2norm
-                                              1.0f              // scale
-        );
-        cb(output, "delta_net_output", il);
+        ggml_tensor * output = ggml_delta_net(k_conv, v_conv, q_conv, gate, beta, state_broadcast, true, 1.0f, il);
 
         // Extract the output part
         ggml_tensor * attn_out = ggml_view_4d(ctx0, output, head_dim, n_heads, n_tokens, n_seqs, output->nb[0],
                                               output->nb[1], output->nb[2], 0);
+        cb(output, "attn_out", il);
 
         // Extract the new state
         ggml_tensor * new_state =
             ggml_view_4d(ctx0, output, head_dim, head_dim * n_heads, n_tokens, n_seqs, output->nb[0], output->nb[1],
                          output->nb[2], n_tokens * n_seqs * head_dim * n_heads * ggml_element_size(output));
+        cb(output, "new_state", il);
 
         // Only return the last recurrent state
         struct ggml_tensor * state_reshaped =
@@ -19303,6 +19418,7 @@ private:
         struct ggml_tensor * state_last = ggml_view_4d(
             ctx0, state_reshaped, head_dim, head_dim, n_heads, 1, state_reshaped->nb[1], state_reshaped->nb[2],
             state_reshaped->nb[3], head_dim * head_dim * n_heads * ((n_seqs * n_tokens) - 1));
+        cb(output, "new_state_last", il);
 
         // Update the recurrent states
         ggml_build_forward_expand(gf, ggml_cpy(ctx0, state_last, ssm_states_all));
@@ -19318,16 +19434,20 @@ private:
         // Apply gated normalization: self.norm(core_attn_out, z)
         // This is Qwen3NextRMSNormGated which applies: RMSNorm(x) * silu(gate)
         ggml_tensor * attn_out_norm = build_norm(attn_out_2d, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
+        cb(output, "attn_out_norm", il);
 
         // Apply silu gate: attn_out_norm * silu(z_2d)
         ggml_tensor * z_silu       = ggml_silu(ctx0, z_2d);
+        cb(output, "z_silu", il);
         ggml_tensor * gated_output = ggml_mul(ctx0, attn_out_norm, z_silu);
+        cb(output, "gated_output", il);
 
         // Reshape back to original dimensions: [n_heads * n_tokens * n_seqs, head_dim] -> [head_dim, n_heads, n_tokens, n_seqs]
         ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, gated_output, head_dim, n_heads, n_tokens, n_seqs);
 
         // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
         ggml_tensor * final_output = ggml_reshape_3d(ctx0, gated_output_4d, n_heads * head_dim, n_tokens, n_seqs);
+        cb(output, "final_output", il);
 
         // Output projection
         cur = build_lora_mm(model.layers[il].ssm_out, final_output);
@@ -19343,27 +19463,17 @@ private:
         // Check if this is an MoE layer
         if (model.layers[il].ffn_gate_inp != nullptr) {
             // MoE branch
-            ggml_tensor * moe_out = build_moe_ffn(cur,
-                    model.layers[il].ffn_gate_inp,
-                    model.layers[il].ffn_up_exps,
-                    model.layers[il].ffn_gate_exps,
-                    model.layers[il].ffn_down_exps,
-                    nullptr,
-                    n_expert, n_expert_used,
-                    LLM_FFN_SILU, true,
-                    false, 0.0,
-                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
-                    il);
+            ggml_tensor * moe_out =
+                build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
+                              model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, nullptr, n_expert,
+                              n_expert_used, LLM_FFN_SILU, true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
             cb(moe_out, "ffn_moe_out", il);
 
             // Add shared experts if present
             if (model.layers[il].ffn_up_shexp != nullptr) {
-                ggml_tensor * ffn_shexp = build_ffn(cur,
-                    model.layers[il].ffn_up_shexp,   NULL, NULL,
-                    model.layers[il].ffn_gate_shexp, NULL, NULL,
-                    model.layers[il].ffn_down_shexp, NULL, NULL,
-                    NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+                ggml_tensor * ffn_shexp =
+                    build_ffn(cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL,
+                              NULL, model.layers[il].ffn_down_shexp, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
                 cb(ffn_shexp, "ffn_shexp", il);
 
                 cur = ggml_add(ctx0, moe_out, ffn_shexp);
@@ -19373,17 +19483,13 @@ private:
             }
         } else {
             // Dense FFN branch
-            cur = build_ffn(cur,
-                    model.layers[il].ffn_up,   NULL, NULL,
-                    model.layers[il].ffn_gate, NULL, NULL,
-                    model.layers[il].ffn_down, NULL, NULL,
-                    NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL,
+                            model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
             cb(cur, "ffn_out", il);
         }
 
         // Residual connection
-        cur = ggml_add(ctx0, cur, cur); // This should be the residual from before FFN
+        cur = ggml_add(ctx0, cur, cur);  // This should be the residual from before FFN
         cb(cur, "ffn_residual", il);
 
         return cur;
@@ -19398,7 +19504,6 @@ private:
     }
 };
 
-
 llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
     llama_memory_i * res;