Просмотр исходного кода

Proper order of attention operations

Piotr Wilkin 3 месяцев назад
Родитель
Сommit
df0b5bcf30
1 измененных файлов с 12 добавлено и 10 удалено
  1. 12 10
      src/models/llm_build_qwen3next.cpp

+ 12 - 10
src/models/llm_build_qwen3next.cpp

@@ -54,6 +54,7 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
     llm_graph_context_mamba(params) {
     const int64_t n_embd_head = hparams.n_embd_head_v;
     GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    //GGML_ASSERT(n_embd_head == hparams.n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -142,7 +143,8 @@ struct ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_t
                                                                    const llama_model &       model,
                                                                    const int64_t             n_embd_head,
                                                                    const int                 il) {
-    // compute Q and K and RoPE them
+    // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention
+
     // Qwen3Next uses a single Q projection that outputs query + gate
     struct ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur);
     cb(Qcur_full, "Qcur_full", il);
@@ -159,13 +161,9 @@ struct ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_t
     Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
     cb(Qcur, "Qcur_reshaped", il);
     
-    // Apply Q normalization only to the query part
+    // Apply Q normalization
     Qcur = build_q3n_norm(Qcur, model.layers[il].attn_q_norm, il);
     cb(Qcur, "Qcur_normed", il);
-    
-    // Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads)
-    gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
-    cb(gate, "gate_reshaped", il);
 
     struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
     cb(Kcur, "Kcur", il);
@@ -173,14 +171,18 @@ struct ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_t
     struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
     cb(Vcur, "Vcur", il);
 
+    // Apply K normalization
+    Kcur = build_q3n_norm(Kcur, model.layers[il].attn_k_norm, il);
+    cb(Kcur, "Kcur_normed", il);
+    
+    // Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads)
+    gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
+    cb(gate, "gate_reshaped", il);
+
     Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
     Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
     Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
-    // Apply Q/K normalization
-    Kcur = build_q3n_norm(Kcur, model.layers[il].attn_k_norm, il);
-    cb(Kcur, "Kcur_normed", il);
-
     // Apply RoPE
     Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
                          attn_factor, beta_fast, beta_slow);