Procházet zdrojové kódy

model : Qwen3 Next (#16095)

* Qwen3 Next - cleaned up version

* Whitespaces and stuff

* Correct minor errors

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Misc. fixes.

* Clean up code, add missing hybrid qualifier

* Did someone transpose the SOLVE_TRI result matrix? Perhaps...

* Whitespace

* Proper tensors for cb calls

* Use llama-graph.h vertical alignment

* BROKEN: chunking

* Set new tensors as inputs.

* Proper chunk logic

* It's the circle of life...

* More shenanigans for n_seq > 1

* Nail in the coffin?

* Fix Windows build

* Eh, one fails on Windows, the other fails on Mac... just use general capture.

* quant : cleanup

* model : cleanup

* qwen3 : cleanup

* cont : cleanup

* cont : cleanup

* ggml : revert change

* qwen3 : cleanup

* cont : cleanup

* Readd cmath

* qwen3 : fix typo

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Usual suspects

* fix my bad suggestion

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Piotr Wilkin (ilintar) před 1 měsícem
rodič
revize
ff55414c42

+ 30 - 0
convert_hf_to_gguf.py

@@ -4183,6 +4183,36 @@ class Qwen3MoeModel(Qwen2MoeModel):
         super().set_vocab()
 
 
+@ModelBase.register("Qwen3NextForCausalLM")
+class Qwen3NextModel(Qwen2MoeModel):
+    model_arch = gguf.MODEL_ARCH.QWEN3NEXT
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        self.gguf_writer.add_ssm_conv_kernel(self.hparams["linear_conv_kernel_dim"])
+        self.gguf_writer.add_ssm_state_size(self.hparams["linear_key_head_dim"])
+        self.gguf_writer.add_ssm_group_count(self.hparams["linear_num_key_heads"])
+        self.gguf_writer.add_ssm_time_step_rank(self.hparams["linear_num_value_heads"])
+        self.gguf_writer.add_ssm_inner_size(self.hparams["linear_value_head_dim"] * self.hparams["linear_num_value_heads"])
+        if (rope_dim := self.hparams.get("head_dim")) is None:
+            rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
+        self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.25)))
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        if name.startswith("mtp"):
+            return [] # ignore MTP layers for now
+        if name.endswith(".A_log"):
+            data_torch = -torch.exp(data_torch)
+        elif name.endswith(".dt_bias"):
+            name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
+        elif "conv1d" in name:
+            data_torch = data_torch.squeeze()
+        elif name.endswith("norm.weight") and not name.endswith("linear_attn.norm.weight"):
+            data_torch = data_torch + 1
+
+        yield from super().modify_tensors(data_torch, name, bid)
+
+
 @ModelBase.register("RND1")
 class RND1Model(Qwen2MoeModel):
     model_arch = gguf.MODEL_ARCH.RND1

+ 7 - 1
examples/model-conversion/scripts/causal/run-converted-model.sh

@@ -4,6 +4,11 @@ set -e
 
 # First try command line argument, then environment variable, then file
 CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
+MODEL_TESTING_PROMPT="${2:-"$MODEL_TESTING_PROMPT"}"
+
+if [ -z "$MODEL_TESTING_PROMPT"]; then
+    MODEL_TESTING_PROMPT="Hello, my name is"
+fi
 
 # Final check if we have a model path
 if [ -z "$CONVERTED_MODEL" ]; then
@@ -14,7 +19,8 @@ if [ -z "$CONVERTED_MODEL" ]; then
 fi
 
 echo $CONVERTED_MODEL
+echo $MODEL_TESTING_PROMPT
 
 cmake --build ../../build --target llama-logits -j8
 
-../../build/bin/llama-logits -m "$CONVERTED_MODEL" "Hello, my name is"
+../../build/bin/llama-logits -m "$CONVERTED_MODEL" "$MODEL_TESTING_PROMPT"

+ 6 - 2
examples/model-conversion/scripts/causal/run-org-model.py

@@ -184,8 +184,12 @@ model_name = os.path.basename(model_path)
 # of using AutoModelForCausalLM.
 print(f"Model class: {model.__class__.__name__}")
 
-prompt = "Hello, my name is"
-input_ids = tokenizer(prompt, return_tensors="pt").input_ids
+device = next(model.parameters()).device
+if os.getenv("MODEL_TESTING_PROMPT"):
+    prompt = os.getenv("MODEL_TESTING_PROMPT")
+else:
+    prompt = "Hello, my name is"
+input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
 
 print(f"Input tokens: {input_ids}")
 print(f"Input text: {repr(prompt)}")

+ 2 - 1
ggml/src/ggml-cpu/ops.cpp

@@ -9766,7 +9766,8 @@ static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params
             }
 
             const float diag = A_batch[i00 * n + i00];
-            GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix");
+            assert(diag != 0.0f && "Zero diagonal in triangular matrix");
+
             X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
         }
     }

+ 33 - 0
gguf-py/gguf/constants.py

@@ -366,6 +366,7 @@ class MODEL_ARCH(IntEnum):
     QWEN2VL          = auto()
     QWEN3            = auto()
     QWEN3MOE         = auto()
+    QWEN3NEXT        = auto()
     QWEN3VL          = auto()
     QWEN3VLMOE       = auto()
     PHI2             = auto()
@@ -531,6 +532,7 @@ class MODEL_TENSOR(IntEnum):
     SSM_D                = auto()
     SSM_NORM             = auto()
     SSM_OUT              = auto()
+    SSM_BETA_ALPHA       = auto() # qwen3next
     TIME_MIX_W0          = auto()
     TIME_MIX_W1          = auto()
     TIME_MIX_W2          = auto()
@@ -736,6 +738,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.QWEN2VL:          "qwen2vl",
     MODEL_ARCH.QWEN3:            "qwen3",
     MODEL_ARCH.QWEN3MOE:         "qwen3moe",
+    MODEL_ARCH.QWEN3NEXT:        "qwen3next",
     MODEL_ARCH.QWEN3VL:          "qwen3vl",
     MODEL_ARCH.QWEN3VLMOE:       "qwen3vlmoe",
     MODEL_ARCH.PHI2:             "phi2",
@@ -900,6 +903,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.SSM_D:                     "blk.{bid}.ssm_d",
     MODEL_TENSOR.SSM_NORM:                  "blk.{bid}.ssm_norm",
     MODEL_TENSOR.SSM_OUT:                   "blk.{bid}.ssm_out",
+    MODEL_TENSOR.SSM_BETA_ALPHA:            "blk.{bid}.ssm_ba",
     MODEL_TENSOR.TIME_MIX_W0:               "blk.{bid}.time_mix_w0",
     MODEL_TENSOR.TIME_MIX_W1:               "blk.{bid}.time_mix_w1",
     MODEL_TENSOR.TIME_MIX_W2:               "blk.{bid}.time_mix_w2",
@@ -1569,6 +1573,35 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN_EXP,
         MODEL_TENSOR.FFN_UP_EXP,
     ],
+    MODEL_ARCH.QWEN3NEXT: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_Q_NORM,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_K_NORM,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.ATTN_POST_NORM,
+        MODEL_TENSOR.ATTN_GATE,
+        MODEL_TENSOR.FFN_GATE_INP,
+        MODEL_TENSOR.FFN_GATE_INP_SHEXP,
+        MODEL_TENSOR.FFN_UP_SHEXP,
+        MODEL_TENSOR.FFN_DOWN_SHEXP,
+        MODEL_TENSOR.FFN_GATE_SHEXP,
+        MODEL_TENSOR.FFN_DOWN_EXP,
+        MODEL_TENSOR.FFN_UP_EXP,
+        MODEL_TENSOR.FFN_GATE_EXP,
+        MODEL_TENSOR.SSM_A,
+        MODEL_TENSOR.SSM_CONV1D,
+        MODEL_TENSOR.SSM_DT,
+        MODEL_TENSOR.SSM_NORM,
+        MODEL_TENSOR.SSM_IN,
+        MODEL_TENSOR.SSM_BETA_ALPHA,
+        MODEL_TENSOR.SSM_OUT
+    ],
     MODEL_ARCH.QWEN3VL: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT_NORM,

+ 16 - 6
gguf-py/gguf/tensor_mapping.py

@@ -672,10 +672,11 @@ class TensorNameMap:
         ),
 
         MODEL_TENSOR.SSM_IN: (
-            "model.layers.{bid}.in_proj",               # mamba-hf
-            "backbone.layers.{bid}.mixer.in_proj",      # mamba
-            "model.layers.{bid}.mamba.in_proj",         # jamba falcon-h1 granite-hybrid
-            "model.layers.layers.{bid}.mixer.in_proj",  # plamo2
+            "model.layers.{bid}.in_proj",                   # mamba-hf
+            "backbone.layers.{bid}.mixer.in_proj",          # mamba
+            "model.layers.{bid}.mamba.in_proj",             # jamba falcon-h1 granite-hybrid
+            "model.layers.layers.{bid}.mixer.in_proj",      # plamo2
+            "model.layers.{bid}.linear_attn.in_proj_qkvz",  # qwen3next
         ),
 
         MODEL_TENSOR.SSM_CONV1D: (
@@ -683,6 +684,7 @@ class TensorNameMap:
             "backbone.layers.{bid}.mixer.conv1d",      # mamba
             "model.layers.{bid}.mamba.conv1d",         # jamba falcon-h1 granite-hybrid
             "model.layers.layers.{bid}.mixer.conv1d",  # plamo2
+            "model.layers.{bid}.linear_attn.conv1d",   # qwen3next
         ),
 
         MODEL_TENSOR.SSM_X: (
@@ -697,6 +699,7 @@ class TensorNameMap:
             "backbone.layers.{bid}.mixer.dt_proj",      # mamba
             "model.layers.{bid}.mamba.dt_proj",         # jamba falcon-h1 granite-hybrid
             "model.layers.layers.{bid}.mixer.dt_proj",  # plamo2
+            "model.layers.{bid}.linear_attn.dt_proj",   # qwen3next
         ),
 
         MODEL_TENSOR.SSM_DT_NORM: (
@@ -709,6 +712,7 @@ class TensorNameMap:
             "backbone.layers.{bid}.mixer.A_log",      # mamba
             "model.layers.{bid}.mamba.A_log",         # jamba falcon-h1 granite-hybrid
             "model.layers.layers.{bid}.mixer.A_log",  # plamo2
+            "model.layers.{bid}.linear_attn.A_log",   # qwen3next
         ),
 
         MODEL_TENSOR.SSM_B_NORM: (
@@ -731,17 +735,23 @@ class TensorNameMap:
         ),
 
         MODEL_TENSOR.SSM_NORM: (
-            "model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid
-            "backbone.layers.{bid}.mixer.norm",  # mamba2
+            "model.layers.{bid}.mamba.norm",        # falcon-h1 granite-hybrid
+            "model.layers.{bid}.linear_attn.norm",  # qwen3next
+            "backbone.layers.{bid}.mixer.norm",     # mamba2
         ),
 
         MODEL_TENSOR.SSM_OUT: (
             "model.layers.{bid}.out_proj",               # mamba-hf
             "backbone.layers.{bid}.mixer.out_proj",      # mamba
             "model.layers.{bid}.mamba.out_proj",         # jamba falcon-h1 granite-hybrid
+            "model.layers.{bid}.linear_attn.out_proj",   # qwen3next
             "model.layers.layers.{bid}.mixer.out_proj",  # plamo2
         ),
 
+        MODEL_TENSOR.SSM_BETA_ALPHA: (
+            "model.layers.{bid}.linear_attn.in_proj_ba",  # qwen3next
+        ),
+
         MODEL_TENSOR.TIME_MIX_W0: (
             "model.layers.{bid}.attention.w0",            # rwkv7
         ),

+ 1 - 0
src/CMakeLists.txt

@@ -114,6 +114,7 @@ add_library(llama
             models/qwen3vl.cpp
             models/qwen3vl-moe.cpp
             models/qwen3moe.cpp
+            models/qwen3next.cpp
             models/refact.cpp
             models/rnd1.cpp
             models/rwkv6-base.cpp

+ 35 - 0
src/llama-arch.cpp

@@ -32,6 +32,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_QWEN2VL,          "qwen2vl"          },
     { LLM_ARCH_QWEN3,            "qwen3"            },
     { LLM_ARCH_QWEN3MOE,         "qwen3moe"         },
+    { LLM_ARCH_QWEN3NEXT,        "qwen3next"        },
     { LLM_ARCH_QWEN3VL,          "qwen3vl"          },
     { LLM_ARCH_QWEN3VLMOE,       "qwen3vlmoe"       },
     { LLM_ARCH_PHI2,             "phi2"             },
@@ -829,6 +830,38 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_UP_EXPS,        "blk.%d.ffn_up_exps" },
         },
     },
+    {
+        LLM_ARCH_QWEN3NEXT,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
+            { LLM_TENSOR_OUTPUT,             "output" },
+            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_POST_NORM,     "blk.%d.post_attention_norm" },
+            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_Q_NORM,        "blk.%d.attn_q_norm" },
+            { LLM_TENSOR_ATTN_K,             "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_K_NORM,        "blk.%d.attn_k_norm" },
+            { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE_INP,       "blk.%d.ffn_gate_inp" },
+            { LLM_TENSOR_FFN_GATE_EXPS,      "blk.%d.ffn_gate_exps" },
+            { LLM_TENSOR_FFN_DOWN_EXPS,      "blk.%d.ffn_down_exps" },
+            { LLM_TENSOR_FFN_UP_EXPS,        "blk.%d.ffn_up_exps" },
+            { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
+            { LLM_TENSOR_FFN_GATE_SHEXP,     "blk.%d.ffn_gate_shexp" },
+            { LLM_TENSOR_FFN_DOWN_SHEXP,     "blk.%d.ffn_down_shexp" },
+            { LLM_TENSOR_FFN_UP_SHEXP,       "blk.%d.ffn_up_shexp" },
+            { LLM_TENSOR_SSM_A,              "blk.%d.ssm_a" },
+            { LLM_TENSOR_SSM_CONV1D,         "blk.%d.ssm_conv1d" },
+            { LLM_TENSOR_SSM_DT,             "blk.%d.ssm_dt" },
+            { LLM_TENSOR_SSM_BETA_ALPHA,     "blk.%d.ssm_ba" },
+            { LLM_TENSOR_SSM_IN,             "blk.%d.ssm_in" },
+            { LLM_TENSOR_SSM_NORM,           "blk.%d.ssm_norm" },
+            { LLM_TENSOR_SSM_OUT,            "blk.%d.ssm_out" },
+        },
+    },
     {
         LLM_ARCH_QWEN3VL,
         {
@@ -2556,6 +2589,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
     {LLM_TENSOR_SSM_X,                      {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_SSM_DT,                     {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_SSM_OUT,                    {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_SSM_BETA_ALPHA,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_TIME_MIX_W1,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_TIME_MIX_W2,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_TIME_MIX_A1,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
@@ -2754,6 +2788,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
         case LLM_ARCH_LFM2:
         case LLM_ARCH_LFM2MOE:
         case LLM_ARCH_NEMOTRON_H:
+        case LLM_ARCH_QWEN3NEXT:
             return true;
         default:
             return false;

+ 2 - 0
src/llama-arch.h

@@ -36,6 +36,7 @@ enum llm_arch {
     LLM_ARCH_QWEN2VL,
     LLM_ARCH_QWEN3,
     LLM_ARCH_QWEN3MOE,
+    LLM_ARCH_QWEN3NEXT,
     LLM_ARCH_QWEN3VL,
     LLM_ARCH_QWEN3VLMOE,
     LLM_ARCH_PHI2,
@@ -381,6 +382,7 @@ enum llm_tensor {
     LLM_TENSOR_SSM_D,
     LLM_TENSOR_SSM_NORM,
     LLM_TENSOR_SSM_OUT,
+    LLM_TENSOR_SSM_BETA_ALPHA,      // qwen3next
     LLM_TENSOR_TIME_MIX_W0,
     LLM_TENSOR_TIME_MIX_W1,
     LLM_TENSOR_TIME_MIX_W2,

+ 4 - 0
src/llama-context.cpp

@@ -1,5 +1,6 @@
 #include "llama-context.h"
 
+#include "llama-arch.h"
 #include "llama-impl.h"
 #include "llama-batch.h"
 #include "llama-io.h"
@@ -1386,6 +1387,9 @@ void llama_context::output_reorder() {
 //
 
 uint32_t llama_context::graph_max_nodes() const {
+    if (model.arch == LLM_ARCH_QWEN3NEXT) {
+        return std::max<uint32_t>(8192u, 32u*model.n_tensors());
+    }
     return std::max<uint32_t>(1024u, 8u*model.n_tensors());
 }
 

+ 1 - 1
src/llama-hparams.h

@@ -6,7 +6,7 @@
 
 // bump if necessary
 #define LLAMA_MAX_LAYERS  512
-#define LLAMA_MAX_EXPERTS 384  // Kimi-K2
+#define LLAMA_MAX_EXPERTS 512 // Qwen3 Next
 
 enum llama_expert_gating_func_type {
     LLAMA_EXPERT_GATING_FUNC_TYPE_NONE           = 0,

+ 98 - 2
src/llama-model.cpp

@@ -2,7 +2,6 @@
 
 #include "llama-impl.h"
 #include "llama-mmap.h"
-#include "llama-batch.h"
 #include "llama-cparams.h"
 #include "llama-model-loader.h"
 
@@ -2225,6 +2224,29 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_QWEN3NEXT:
+            {
+                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,        hparams.n_ff_exp, false);
+                ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,       hparams.f_norm_rms_eps);
+
+                // Load linear attention (gated delta net) parameters
+                ml.get_key(LLM_KV_SSM_CONV_KERNEL,    hparams.ssm_d_conv);
+                ml.get_key(LLM_KV_SSM_INNER_SIZE,     hparams.ssm_d_inner);
+                ml.get_key(LLM_KV_SSM_STATE_SIZE,     hparams.ssm_d_state);
+                ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
+                ml.get_key(LLM_KV_SSM_GROUP_COUNT,    hparams.ssm_n_group);
+
+                // Mark recurrent layers (linear attention layers)
+                for (uint32_t i = 0; i < hparams.n_layer; ++i) {
+                    hparams.recurrent_layer_arr[i] = ((i + 1) % 4 != 0); // TODO: extract the magic 4 from "full_attention_interval"
+                }
+
+                switch (hparams.n_layer) {
+                    case 80: type = LLM_TYPE_80B_A3B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         default: throw std::runtime_error("unsupported model architecture");
     }
 
@@ -6415,6 +6437,74 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                     }
                 } break;
+            case LLM_ARCH_QWEN3NEXT:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
+
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
+                    }
+
+                    const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
+
+                    // Calculate dimensions from hyperparameters
+                    const int64_t head_k_dim = hparams.ssm_d_state;
+                    const int64_t head_v_dim = hparams.ssm_d_state;
+                    const int64_t n_k_heads  = hparams.ssm_n_group;
+                    const int64_t n_v_heads  = hparams.ssm_dt_rank;
+                    const int64_t key_dim    = head_k_dim * n_k_heads;
+                    const int64_t value_dim  = head_v_dim * n_v_heads;
+                    const int64_t conv_dim   = key_dim * 2 + value_dim;
+
+                    // Calculate projection sizes
+                    const int64_t qkvz_dim = key_dim * 2 + value_dim * 2;
+                    const int64_t ba_dim   = n_v_heads * 2;
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm      = create_tensor(tn(LLM_TENSOR_ATTN_NORM,      "weight", i), { n_embd }, 0);
+                        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
+
+                        if (!hparams.is_recurrent(i)) {
+                            // Attention layers
+                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0);
+                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), { n_embd, n_embd_k_gqa }, 0);
+                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), { n_embd, n_embd_v_gqa }, 0);
+                            layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
+
+                            // Q/K normalization for attention layers
+                            layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
+                            layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
+                        } else {
+                            // Linear attention (gated delta net) specific tensors
+                            // Create tensors with calculated dimensions
+                            layer.ssm_in         = create_tensor(tn(LLM_TENSOR_SSM_IN,         "weight", i), { n_embd, qkvz_dim }, 0);
+                            layer.ssm_conv1d     = create_tensor(tn(LLM_TENSOR_SSM_CONV1D,     "weight", i), { hparams.ssm_d_conv, conv_dim }, 0);
+                            layer.ssm_dt         = create_tensor(tn(LLM_TENSOR_SSM_DT,         "bias",   i), { hparams.ssm_dt_rank }, 0);
+                            layer.ssm_a          = create_tensor(tn(LLM_TENSOR_SSM_A,                    i), { hparams.ssm_dt_rank }, 0);
+                            layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_dim }, 0);
+                            layer.ssm_norm       = create_tensor(tn(LLM_TENSOR_SSM_NORM,       "weight", i), { head_v_dim }, 0);
+                            layer.ssm_out        = create_tensor(tn(LLM_TENSOR_SSM_OUT,        "weight", i), { value_dim, n_embd }, 0);
+                        }
+
+                        layer.ffn_gate_inp  = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), { n_embd, n_expert }, 0);
+                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
+                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0);
+                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
+
+                        // Shared experts
+                        layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0);
+                        layer.ffn_gate_shexp     = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP,     "weight", i), { n_embd, hparams.n_ff_shexp }, 0);
+                        layer.ffn_up_shexp       = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,       "weight", i), { n_embd, hparams.n_ff_shexp }, 0);
+                        layer.ffn_down_shexp     = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP,     "weight", i), { hparams.n_ff_shexp, n_embd }, 0);
+                    }
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -6685,6 +6775,7 @@ void llama_model::print_info() const {
         arch == LLM_ARCH_FALCON_H1 ||
         arch == LLM_ARCH_PLAMO2 ||
         arch == LLM_ARCH_GRANITE_HYBRID ||
+        arch == LLM_ARCH_QWEN3NEXT ||
         arch == LLM_ARCH_NEMOTRON_H) {
         LLAMA_LOG_INFO("%s: ssm_d_conv       = %u\n",     __func__, hparams.ssm_d_conv);
         LLAMA_LOG_INFO("%s: ssm_d_inner      = %u\n",     __func__, hparams.ssm_d_inner);
@@ -7426,7 +7517,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
         case LLM_ARCH_PANGU_EMBED:
             {
                 llm = std::make_unique<llm_build_pangu_embedded>(*this, params);
-            }break;
+            } break;
+        case LLM_ARCH_QWEN3NEXT:
+            {
+                llm = std::make_unique<llm_build_qwen3next>(*this, params);
+            } break;
         default:
             GGML_ABORT("fatal error");
     }
@@ -7653,6 +7748,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_COGVLM:
         case LLM_ARCH_PANGU_EMBED:
         case LLM_ARCH_AFMOE:
+        case LLM_ARCH_QWEN3NEXT:
             return LLAMA_ROPE_TYPE_NEOX;
 
         case LLM_ARCH_QWEN2VL:

+ 4 - 0
src/llama-model.h

@@ -113,6 +113,7 @@ enum llm_type {
     LLM_TYPE_16B_A1B,
     LLM_TYPE_21B_A3B, // Ernie MoE small
     LLM_TYPE_30B_A3B,
+    LLM_TYPE_80B_A3B, // Qwen3 Next
     LLM_TYPE_100B_A6B,
     LLM_TYPE_106B_A12B, // GLM-4.5-Air
     LLM_TYPE_230B_A10B, // Minimax M2
@@ -309,6 +310,9 @@ struct llama_layer {
     struct ggml_tensor * ssm_conv1d_b = nullptr;
     struct ggml_tensor * ssm_dt_b     = nullptr;
 
+    // qwen3next
+    struct ggml_tensor * ssm_beta_alpha = nullptr;
+
     // rwkv
     struct ggml_tensor * time_mix_w1         = nullptr;
     struct ggml_tensor * time_mix_w2         = nullptr;

+ 13 - 5
src/llama-quant.cpp

@@ -681,7 +681,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
             }
             LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str());
             continue;
-        } else if (remapped_name != it.first) {
+        }
+
+        if (remapped_name != it.first) {
             ggml_set_name(it.second.tensor, remapped_name.c_str());
             LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor));
         }
@@ -726,13 +728,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
     {
         const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
         // attention layers have a non-zero number of kv heads
-        int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
+        int32_t n_layer_attn = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
         if (llama_model_has_encoder(&model)) {
-            // now n_attn_layer is the number of attention layers in the encoder
+            // now n_layer_attn is the number of attention layers in the encoder
             // for each decoder block, there are 2 attention layers
-            n_attn_layer += 2 * model.hparams.dec_n_layer;
+            n_layer_attn += 2 * model.hparams.dec_n_layer;
         }
-        GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected");
+
+        // note: for linear-attention models (such as Qwen3 Next) this is the number of linear layers
+        const int32_t n_layer_recr = std::count(model.hparams.recurrent_layer_arr.begin(), model.hparams.recurrent_layer_arr.end(), true);
+
+        LLAMA_LOG_INFO("%s: n_layer_attn = %d, n_layer_recr = %d, pruned_attention_w = %d\n", __func__, n_layer_attn, n_layer_recr, pruned_attention_w);
+
+        GGML_ASSERT((qs.n_attention_wv == n_layer_attn - pruned_attention_w - n_layer_recr) && "n_attention_wv is unexpected");
     }
 
     size_t total_size_org = 0;

+ 51 - 1
src/models/models.h

@@ -2,8 +2,9 @@
 
 #include "../llama-model.h"
 #include "../llama-graph.h"
-#include "../llama-memory-recurrent.h"
 
+// TODO: remove in follow-up PR - move to .cpp files
+#include "../llama-memory-recurrent.h"
 #include <cmath>
 
 struct llm_graph_context_mamba : public llm_graph_context {
@@ -421,7 +422,56 @@ struct llm_build_qwen3vl : public llm_graph_context {
 struct llm_build_qwen3vlmoe : public llm_graph_context {
     llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params);
 };
+struct llm_build_qwen3next : public llm_graph_context_mamba {
+    llm_build_qwen3next(const llama_model & model, const llm_graph_params & params);
+private:
+    ggml_tensor * build_layer_attn(
+    llm_graph_input_attn_kv * inp_attn,
+                ggml_tensor * cur,
+                ggml_tensor * inp_pos,
+                        int   il);
+
+    ggml_tensor * build_layer_attn_linear(
+         llm_graph_input_rs * inp,
+                ggml_tensor * cur,
+                ggml_tensor * causal_mask,
+                ggml_tensor * identity,
+                        int   il);
 
+    ggml_tensor * build_layer_ffn(
+                ggml_tensor * cur,
+                        int   il);
+
+    ggml_tensor * build_delta_net_recurrent(
+                ggml_tensor * q,
+                ggml_tensor * k,
+                ggml_tensor * v,
+                ggml_tensor * g,
+                ggml_tensor * beta,
+                ggml_tensor * state,
+                ggml_tensor * causal_mask,
+                ggml_tensor * identity,
+                        int   il);
+
+    ggml_tensor * build_delta_net_chunking(
+                ggml_tensor * q,
+                ggml_tensor * k,
+                ggml_tensor * v,
+                ggml_tensor * g,
+                ggml_tensor * beta,
+                ggml_tensor * state,
+                ggml_tensor * causal_mask,
+                ggml_tensor * identity,
+                        int   il);
+
+    ggml_tensor * build_norm_gated(
+                ggml_tensor * input,
+                ggml_tensor * weights,
+                ggml_tensor * gate,
+                        int   layer);
+
+    const llama_model & model;
+};
 
 struct llm_build_qwen : public llm_graph_context {
     llm_build_qwen(const llama_model & model, const llm_graph_params & params);

+ 1042 - 0
src/models/qwen3next.cpp

@@ -0,0 +1,1042 @@
+#include "ggml.h"
+#include "models.h"
+
+#define CHUNK_SIZE 64
+
+llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) :
+    llm_graph_context_mamba(params), model(model) {
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+    cb(inpL, "model.embed_tokens", -1);
+
+    auto * inp = build_inp_mem_hybrid();
+
+    ggml_tensor * inp_pos     = build_inp_pos();
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    ggml_tensor * causal_mask =
+        ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens, ubatch.n_seq_tokens), 1.0f),
+                    GGML_TRI_TYPE_LOWER);
+
+    ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens), 1.0f));
+
+    ggml_build_forward_expand(gf, causal_mask);
+    ggml_build_forward_expand(gf, identity);
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_tensor * inpSA = inpL;
+
+        cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
+        cb(cur, "attn_norm", il);
+
+        // Determine layer type and build appropriate attention mechanism
+        if (hparams.is_recurrent(il)) {
+            // Linear attention layer (gated delta net)
+            cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, il);
+        } else {
+            // Full attention layer
+            cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il);
+        }
+
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+
+        // Residual connection
+        cur = ggml_add(ctx0, cur, inpSA);
+        cb(cur, "attn_residual", il);
+
+        // Save the tensor before post-attention norm for residual connection
+        ggml_tensor * ffn_residual = cur;
+
+        // Post-attention norm
+        ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il);
+        cb(attn_post_norm, "attn_post_norm", il);
+
+        // FFN layer (MoE or dense) - without residual connection
+        cur = build_layer_ffn(attn_post_norm, il);
+        cb(cur, "ffn_out", il);
+
+        // Residual connection for FFN - add to the tensor from before post_attention_layernorm
+        cur = ggml_add(ctx0, cur, ffn_residual);
+        cb(cur, "post_moe", il);
+
+        // Input for next layer
+        inpL = cur;
+    }
+    cur = inpL;
+
+    // Final norm
+    cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
+
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    // LM head
+    cur = build_lora_mm(model.output, cur);
+
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
+
+ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * g,
+        ggml_tensor * beta,
+        ggml_tensor * state,
+        ggml_tensor * causal_mask,
+        ggml_tensor * identity,
+        int           il) {
+    GGML_ASSERT(ggml_is_contiguous(q));
+    GGML_ASSERT(ggml_is_contiguous(k));
+    GGML_ASSERT(ggml_is_contiguous(v));
+    GGML_ASSERT(ggml_is_contiguous(g));
+    GGML_ASSERT(ggml_is_contiguous(beta));
+    GGML_ASSERT(ggml_is_contiguous(state));
+
+    const int64_t S_k      = q->ne[0];
+    const int64_t H_k      = q->ne[1];
+    const int64_t n_tokens = q->ne[2];
+    const int64_t n_seqs   = q->ne[3];
+
+    const int64_t S_v = v->ne[0];
+    const int64_t H_v = v->ne[1];
+
+    GGML_ASSERT(v->ne[2] == n_tokens);
+    GGML_ASSERT(k->ne[2] == n_tokens);
+    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
+    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
+    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
+
+    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+
+    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
+
+    // TODO: can this ever be false?
+    const bool use_qk_l2norm = true;
+
+    if (use_qk_l2norm) {
+        const float eps_norm = hparams.f_norm_rms_eps;
+
+        q = ggml_l2_norm(ctx0, q, eps_norm);
+        k = ggml_l2_norm(ctx0, k, eps_norm);
+    }
+
+    const float scale = 1.0f / sqrtf(S_v);
+
+    q = ggml_scale(ctx0, q, scale);
+
+    beta = ggml_sigmoid(ctx0, beta);
+
+    ggml_tensor * causal_diag_mask = ggml_add(ctx0, causal_mask, identity);
+
+    cb(q, "q_in", il);
+    cb(k, "k_in", il);
+    cb(v, "v_in", il);
+    cb(beta, "beta_in", il);
+    cb(g, "g_in", il);
+
+    q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
+    k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
+    v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
+    g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
+
+    beta  = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
+    state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
+
+    cb(q, "q_perm", il);
+    cb(k, "k_perm", il);
+    cb(v, "v_perm", il);
+    cb(beta, "beta_perm", il);
+    cb(g, "g_perm", il);
+    cb(state, "state_in", il);
+
+    GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
+    GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
+    GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
+
+    // Do padding
+    const int64_t chunk_size = CHUNK_SIZE;
+
+    const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
+    const int64_t n_chunks = (n_tokens + pad) / chunk_size;
+
+    q = ggml_pad(ctx0, q, 0, pad, 0, 0);
+    k = ggml_pad(ctx0, k, 0, pad, 0, 0);
+    v = ggml_pad(ctx0, v, 0, pad, 0, 0);
+    g = ggml_pad(ctx0, g, pad, 0, 0, 0);
+    beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
+
+    cb(q, "q_pad", il);
+    cb(k, "k_pad", il);
+    cb(v, "v_pad", il);
+    cb(beta, "beta_pad", il);
+    cb(g, "g_pad", il);
+
+    ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
+    ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
+
+    cb(v_beta, "v_beta", il);
+    cb(k_beta, "k_beta", il);
+
+    ggml_tensor * chunked_mask =
+        ggml_view_4d(ctx0, causal_mask, chunk_size,
+                chunk_size,         causal_mask->ne[2], causal_mask->ne[3],
+                causal_mask->nb[1], causal_mask->nb[2], causal_mask->nb[3], 0);
+
+    ggml_tensor * chunked_diag_mask =
+        ggml_view_4d(ctx0, causal_diag_mask, chunk_size,
+                chunk_size,              causal_diag_mask->ne[2], causal_diag_mask->ne[3],
+                causal_diag_mask->nb[1], causal_diag_mask->nb[2], causal_diag_mask->nb[3], 0);
+
+    ggml_tensor * chunked_identity =
+        ggml_view_4d(ctx0, identity, chunk_size,
+            chunk_size,      identity->ne[2], identity->ne[3],
+            identity->nb[1], identity->nb[2], identity->nb[3], 0);
+
+    q      = ggml_cont_4d(ctx0, q,      S_k, chunk_size, n_chunks, H_k * n_seqs);
+    k      = ggml_cont_4d(ctx0, k,      S_k, chunk_size, n_chunks, H_k * n_seqs);
+    k_beta = ggml_cont_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs);
+    v      = ggml_cont_4d(ctx0, v,      S_v, chunk_size, n_chunks, H_v * n_seqs);
+    v_beta = ggml_cont_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs);
+
+    g    = ggml_cont_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs);
+    beta = ggml_cont_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
+
+    ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
+
+    cb(g_cumsum, "g_cumsum", il);
+
+    ggml_tensor * gcs_i = ggml_cont_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs);
+    ggml_tensor * gcs_j = ggml_cont_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
+
+    ggml_tensor * gcs_j_broadcast =
+        ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
+
+    ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
+
+    cb(decay_mask, "decay_mask", il);
+
+    decay_mask = ggml_mul(ctx0, decay_mask, chunked_diag_mask);
+    decay_mask = ggml_exp(ctx0, decay_mask);
+    decay_mask = ggml_mul(ctx0, decay_mask, chunked_diag_mask);
+
+    ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
+
+    ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
+    ggml_tensor * attn    = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, chunked_mask));
+
+    cb(attn, "attn_pre_solve", il);
+
+    ggml_tensor * attn_lower = ggml_mul(ctx0, attn, chunked_mask);
+    ggml_tensor * lhs        = ggml_sub(ctx0, ggml_repeat(ctx0, chunked_identity, attn_lower), attn_lower);
+
+    ggml_tensor * lin_solve  = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
+    attn                     = ggml_mul(ctx0, lin_solve, chunked_mask);
+    attn                     = ggml_add(ctx0, attn, chunked_identity);
+
+    cb(attn, "attn_solved", il);
+
+    v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
+
+    ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
+    ggml_tensor * gexp       = ggml_exp(ctx0, g_cumsum_t);
+
+    ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
+
+    cb(kbeta_gexp, "kbeta_gexp", il);
+
+    ggml_tensor * k_cumdecay =
+        ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)))));
+
+    cb(k_cumdecay, "k_cumdecay", il);
+
+    ggml_tensor * core_attn_out = nullptr;
+    ggml_tensor * new_state = ggml_dup(ctx0, state);
+
+    cb(new_state, "new_state", il);
+
+    for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
+        auto chunkify = [=](ggml_tensor * t) {
+            return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3],
+                t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
+        };
+
+        auto chunkify_g = [=](ggml_tensor * t) {
+            return ggml_cont(ctx0, ggml_view_4d(ctx0, t, chunk_size, t->ne[1], 1, t->ne[3],
+                t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
+        };
+
+        ggml_tensor * k_chunk = chunkify(k);
+        ggml_tensor * q_chunk = chunkify(q);
+        ggml_tensor * v_chunk = chunkify(v);
+
+        ggml_tensor * g_cs_chunk = chunkify_g(g_cumsum);
+        ggml_tensor * g_cs_chunk_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cs_chunk));
+
+        ggml_tensor * decay_mask_chunk = chunkify(decay_mask);
+        ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay);
+
+        ggml_tensor * gexp_chunk = ggml_exp(ctx0, g_cs_chunk_t);
+
+        // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
+        attn = ggml_mul_mat(ctx0, k_chunk, q_chunk);
+        attn = ggml_mul(ctx0, attn, decay_mask_chunk);
+        attn = ggml_mul(ctx0, attn, ggml_add(ctx0, chunked_identity, chunked_mask));
+
+        ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
+
+        // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
+        ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
+
+        // v_new = v_i - v_prime
+        ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime);
+        ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
+
+        // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
+        ggml_tensor * q_g_exp    = ggml_mul(ctx0, q_chunk, gexp_chunk);
+        ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
+
+        // core_attn_out[:, :, i] = attn_inter + attn @ v_new
+        ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn);
+
+        ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
+
+        core_attn_out = core_attn_out == nullptr ? core_attn_out_chunk : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 1);
+
+        // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
+        // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
+        // key_gdiff = key * g_diff.unsqueeze(-1)
+        // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
+        // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
+
+        ggml_tensor * g_cum_last =
+            ggml_cont(ctx0, ggml_view_4d(ctx0, g_cs_chunk_t, g_cs_chunk_t->ne[0], 1, g_cs_chunk_t->ne[2], g_cs_chunk_t->ne[3],
+                                        g_cs_chunk_t->nb[1], g_cs_chunk_t->nb[2], g_cs_chunk_t->nb[3],
+                                        g_cs_chunk_t->nb[0] * (g_cs_chunk_t->ne[1] - 1)));
+
+        ggml_tensor * gexp_last =
+            ggml_reshape_4d(ctx0, ggml_exp(ctx0, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]);
+
+        ggml_tensor * g_cum_last_3d =
+            ggml_reshape_3d(ctx0, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]);
+
+        ggml_tensor * g_cumsum_3d = ggml_reshape_3d(ctx0, g_cs_chunk, g_cs_chunk->ne[0], g_cs_chunk->ne[2], g_cs_chunk->ne[3]);
+
+        ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum_3d, g_cum_last_3d));
+
+        ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
+
+        ggml_tensor * key_gdiff = ggml_mul(ctx0, k_chunk,
+                                        ggml_reshape_4d(ctx0, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1],
+                                                        g_diff_exp->ne[2] * g_diff_exp->ne[3]));
+
+        ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)));
+
+        new_state = ggml_add(ctx0,
+            ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last, gexp_last->ne[0], gexp_last->ne[1], H_v, n_seqs)),
+            ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
+    }
+
+    core_attn_out = ggml_cont_4d(ctx0, core_attn_out, S_v, chunk_size * n_chunks, H_v, n_seqs);
+
+    ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, S_v, n_tokens, H_v, n_seqs, core_attn_out->nb[1], core_attn_out->nb[2], core_attn_out->nb[3], 0);
+    cb(output_tokens, "output_tokens", il);
+
+    // flatten output
+    ggml_tensor * flat_output =
+        ggml_cont_1d(ctx0, ggml_permute(ctx0, output_tokens, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs);
+
+    ggml_tensor * flat_state = ggml_cont_1d(ctx0, new_state, S_v * S_v * H_v * n_seqs);
+
+    return ggml_concat(ctx0, flat_output, flat_state, 0);
+}
+
+ggml_tensor * llm_build_qwen3next::build_delta_net_recurrent(
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * g,
+        ggml_tensor * beta,
+        ggml_tensor * state,
+        ggml_tensor * causal_mask,
+        ggml_tensor * identity,
+        int           il) {
+    GGML_ASSERT(ggml_is_contiguous(q));
+    GGML_ASSERT(ggml_is_contiguous(k));
+    GGML_ASSERT(ggml_is_contiguous(v));
+    GGML_ASSERT(ggml_is_contiguous(g));
+    GGML_ASSERT(ggml_is_contiguous(beta));
+    GGML_ASSERT(ggml_is_contiguous(state));
+
+    const int64_t S_k      = q->ne[0];
+    const int64_t H_k      = q->ne[1];
+    const int64_t n_tokens = q->ne[2];
+    const int64_t n_seqs   = q->ne[3];
+
+    const int64_t S_v = v->ne[0];
+    const int64_t H_v = v->ne[1];
+
+    GGML_ASSERT(v->ne[2] == n_tokens);
+    GGML_ASSERT(k->ne[2] == n_tokens);
+    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
+    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
+    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
+
+    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+
+    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
+
+    // TODO: can this ever be false?
+    const bool use_qk_l2norm = true;
+
+    if (use_qk_l2norm) {
+        const float eps_norm = hparams.f_norm_rms_eps;
+
+        q = ggml_l2_norm(ctx0, q, eps_norm);
+        k = ggml_l2_norm(ctx0, k, eps_norm);
+    }
+
+    const float scale = 1.0f / sqrtf(S_v);
+
+    q = ggml_scale(ctx0, q, scale);
+
+    beta = ggml_sigmoid(ctx0, beta);
+
+    ggml_tensor * causal_diag_mask = ggml_add(ctx0, causal_mask, identity);
+
+    cb(q, "q_in", il);
+    cb(k, "k_in", il);
+    cb(v, "v_in", il);
+    cb(beta, "beta_in", il);
+    cb(g, "g_in", il);
+
+    q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
+    k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
+    v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
+    g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
+
+    beta  = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
+    state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
+
+    cb(q, "q_perm", il);
+    cb(k, "k_perm", il);
+    cb(v, "v_perm", il);
+    cb(beta, "beta_perm", il);
+    cb(g, "g_perm", il);
+    cb(state, "state_in", il);
+
+    GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
+    GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
+    GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
+
+    ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
+    ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
+
+    ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
+
+    cb(k_beta, "k_beta", il);
+    cb(v_beta, "v_beta", il);
+    cb(g_cumsum, "g_cumsum", il);
+
+    ggml_tensor * gcs_i = ggml_cont_4d(ctx0, g_cumsum, n_tokens, 1, H_v, n_seqs);  // [chunk_size, 1, n_tokens, n_seqs]
+    ggml_tensor * gcs_j = ggml_cont_4d(ctx0, g_cumsum, 1, n_tokens, H_v, n_seqs);  // [1, chunk_size, n_tokens, n_seqs]
+
+    // Broadcast both tensors to [chunk_size, chunk_size, H_v, n_seqs]
+    // ggml_tensor * gcs_i_broadcast =
+    //     ggml_repeat_4d(ctx0, gcs_i, GGML_DELTA_NET_CHUNK, GGML_DELTA_NET_CHUNK, num_chunks * H_v,
+    //                     n_seqs);  // [chunk_size, 1, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
+    // Don't need this, this one will get auto-broadcast
+    ggml_tensor * gcs_j_broadcast =
+        ggml_repeat_4d(ctx0, gcs_j, n_tokens, n_tokens, H_v, n_seqs);  // [1, chunk_size, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
+
+    ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
+
+    // Apply lower triangular mask to ensure attention is causal (only past tokens influence current)
+    decay_mask = ggml_mul(ctx0, decay_mask, causal_diag_mask);
+    // Apply exponential to get the decay mask values
+    decay_mask = ggml_exp(ctx0, decay_mask);
+    // Apply lower triangular mask again to ensure only lower triangular values remain
+    decay_mask = ggml_mul(ctx0, decay_mask, causal_diag_mask);
+
+    cb(decay_mask, "decay_mask", il);
+
+    // attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
+    ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
+
+    cb(kmulkbeta, "kmulkbeta", il);
+
+    ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
+    ggml_tensor * attn    = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
+
+    cb(attn, "attn_pre_rec", il);
+
+    // for i in range(1, chunk_size):
+    //          row = attn[..., i, :i].clone()
+    //          sub = attn[..., :i, :i].clone()
+    //          attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
+    // attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
+    //
+    // We reduce this to a linear triangular solve: AX = B, where B = attn, A = I - tril(A)
+    ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
+    ggml_tensor * lhs        = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
+
+    ggml_tensor * lin_solve  = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
+    attn                     = ggml_mul(ctx0, lin_solve, causal_mask);
+    attn                     = ggml_add(ctx0, attn, identity);
+
+    // value = attn @ v_beta
+    v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
+
+    cb(v, "value_beta", il);
+
+    // k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
+    ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
+    ggml_tensor * gexp       = ggml_exp(ctx0, g_cumsum_t);
+
+    cb(gexp, "g_cum_exp", il);
+
+    ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
+
+    cb(kbeta_gexp, "kbeta_gexp", il);
+
+    ggml_tensor * k_cumdecay =
+        ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)))));
+
+    cb(k_cumdecay, "k_cumdecay", il);
+
+    // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
+    attn = ggml_mul_mat(ctx0, k, q);
+    attn = ggml_mul(ctx0, attn, decay_mask);
+    attn = ggml_mul(ctx0, attn, ggml_add(ctx0, identity, causal_mask));
+
+    cb(attn, "attn_decay_key", il);
+
+    ggml_tensor * state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
+
+    // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
+    ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay);
+
+    cb(v_prime, "v_prime", il);
+
+    // v_new = v_i - v_prime
+    ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v, v_prime), v_prime);
+
+    ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
+
+    cb(v_new, "v_new", il);
+
+    // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
+    ggml_tensor * q_g_exp    = ggml_mul(ctx0, q, gexp);
+    ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
+
+    cb(attn_inter, "attn_inter", il);
+
+    // core_attn_out[:, :, i] = attn_inter + attn @ v_new
+    ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn);
+
+    cb(v_attn, "v_attn", il);
+
+    ggml_tensor * core_attn_out = ggml_add(ctx0, attn_inter, v_attn);
+
+    cb(core_attn_out, "core_attn_out", il);
+
+    // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
+    // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
+    // key_gdiff = key * g_diff.unsqueeze(-1)
+    // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
+    // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
+
+    ggml_tensor * g_cum_last =
+        ggml_cont(ctx0, ggml_view_4d(ctx0, g_cumsum_t, g_cumsum_t->ne[0], 1, g_cumsum_t->ne[2], g_cumsum_t->ne[3],
+                                    g_cumsum_t->nb[1], g_cumsum_t->nb[2], g_cumsum_t->nb[3],
+                                    g_cumsum_t->nb[0] * (g_cumsum_t->ne[1] - 1)));
+
+    cb(g_cum_last, "g_cum_last", il);
+
+    ggml_tensor * gexp_last =
+        ggml_reshape_4d(ctx0, ggml_exp(ctx0, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]);
+
+    cb(gexp_last, "gexp_last", il);
+
+    ggml_tensor * g_cum_last_3d =
+        ggml_reshape_3d(ctx0, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]);
+
+    cb(g_cum_last_3d, "g_cum_last_3d", il);
+
+    ggml_tensor * g_cumsum_3d = ggml_reshape_3d(ctx0, g_cumsum, g_cumsum->ne[0], g_cumsum->ne[2], g_cumsum->ne[3]);
+
+    cb(g_cumsum_3d, "g_cumsum_3d", il);
+
+    ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum_3d, g_cum_last_3d));
+
+    cb(g_diff, "g_diff", il);
+
+    ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
+
+    cb(g_diff_exp, "g_diff_exp", il);
+
+    ggml_tensor * key_gdiff = ggml_mul(ctx0, k,
+                                    ggml_reshape_4d(ctx0, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1],
+                                                    g_diff_exp->ne[2] * g_diff_exp->ne[3]));
+
+    cb(key_gdiff, "key_gdiff", il);
+
+    ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)));
+
+    cb(kgdmulvnew, "kgdmulvnew", il);
+
+    state = ggml_add(ctx0, ggml_mul(ctx0, state, gexp_last), kgdmulvnew);
+
+    cb(state, "new_state", il);
+
+    // flatten output
+    ggml_tensor * flat_output =
+        ggml_cont_1d(ctx0, ggml_permute(ctx0, core_attn_out, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs);
+
+    ggml_tensor * flat_state = ggml_cont_1d(ctx0, state, S_v * S_v * H_v * n_seqs);
+
+    return ggml_concat(ctx0, flat_output, flat_state, 0);
+}
+
+ggml_tensor * llm_build_qwen3next::build_norm_gated(
+        ggml_tensor * input,
+        ggml_tensor * weights,
+        ggml_tensor * gate,
+        int           layer) {
+    ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer);
+    ggml_tensor * gated_silu = ggml_silu(ctx0, gate);
+
+    return ggml_mul(ctx0, normalized, gated_silu);
+}
+
+ggml_tensor * llm_build_qwen3next::build_layer_attn(
+        llm_graph_input_attn_kv * inp,
+        ggml_tensor *             cur,
+        ggml_tensor *             inp_pos,
+        int                       il) {
+    const int64_t n_embd_head = hparams.n_embd_head_v;
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+    // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention
+
+    // Qwen3Next uses a single Q projection that outputs query + gate
+    ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur);
+    cb(Qcur_full, "Qcur_full", il);
+
+    Qcur_full = ggml_reshape_4d(ctx0, Qcur_full, n_embd_head * 2, n_head, n_tokens, 1);
+
+    // Split Q projection into query and gate
+    // The split should be along dimension 0 (the feature dimension)
+    ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
+                                             Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0);
+    ggml_tensor * gate =
+        ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
+                     Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full));
+    cb(Qcur, "Qcur", il);
+    cb(gate, "gate", il);
+
+    // Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention
+    Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+    cb(Qcur, "Qcur_reshaped", il);
+
+    // Apply Q normalization
+    Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
+    cb(Qcur, "Qcur_normed", il);
+
+    ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+    cb(Kcur, "Kcur", il);
+
+    ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+    cb(Vcur, "Vcur", il);
+
+    // Apply K normalization
+    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+    Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
+    cb(Kcur, "Kcur_normed", il);
+
+    // Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads)
+    gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
+    cb(gate, "gate_reshaped", il);
+
+    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+    // Apply RoPE
+    Qcur = ggml_rope_ext(
+            ctx0, Qcur, inp_pos, nullptr,
+            n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+            ext_factor, attn_factor, beta_fast, beta_slow);
+
+    Kcur = ggml_rope_ext(
+            ctx0, Kcur, inp_pos, nullptr,
+            n_rot, rope_type, n_ctx_orig, freq_base,
+            freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+
+    cb(Qcur, "Qcur", il);
+    cb(Kcur, "Kcur", il);
+    cb(Vcur, "Vcur", il);
+
+    // Attention computation
+    const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+
+    cur = build_attn(inp,
+                nullptr, nullptr,
+                Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
+    cb(cur, "attn_pregate", il);
+
+    ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate);
+    cb(gate_sigmoid, "gate_sigmoid", il);
+
+    cur = ggml_mul(ctx0, cur, gate_sigmoid);
+    cb(cur, "attn_gated", il);
+
+    cur = build_lora_mm(model.layers[il].wo, cur);
+    cb(cur, "attn_output", il);
+
+    return cur;
+}
+
+ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
+        llm_graph_input_rs * inp,
+        ggml_tensor *        cur,
+        ggml_tensor *        causal_mask,
+        ggml_tensor *        identity,
+        int                  il) {
+    const auto * mctx_cur = inp->mctx;
+
+    const int64_t d_inner      = hparams.ssm_d_inner;
+    const int64_t n_seqs       = ubatch.n_seqs;
+    const int64_t head_k_dim   = hparams.ssm_d_state;
+    const int64_t num_k_heads  = hparams.ssm_n_group;
+    const int64_t num_v_heads  = hparams.ssm_dt_rank;
+    const int64_t head_v_dim   = d_inner / num_v_heads;
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+    const auto kv_head = mctx_cur->get_head();
+
+    GGML_ASSERT(n_seqs != 0);
+    GGML_ASSERT(ubatch.equal_seqs());
+    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+
+    // Input projections
+    ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, cur);
+    cb(mixed_qkvz, "linear_attn_mixed_qkvz", il);
+
+    ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur);
+    cb(mixed_ba, "linear_attn_mixed_ba", il);
+
+    int64_t       qkvz_new_dim        = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads);
+    ggml_tensor * mixed_qkvz_reshaped = ggml_cont_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs);
+
+    // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads]
+    int64_t       ba_new_dim        = 2 * num_v_heads / num_k_heads;
+    ggml_tensor * mixed_ba_reshaped = ggml_cont_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs);
+
+    // Split mixed_ba into b and a (beta and alpha parameters)
+    int64_t split_sizes_ba[2] = {
+        num_v_heads / num_k_heads,  // beta size
+        num_v_heads / num_k_heads   // alpha size
+    };
+
+    ggml_tensor * b = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_seq_tokens, n_seqs,
+                                   mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], 0);
+    cb(b, "b", il);
+
+    ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_seq_tokens, n_seqs,
+                                   mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3],
+                                   split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped));
+    cb(a, "a", il);
+
+    // Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
+    ggml_tensor * beta  = ggml_cont_3d(ctx0, b, num_v_heads, n_seq_tokens, n_seqs);
+    ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs);
+
+    GGML_ASSERT(ggml_nelements(beta) + ggml_nelements(alpha) == ggml_nelements(mixed_ba));
+
+    ggml_tensor * alpha_biased   = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
+    ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
+    cb(alpha_softplus, "a_softplus", il);
+    ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a);  // -A_log.exp() * softplus
+    cb(gate, "gate", il);
+
+    // Split mixed_qkvz into query, key, value, z
+    int64_t split_sizes_qkvz[4] = {
+        head_k_dim,                              // query size
+        head_k_dim,                              // key size
+        head_v_dim * num_v_heads / num_k_heads,  // value size
+        head_v_dim * num_v_heads / num_k_heads   // z size
+    };
+
+    ggml_tensor * query =
+        ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs,
+                     mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0);
+    cb(query, "q", il);
+
+    ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs,
+                                     mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
+                                     split_sizes_qkvz[0] * sizeof(float));
+    cb(key, "k", il);
+
+    ggml_tensor * value =
+        ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs,
+                     mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
+                     (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float));
+    cb(value, "v", il);
+
+    ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs,
+                                   mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
+                                   (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float));
+    cb(z, "z", il);
+
+    GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value) + ggml_nelements(z) ==
+                ggml_nelements(mixed_qkvz));
+
+    // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions
+    // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
+    ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
+    cb(query_flat, "query_flat", il);
+
+    // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
+    ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
+    cb(key_flat, "key_flat", il);
+
+    // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs]
+    ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
+    cb(value_flat, "value_flat", il);
+
+    // Get convolution states from cache
+    ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
+    ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
+
+    // bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state();
+
+    // Build the convolution states tensor
+    ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
+    cb(conv_states, "conv_states", il);
+
+    // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs]
+    ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
+    qkv_mixed               = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
+    cb(qkv_mixed, "qkv_mixed", il);
+
+    qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
+    cb(qkv_mixed, "qkv_mixed_permuted", il);
+
+    // Calculate the total conv dimension
+    int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
+
+    // Calculate convolution kernel size
+    ggml_tensor * conv_kernel      = model.layers[il].ssm_conv1d;
+    const int64_t conv_kernel_size = conv_kernel->ne[0];
+    const int64_t conv_channels    = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
+    conv_states                    = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
+    cb(conv_states, "conv_states_reshaped", il);
+
+    ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
+    cb(conv_input, "conv_input", il);
+
+    // Update convolution state cache
+    // Extract the last (conv_kernel_size - 1) states from conv_input
+    ggml_tensor * last_conv_states =
+        ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1],
+                     conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
+    cb(last_conv_states, "last_conv_states", il);
+
+    ggml_tensor * state_update_target =
+        ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs,
+                     kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
+    cb(state_update_target, "state_update_target", il);
+
+    ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
+    cb(conv_states_all, "conv_states_updated", il);
+
+    // Apply SSM convolution
+    ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel);
+    cb(conv_output_proper, "conv_output_raw", il);
+
+    conv_output_proper = ggml_cont(ctx0, ggml_transpose(ctx0, conv_output_proper));
+    cb(conv_output_proper, "conv_output_pre_silu", il);
+
+    ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper);
+    cb(conv_output_silu, "conv_output_silu", il);
+
+    ggml_tensor * conv_qkv_mix =
+        ggml_cont_2d(ctx0, ggml_transpose(ctx0, conv_output_silu), qkv_dim, n_seq_tokens * n_seqs);
+    cb(conv_qkv_mix, "conv_qkv_mix", il);
+
+    // Extract the convolved Q, K, V from conv_output
+    ggml_tensor * q_conv =
+        ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], 0);
+    cb(q_conv, "q_conv", il);
+    ggml_tensor * k_conv =
+        ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1],
+                     head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
+    cb(k_conv, "k_conv", il);
+    ggml_tensor * v_conv =
+        ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1],
+                     2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
+    cb(v_conv, "v_conv", il);
+
+    // Unsqueeze them
+    q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
+    k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
+    v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
+
+    beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs);
+
+    ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
+    state               = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs);
+    cb(state, "state_predelta", il);
+
+    // if head keys and value keys are different, repeat to force tensors into matching shapes
+    if (num_k_heads != num_v_heads) {
+        GGML_ASSERT(num_v_heads % num_k_heads == 0);
+        int64_t repeat_factor = num_v_heads / num_k_heads;
+
+        // repeat interleave: reshape to (repeat part, 1, remaining part), do repeat, then reshape back
+        ggml_tensor * q_reshaped = ggml_reshape_3d(ctx0, q_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs);
+        ggml_tensor * k_reshaped = ggml_reshape_3d(ctx0, k_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs);
+
+        // Repeat along the third dimension (the new dimension with size 1)
+        ggml_tensor * q_repeated =
+            ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1);
+        ggml_tensor * k_repeated =
+            ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1);
+
+        // Reshape back to merge the head and repeat dimensions
+        // From [head_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs]
+        // Back to [head_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs]
+        q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
+        k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
+    }
+
+    cb(q_conv, "q_conv_predelta", il);
+    cb(k_conv, "k_conv_predelta", il);
+    cb(v_conv, "v_conv_predelta", il);
+
+    // Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens
+    ggml_tensor * attn_out = n_seq_tokens > CHUNK_SIZE ?
+        build_delta_net_chunking (q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, il) :
+        build_delta_net_recurrent(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, il);
+    cb(attn_out, "attn_out", il);
+
+    // The tensors were concatenated 1d, so we need to extract them 1d as well
+    const int64_t output_flat_size = head_v_dim * num_v_heads * n_seq_tokens * n_seqs;
+    ggml_tensor * attn_out_1d      = ggml_view_1d(ctx0, attn_out, output_flat_size, 0);
+    cb(attn_out_1d, "attn_out_1d", il);
+
+    ggml_tensor * attn_out_final = ggml_cont_4d(ctx0, attn_out_1d, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
+    cb(attn_out_final, "attn_out_reshaped", il);
+
+    // Extract the state part (second part of the concatenated tensor)
+    // State starts after n_tokens elements along dimension 1
+    const int64_t state_flat_size = head_v_dim * head_v_dim * num_v_heads * n_seqs;
+
+    ggml_tensor * state_1d =
+        ggml_view_1d(ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size(attn_out));
+    cb(state_1d, "state_1d", il);
+
+    // Update the recurrent states
+    ggml_build_forward_expand(gf,
+                              ggml_cpy(ctx0, state_1d,
+                                       ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
+                                                    kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
+
+    GGML_ASSERT(ggml_nelements(attn_out_1d) + ggml_nelements(state_1d) == ggml_nelements(attn_out));
+
+    // Reshape both attn_out_final and z to 2D tensors for normalization
+    // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
+    ggml_tensor * attn_out_2d_final =
+        ggml_cont_2d(ctx0, attn_out_final, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
+
+    // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
+    ggml_tensor * z_2d = ggml_cont_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
+
+    // Apply gated normalization: self.norm(core_attn_out, z)
+    ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
+
+    // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
+    ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
+    cb(final_output, "final_output", il);
+
+    // Output projection
+    cur = build_lora_mm(model.layers[il].ssm_out, final_output);
+    cb(cur, "linear_attn_out", il);
+
+    // Reshape back to original dimensions
+    cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
+    return cur;
+}
+
+ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int il) {
+    // Check if this is an MoE layer
+    if (model.layers[il].ffn_gate_inp != nullptr) {
+        // MoE branch
+        ggml_tensor * moe_out =
+            build_moe_ffn(cur,
+                model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
+                model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
+                nullptr,
+                n_expert, n_expert_used, LLM_FFN_SILU,
+                true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
+        cb(moe_out, "ffn_moe_out", il);
+
+        // Add shared experts if present - following Qwen3Next reference implementation
+        if (model.layers[il].ffn_up_shexp != nullptr) {
+            ggml_tensor * ffn_shexp =
+                build_ffn(cur,
+                    model.layers[il].ffn_up_shexp, NULL, NULL,
+                    model.layers[il].ffn_gate_shexp, NULL, NULL,
+                    model.layers[il].ffn_down_shexp, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(ffn_shexp, "ffn_shexp", il);
+
+            // Apply shared expert gating as in the reference implementation
+            // The shared expert has its own gate that is sigmoided
+            // Note: ffn_gate_inp_shexp is the shared expert gate (outputs 1 value per token)
+            ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur);
+            cb(shared_gate, "shared_expert_gate", il);
+
+            // Apply sigmoid to the gate
+            shared_gate = ggml_sigmoid(ctx0, shared_gate);
+            cb(shared_gate, "shared_expert_gate_sigmoid", il);
+
+            // The gate needs to be broadcast to match the dimensions of ffn_shexp
+            // ffn_shexp is [n_embd, n_tokens, 1, 1] and shared_gate is [1, n_tokens, 1, 1]
+            // We need to repeat the gate along the feature dimension
+            shared_gate = ggml_repeat(ctx0, shared_gate, ffn_shexp);
+            cb(shared_gate, "shared_expert_gate_broadcast", il);
+
+            // Apply the gate to the shared expert output
+            ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
+            cb(ffn_shexp, "ffn_shexp_gated", il);
+
+            cur = ggml_add(ctx0, moe_out, ffn_shexp);
+            cb(cur, "ffn_out", il);
+        } else {
+            cur = moe_out;
+        }
+    } else {
+        // Dense FFN branch (not currently used I believe)
+        cur = build_ffn(cur,
+            model.layers[il].ffn_up, NULL, NULL,
+            model.layers[il].ffn_gate, NULL, NULL,
+            model.layers[il].ffn_down, NULL, NULL,
+            NULL,
+            LLM_FFN_SILU, LLM_FFN_PAR, il);
+        cb(cur, "ffn_out", il);
+    }
+    return cur;
+}