|
@@ -57,20 +57,29 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
|
|
|
// Full attention layer
|
|
// Full attention layer
|
|
|
cur = build_qwen3next_attention_layer(cur, inp_pos, inp->get_attn(), model, n_embd_head, il);
|
|
cur = build_qwen3next_attention_layer(cur, inp_pos, inp->get_attn(), model, n_embd_head, il);
|
|
|
}
|
|
}
|
|
|
- // Post-attention norm
|
|
|
|
|
- cur = build_q3n_norm(cur, model.layers[il].attn_post_norm, il);
|
|
|
|
|
- cb(cur, "attn_post_norm", il);
|
|
|
|
|
|
|
|
|
|
if (il == n_layer - 1 && inp_out_ids) {
|
|
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
// Residual connection
|
|
// Residual connection
|
|
|
cur = ggml_add(ctx0, cur, inpSA);
|
|
cur = ggml_add(ctx0, cur, inpSA);
|
|
|
cb(cur, "attn_residual", il);
|
|
cb(cur, "attn_residual", il);
|
|
|
|
|
|
|
|
- // FFN layer (MoE or dense)
|
|
|
|
|
- cur = build_layer_ffn(cur, model, il);
|
|
|
|
|
|
|
+ // Save the tensor before post-attention norm for residual connection
|
|
|
|
|
+ ggml_tensor * ffn_residual = cur;
|
|
|
|
|
+
|
|
|
|
|
+ // Post-attention norm
|
|
|
|
|
+ ggml_tensor * attn_post_norm = build_q3n_norm(cur, model.layers[il].attn_post_norm, il);
|
|
|
|
|
+ cb(attn_post_norm, "attn_post_norm", il);
|
|
|
|
|
+
|
|
|
|
|
+ // FFN layer (MoE or dense) - without residual connection
|
|
|
|
|
+ cur = build_layer_ffn(attn_post_norm, model, il, false);
|
|
|
|
|
+ cb(cur, "ffn_out", il);
|
|
|
|
|
+
|
|
|
|
|
+ // Residual connection for FFN - add to the tensor BEFORE post_attention_layernorm
|
|
|
|
|
+ cur = ggml_add(ctx0, cur, ffn_residual);
|
|
|
cb(cur, "post_moe", il);
|
|
cb(cur, "post_moe", il);
|
|
|
|
|
|
|
|
// Input for next layer
|
|
// Input for next layer
|
|
@@ -111,11 +120,30 @@ struct ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_t
|
|
|
const llama_model & model,
|
|
const llama_model & model,
|
|
|
const int64_t n_embd_head,
|
|
const int64_t n_embd_head,
|
|
|
const int il) {
|
|
const int il) {
|
|
|
- ggml_tensor * gate = build_lora_mm(model.layers[il].wq_gate, cur);
|
|
|
|
|
-
|
|
|
|
|
// compute Q and K and RoPE them
|
|
// compute Q and K and RoPE them
|
|
|
- struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
|
|
|
|
|
|
+ // Qwen3Next uses a single Q projection that outputs query + gate
|
|
|
|
|
+ struct ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur);
|
|
|
|
|
+ cb(Qcur_full, "Qcur_full", il);
|
|
|
|
|
+ Qcur_full = ggml_reshape_4d(ctx0, Qcur_full, n_embd_head * 2, n_head, n_tokens, 1);
|
|
|
|
|
+ // Split Q projection into query and gate
|
|
|
|
|
+ // The split should be along dimension 0 (the feature dimension)
|
|
|
|
|
+ struct ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0);
|
|
|
|
|
+ struct ggml_tensor * gate = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3],
|
|
|
|
|
+ n_embd_head * ggml_element_size(Qcur_full));
|
|
|
cb(Qcur, "Qcur", il);
|
|
cb(Qcur, "Qcur", il);
|
|
|
|
|
+ cb(gate, "gate", il);
|
|
|
|
|
+
|
|
|
|
|
+ // Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention
|
|
|
|
|
+ Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
|
|
|
|
+ cb(Qcur, "Qcur_reshaped", il);
|
|
|
|
|
+
|
|
|
|
|
+ // Apply Q normalization only to the query part
|
|
|
|
|
+ Qcur = build_q3n_norm(Qcur, model.layers[il].attn_q_norm, il);
|
|
|
|
|
+ cb(Qcur, "Qcur_normed", il);
|
|
|
|
|
+
|
|
|
|
|
+ // Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads)
|
|
|
|
|
+ gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
|
|
|
|
|
+ cb(gate, "gate_reshaped", il);
|
|
|
|
|
|
|
|
struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
|
struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
|
|
cb(Kcur, "Kcur", il);
|
|
cb(Kcur, "Kcur", il);
|
|
@@ -123,14 +151,12 @@ struct ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_t
|
|
|
struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
|
struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
|
|
cb(Vcur, "Vcur", il);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
|
|
|
|
- Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
|
|
|
|
|
|
+ Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
|
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
|
|
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
|
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
|
|
|
|
|
|
|
// Apply Q/K normalization
|
|
// Apply Q/K normalization
|
|
|
- Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
|
|
|
|
- Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
|
|
|
|
- cb(Kcur, "Qcur_normed", il);
|
|
|
|
|
|
|
+ Kcur = build_q3n_norm(Kcur, model.layers[il].attn_k_norm, il);
|
|
|
cb(Kcur, "Kcur_normed", il);
|
|
cb(Kcur, "Kcur_normed", il);
|
|
|
|
|
|
|
|
// Apply RoPE
|
|
// Apply RoPE
|
|
@@ -149,8 +175,8 @@ struct ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_t
|
|
|
hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
|
hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
|
|
cur = build_attn(inp_attn, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
|
cur = build_attn(inp_attn, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
|
|
|
|
|
|
|
- // Apply gating
|
|
|
|
|
- cur = ggml_cont(ctx0, ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)));
|
|
|
|
|
|
|
+ // Apply gating directly using the original gate tensor
|
|
|
|
|
+ cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate));
|
|
|
cb(cur, "attn_gated", il);
|
|
cb(cur, "attn_gated", il);
|
|
|
|
|
|
|
|
cur = build_lora_mm(model.layers[il].wo, cur);
|
|
cur = build_lora_mm(model.layers[il].wo, cur);
|
|
@@ -598,7 +624,8 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
|
|
|
return cur;
|
|
return cur;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const llama_model & model, const int il) {
|
|
|
|
|
|
|
+ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const llama_model & model, const int il, bool do_residual) {
|
|
|
|
|
+
|
|
|
// Check if this is an MoE layer
|
|
// Check if this is an MoE layer
|
|
|
if (model.layers[il].ffn_gate_inp != nullptr) {
|
|
if (model.layers[il].ffn_gate_inp != nullptr) {
|
|
|
// MoE branch
|
|
// MoE branch
|
|
@@ -608,13 +635,33 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const llam
|
|
|
n_expert_used, LLM_FFN_SILU, true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
|
|
n_expert_used, LLM_FFN_SILU, true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
|
|
|
cb(moe_out, "ffn_moe_out", il);
|
|
cb(moe_out, "ffn_moe_out", il);
|
|
|
|
|
|
|
|
- // Add shared experts if present
|
|
|
|
|
|
|
+ // Add shared experts if present - following Qwen3Next reference implementation
|
|
|
if (model.layers[il].ffn_up_shexp != nullptr) {
|
|
if (model.layers[il].ffn_up_shexp != nullptr) {
|
|
|
ggml_tensor * ffn_shexp =
|
|
ggml_tensor * ffn_shexp =
|
|
|
build_ffn(cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL,
|
|
build_ffn(cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL,
|
|
|
model.layers[il].ffn_down_shexp, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
|
|
model.layers[il].ffn_down_shexp, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
|
|
|
cb(ffn_shexp, "ffn_shexp", il);
|
|
cb(ffn_shexp, "ffn_shexp", il);
|
|
|
|
|
|
|
|
|
|
+ // Apply shared expert gating as in the reference implementation
|
|
|
|
|
+ // The shared expert has its own gate that is sigmoided
|
|
|
|
|
+ // Note: ffn_gate_inp_shexp is the shared expert gate (outputs 1 value per token)
|
|
|
|
|
+ ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur);
|
|
|
|
|
+ cb(shared_gate, "shared_expert_gate", il);
|
|
|
|
|
+
|
|
|
|
|
+ // Apply sigmoid to the gate
|
|
|
|
|
+ shared_gate = ggml_sigmoid(ctx0, shared_gate);
|
|
|
|
|
+ cb(shared_gate, "shared_expert_gate_sigmoid", il);
|
|
|
|
|
+
|
|
|
|
|
+ // The gate needs to be broadcast to match the dimensions of ffn_shexp
|
|
|
|
|
+ // ffn_shexp is [n_embd, n_tokens, 1, 1] and shared_gate is [1, n_tokens, 1, 1]
|
|
|
|
|
+ // We need to repeat the gate along the feature dimension
|
|
|
|
|
+ shared_gate = ggml_repeat(ctx0, shared_gate, ffn_shexp);
|
|
|
|
|
+ cb(shared_gate, "shared_expert_gate_broadcast", il);
|
|
|
|
|
+
|
|
|
|
|
+ // Apply the gate to the shared expert output
|
|
|
|
|
+ ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
|
|
|
|
|
+ cb(ffn_shexp, "ffn_shexp_gated", il);
|
|
|
|
|
+
|
|
|
cur = ggml_add(ctx0, moe_out, ffn_shexp);
|
|
cur = ggml_add(ctx0, moe_out, ffn_shexp);
|
|
|
cb(cur, "ffn_out", il);
|
|
cb(cur, "ffn_out", il);
|
|
|
} else {
|
|
} else {
|
|
@@ -626,9 +673,14 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const llam
|
|
|
model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
|
|
model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
|
|
|
cb(cur, "ffn_out", il);
|
|
cb(cur, "ffn_out", il);
|
|
|
}
|
|
}
|
|
|
- // Residual connection
|
|
|
|
|
- cur = ggml_add(ctx0, cur, cur); // This should be the residual from before FFN
|
|
|
|
|
- cb(cur, "ffn_residual", il);
|
|
|
|
|
|
|
+ // Residual connection (only if requested)
|
|
|
|
|
+ if (do_residual) {
|
|
|
|
|
+ cur = ggml_add(ctx0, cur, cur);
|
|
|
|
|
+ cb(cur, "ffn_residual", il);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ cur = build_cvec(cur, il);
|
|
|
|
|
+ cb(cur, "l_out", il);
|
|
|
|
|
|
|
|
return cur;
|
|
return cur;
|
|
|
};
|
|
};
|