Piotr Wilkin il y a 3 mois
Parent
commit
0dd6110fdc

+ 0 - 4
convert_hf_to_gguf.py

@@ -3770,10 +3770,6 @@ class Qwen3NextModel(Qwen3MoeModel):
             name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
             name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
         elif "conv1d" in name:
         elif "conv1d" in name:
             data_torch = data_torch.squeeze()
             data_torch = data_torch.squeeze()
-        elif "q_proj.weight" in name:
-            q_proj, gate = data_torch.chunk(2, dim=0)
-            yield (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_GATE, bid), gate)
-            data_torch = q_proj
 
 
         yield from Qwen2MoeModel.modify_tensors(self, data_torch, name, bid)
         yield from Qwen2MoeModel.modify_tensors(self, data_torch, name, bid)
 
 

+ 0 - 2
src/llama-arch.cpp

@@ -769,7 +769,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_ATTN_K_NORM,        "blk.%d.attn_k_norm" },
             { LLM_TENSOR_ATTN_K_NORM,        "blk.%d.attn_k_norm" },
             { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
             { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
             { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
             { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_GATE,          "blk.%d.attn_gate" },
             { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
             { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
             { LLM_TENSOR_FFN_GATE_INP,       "blk.%d.ffn_gate_inp" },
             { LLM_TENSOR_FFN_GATE_INP,       "blk.%d.ffn_gate_inp" },
             { LLM_TENSOR_FFN_GATE_EXPS,      "blk.%d.ffn_gate_exps" },
             { LLM_TENSOR_FFN_GATE_EXPS,      "blk.%d.ffn_gate_exps" },
@@ -2246,7 +2245,6 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
     {LLM_TENSOR_ATTN_V,                     {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_ATTN_V,                     {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_ATTN_QKV,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_ATTN_QKV,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_ATTN_OUT,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_ATTN_OUT,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_GATE,                  {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_FFN_GATE,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_FFN_GATE,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_FFN_DOWN,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_FFN_DOWN,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_FFN_UP,                     {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_FFN_UP,                     {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},

+ 0 - 1
src/llama-arch.h

@@ -381,7 +381,6 @@ enum llm_tensor {
     LLM_TENSOR_ATTN_Q_A_NORM,
     LLM_TENSOR_ATTN_Q_A_NORM,
     LLM_TENSOR_ATTN_KV_A_NORM,
     LLM_TENSOR_ATTN_KV_A_NORM,
     LLM_TENSOR_ATTN_SUB_NORM,
     LLM_TENSOR_ATTN_SUB_NORM,
-    LLM_TENSOR_ATTN_GATE,
     LLM_TENSOR_FFN_SUB_NORM,
     LLM_TENSOR_FFN_SUB_NORM,
     LLM_TENSOR_DEC_ATTN_NORM,
     LLM_TENSOR_DEC_ATTN_NORM,
     LLM_TENSOR_DEC_ATTN_Q,
     LLM_TENSOR_DEC_ATTN_Q,

+ 1 - 5
src/llama-model.cpp

@@ -2524,7 +2524,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
 
 
                         if (!hparams.is_recurrent(i)) {
                         if (!hparams.is_recurrent(i)) {
                             // Attention layers
                             // Attention layers
-                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);
+                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0);
                             layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
                             layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
                             layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
                             layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
                             layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
                             layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
@@ -2532,10 +2532,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                             // Q/K normalization for attention layers
                             // Q/K normalization for attention layers
                             layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
                             layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
                             layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
                             layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
-
-                            // attn gate
-                            layer.wq_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);
-
                         } else {
                         } else {
                             // Linear attention (gated delta net) specific tensors
                             // Linear attention (gated delta net) specific tensors
                             // Create tensors with calculated dimensions
                             // Create tensors with calculated dimensions

+ 0 - 1
src/llama-model.h

@@ -228,7 +228,6 @@ struct llama_layer {
     struct ggml_tensor * wk_enc    = nullptr;
     struct ggml_tensor * wk_enc    = nullptr;
     struct ggml_tensor * wv_enc    = nullptr;
     struct ggml_tensor * wv_enc    = nullptr;
     struct ggml_tensor * wo_enc    = nullptr;
     struct ggml_tensor * wo_enc    = nullptr;
-    struct ggml_tensor * wq_gate   = nullptr;
 
 
     // attention bias
     // attention bias
     struct ggml_tensor * bq   = nullptr;
     struct ggml_tensor * bq   = nullptr;

+ 71 - 19
src/models/llm_build_qwen3next.cpp

@@ -57,20 +57,29 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
             // Full attention layer
             // Full attention layer
             cur = build_qwen3next_attention_layer(cur, inp_pos, inp->get_attn(), model, n_embd_head, il);
             cur = build_qwen3next_attention_layer(cur, inp_pos, inp->get_attn(), model, n_embd_head, il);
         }
         }
-        // Post-attention norm
-        cur = build_q3n_norm(cur, model.layers[il].attn_post_norm, il);
-        cb(cur, "attn_post_norm", il);
 
 
         if (il == n_layer - 1 && inp_out_ids) {
         if (il == n_layer - 1 && inp_out_ids) {
             cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
             cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
             inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
             inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
         }
         }
+        
         // Residual connection
         // Residual connection
         cur = ggml_add(ctx0, cur, inpSA);
         cur = ggml_add(ctx0, cur, inpSA);
         cb(cur, "attn_residual", il);
         cb(cur, "attn_residual", il);
 
 
-        // FFN layer (MoE or dense)
-        cur = build_layer_ffn(cur, model, il);
+        // Save the tensor before post-attention norm for residual connection
+        ggml_tensor * ffn_residual = cur;
+        
+        // Post-attention norm
+        ggml_tensor * attn_post_norm = build_q3n_norm(cur, model.layers[il].attn_post_norm, il);
+        cb(attn_post_norm, "attn_post_norm", il);
+
+        // FFN layer (MoE or dense) - without residual connection
+        cur = build_layer_ffn(attn_post_norm, model, il, false);
+        cb(cur, "ffn_out", il);
+        
+        // Residual connection for FFN - add to the tensor BEFORE post_attention_layernorm
+        cur = ggml_add(ctx0, cur, ffn_residual);
         cb(cur, "post_moe", il);
         cb(cur, "post_moe", il);
 
 
         // Input for next layer
         // Input for next layer
@@ -111,11 +120,30 @@ struct ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_t
                                                                    const llama_model &       model,
                                                                    const llama_model &       model,
                                                                    const int64_t             n_embd_head,
                                                                    const int64_t             n_embd_head,
                                                                    const int                 il) {
                                                                    const int                 il) {
-    ggml_tensor * gate = build_lora_mm(model.layers[il].wq_gate, cur);
-
     // compute Q and K and RoPE them
     // compute Q and K and RoPE them
-    struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+    // Qwen3Next uses a single Q projection that outputs query + gate
+    struct ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur);
+    cb(Qcur_full, "Qcur_full", il);
+    Qcur_full = ggml_reshape_4d(ctx0, Qcur_full, n_embd_head * 2, n_head, n_tokens, 1);
+    // Split Q projection into query and gate
+    // The split should be along dimension 0 (the feature dimension)
+    struct ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0);
+    struct ggml_tensor * gate = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 
+        n_embd_head * ggml_element_size(Qcur_full));
     cb(Qcur, "Qcur", il);
     cb(Qcur, "Qcur", il);
+    cb(gate, "gate", il);
+    
+    // Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention
+    Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+    cb(Qcur, "Qcur_reshaped", il);
+    
+    // Apply Q normalization only to the query part
+    Qcur = build_q3n_norm(Qcur, model.layers[il].attn_q_norm, il);
+    cb(Qcur, "Qcur_normed", il);
+    
+    // Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads)
+    gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
+    cb(gate, "gate_reshaped", il);
 
 
     struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
     struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
     cb(Kcur, "Kcur", il);
     cb(Kcur, "Kcur", il);
@@ -123,14 +151,12 @@ struct ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_t
     struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
     struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
     cb(Vcur, "Vcur", il);
     cb(Vcur, "Vcur", il);
 
 
-    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+    Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
     Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
     Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
     Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
     Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
 
     // Apply Q/K normalization
     // Apply Q/K normalization
-    Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
-    Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
-    cb(Kcur, "Qcur_normed", il);
+    Kcur = build_q3n_norm(Kcur, model.layers[il].attn_k_norm, il);
     cb(Kcur, "Kcur_normed", il);
     cb(Kcur, "Kcur_normed", il);
 
 
     // Apply RoPE
     // Apply RoPE
@@ -149,8 +175,8 @@ struct ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_t
         hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
         hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
     cur = build_attn(inp_attn, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
     cur = build_attn(inp_attn, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
 
 
-    // Apply gating
-    cur = ggml_cont(ctx0, ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)));
+    // Apply gating directly using the original gate tensor
+    cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate));
     cb(cur, "attn_gated", il);
     cb(cur, "attn_gated", il);
 
 
     cur = build_lora_mm(model.layers[il].wo, cur);
     cur = build_lora_mm(model.layers[il].wo, cur);
@@ -598,7 +624,8 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
     return cur;
     return cur;
 }
 }
 
 
-ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const llama_model & model, const int il) {
+ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const llama_model & model, const int il, bool do_residual) {
+
     // Check if this is an MoE layer
     // Check if this is an MoE layer
     if (model.layers[il].ffn_gate_inp != nullptr) {
     if (model.layers[il].ffn_gate_inp != nullptr) {
         // MoE branch
         // MoE branch
@@ -608,13 +635,33 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const llam
                           n_expert_used, LLM_FFN_SILU, true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
                           n_expert_used, LLM_FFN_SILU, true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
         cb(moe_out, "ffn_moe_out", il);
         cb(moe_out, "ffn_moe_out", il);
 
 
-        // Add shared experts if present
+        // Add shared experts if present - following Qwen3Next reference implementation
         if (model.layers[il].ffn_up_shexp != nullptr) {
         if (model.layers[il].ffn_up_shexp != nullptr) {
             ggml_tensor * ffn_shexp =
             ggml_tensor * ffn_shexp =
                 build_ffn(cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL,
                 build_ffn(cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL,
                           model.layers[il].ffn_down_shexp, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
                           model.layers[il].ffn_down_shexp, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
             cb(ffn_shexp, "ffn_shexp", il);
             cb(ffn_shexp, "ffn_shexp", il);
 
 
+            // Apply shared expert gating as in the reference implementation
+            // The shared expert has its own gate that is sigmoided
+            // Note: ffn_gate_inp_shexp is the shared expert gate (outputs 1 value per token)
+            ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur);
+            cb(shared_gate, "shared_expert_gate", il);
+            
+            // Apply sigmoid to the gate
+            shared_gate = ggml_sigmoid(ctx0, shared_gate);
+            cb(shared_gate, "shared_expert_gate_sigmoid", il);
+            
+            // The gate needs to be broadcast to match the dimensions of ffn_shexp
+            // ffn_shexp is [n_embd, n_tokens, 1, 1] and shared_gate is [1, n_tokens, 1, 1]
+            // We need to repeat the gate along the feature dimension
+            shared_gate = ggml_repeat(ctx0, shared_gate, ffn_shexp);
+            cb(shared_gate, "shared_expert_gate_broadcast", il);
+            
+            // Apply the gate to the shared expert output
+            ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
+            cb(ffn_shexp, "ffn_shexp_gated", il);
+
             cur = ggml_add(ctx0, moe_out, ffn_shexp);
             cur = ggml_add(ctx0, moe_out, ffn_shexp);
             cb(cur, "ffn_out", il);
             cb(cur, "ffn_out", il);
         } else {
         } else {
@@ -626,9 +673,14 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const llam
                         model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
                         model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
         cb(cur, "ffn_out", il);
         cb(cur, "ffn_out", il);
     }
     }
-    // Residual connection
-    cur = ggml_add(ctx0, cur, cur);  // This should be the residual from before FFN
-    cb(cur, "ffn_residual", il);
+    // Residual connection (only if requested)
+    if (do_residual) {
+        cur = ggml_add(ctx0, cur, cur);
+        cb(cur, "ffn_residual", il);
+    }
+
+    cur = build_cvec(cur, il);
+    cb(cur, "l_out", il);
 
 
     return cur;
     return cur;
 };
 };

+ 1 - 1
src/models/llm_build_qwen3next.h

@@ -36,7 +36,7 @@ private:
                                                     const llama_ubatch & ubatch,
                                                     const llama_ubatch & ubatch,
                                                     int                  il);
                                                     int                  il);
 
 
-    ggml_tensor * build_layer_ffn(ggml_tensor * cur, const llama_model & model, const int il);
+    ggml_tensor * build_layer_ffn(ggml_tensor * cur, const llama_model & model, const int il, bool do_residual = true);
 
 
     ggml_tensor * softplus(ggml_tensor * alpha, ggml_tensor * dt_bias);
     ggml_tensor * softplus(ggml_tensor * alpha, ggml_tensor * dt_bias);