Parcourir la source

Override SSM_A op for Qwen3 Next to reduce splits (#17587)

* Override SSM_A op for Qwen3 Next to reduce splits

* New tensor mapping SSM_A_NOSCAN for SSM_A used outside of OP_SSM_SCAN context.

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Piotr Wilkin (ilintar) il y a 1 mois
Parent
commit
746f9ee889
3 fichiers modifiés avec 4 ajouts et 2 suppressions
  1. 2 1
      src/llama-arch.cpp
  2. 1 0
      src/llama-arch.h
  3. 1 1
      src/llama-model.cpp

+ 2 - 1
src/llama-arch.cpp

@@ -855,7 +855,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_GATE_SHEXP,     "blk.%d.ffn_gate_shexp" },
             { LLM_TENSOR_FFN_DOWN_SHEXP,     "blk.%d.ffn_down_shexp" },
             { LLM_TENSOR_FFN_UP_SHEXP,       "blk.%d.ffn_up_shexp" },
-            { LLM_TENSOR_SSM_A,              "blk.%d.ssm_a" },
+            { LLM_TENSOR_SSM_A_NOSCAN,       "blk.%d.ssm_a" },
             { LLM_TENSOR_SSM_CONV1D,         "blk.%d.ssm_conv1d" },
             { LLM_TENSOR_SSM_DT,             "blk.%d.ssm_dt" },
             { LLM_TENSOR_SSM_BETA_ALPHA,     "blk.%d.ssm_ba" },
@@ -2639,6 +2639,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
     {LLM_TENSOR_FFN_ACT,                    {LLM_TENSOR_LAYER_REPEATING, GGML_OP_DIV}},
     {LLM_TENSOR_SSM_CONV1D,                 {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
     {LLM_TENSOR_SSM_A,                      {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}},
+    {LLM_TENSOR_SSM_A_NOSCAN,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // a version of SSM_A used for MUL instead of SSM_SCAN
     {LLM_TENSOR_SSM_DT_NORM,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
     {LLM_TENSOR_SSM_B_NORM,                 {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
     {LLM_TENSOR_SSM_C_NORM,                 {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},

+ 1 - 0
src/llama-arch.h

@@ -379,6 +379,7 @@ enum llm_tensor {
     LLM_TENSOR_SSM_DT,
     LLM_TENSOR_SSM_DT_NORM,
     LLM_TENSOR_SSM_A,
+    LLM_TENSOR_SSM_A_NOSCAN,        // qwen3next special case with MUL instead of SSM_SCAN
     LLM_TENSOR_SSM_B_NORM,
     LLM_TENSOR_SSM_C_NORM,
     LLM_TENSOR_SSM_D,

+ 1 - 1
src/llama-model.cpp

@@ -6526,7 +6526,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                             layer.ssm_in         = create_tensor(tn(LLM_TENSOR_SSM_IN,         "weight", i), { n_embd, qkvz_dim }, 0);
                             layer.ssm_conv1d     = create_tensor(tn(LLM_TENSOR_SSM_CONV1D,     "weight", i), { hparams.ssm_d_conv, conv_dim }, 0);
                             layer.ssm_dt         = create_tensor(tn(LLM_TENSOR_SSM_DT,         "bias",   i), { hparams.ssm_dt_rank }, 0);
-                            layer.ssm_a          = create_tensor(tn(LLM_TENSOR_SSM_A,                    i), { hparams.ssm_dt_rank }, 0);
+                            layer.ssm_a          = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN,             i), { hparams.ssm_dt_rank }, 0);
                             layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_dim }, 0);
                             layer.ssm_norm       = create_tensor(tn(LLM_TENSOR_SSM_NORM,       "weight", i), { head_v_dim }, 0);
                             layer.ssm_out        = create_tensor(tn(LLM_TENSOR_SSM_OUT,        "weight", i), { value_dim, n_embd }, 0);