|
|
@@ -12,7 +12,6 @@
|
|
|
#include "llama-memory-recurrent.h"
|
|
|
|
|
|
#include "ggml-cpp.h"
|
|
|
-#include "ggml-delta.h"
|
|
|
|
|
|
#include <algorithm>
|
|
|
#include <cassert>
|
|
|
@@ -18970,9 +18969,7 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
|
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
|
|
// Pre-norm for attention/linear attention
|
|
|
- cur = build_norm(inpL,
|
|
|
- model.layers[il].attn_norm, NULL,
|
|
|
- LLM_NORM_RMS, il);
|
|
|
+ cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
|
|
cb(cur, "attn_norm", il);
|
|
|
|
|
|
// Determine layer type and build appropriate attention mechanism
|
|
|
@@ -18981,19 +18978,15 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
|
|
|
cur = build_qwen3next_linear_attn_layer(inp->get_recr(), cur, model, ubatch, il);
|
|
|
} else {
|
|
|
// Full attention layer
|
|
|
- cur = build_qwen3next_attention_layer(
|
|
|
- cur, inp_pos, inp->get_attn(), model,
|
|
|
- n_embd_head, il);
|
|
|
+ cur = build_qwen3next_attention_layer(cur, inp_pos, inp->get_attn(), model, n_embd_head, il);
|
|
|
}
|
|
|
|
|
|
// Post-attention norm
|
|
|
- cur = build_norm(cur,
|
|
|
- model.layers[il].attn_post_norm, NULL,
|
|
|
- LLM_NORM_RMS, il);
|
|
|
+ cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il);
|
|
|
cb(cur, "attn_post_norm", il);
|
|
|
|
|
|
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
- cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
|
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
|
}
|
|
|
|
|
|
@@ -19011,9 +19004,7 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
|
|
|
cur = inpL;
|
|
|
|
|
|
// Final norm
|
|
|
- cur = build_norm(cur,
|
|
|
- model.output_norm, NULL,
|
|
|
- LLM_NORM_RMS, -1);
|
|
|
+ cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
|
|
|
|
|
|
cb(cur, "result_norm", -1);
|
|
|
res->t_embd = cur;
|
|
|
@@ -19028,15 +19019,148 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
}
|
|
|
|
|
|
-private:
|
|
|
- ggml_tensor * build_qwen3next_attention_layer(
|
|
|
- ggml_tensor * cur,
|
|
|
- ggml_tensor * inp_pos,
|
|
|
- llm_graph_input_attn_kv * inp_attn,
|
|
|
- const llama_model & model,
|
|
|
- const int64_t n_embd_head,
|
|
|
- const int il) {
|
|
|
+ private:
|
|
|
+ // ggml_delta_net
|
|
|
+ struct ggml_tensor * ggml_delta_net(struct ggml_tensor * k, struct ggml_tensor * v, struct ggml_tensor * q, struct ggml_tensor * g,
|
|
|
+ struct ggml_tensor * beta, struct ggml_tensor * state, bool use_qk_l2norm, float scale, int il) {
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(k));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(v));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(q));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(g));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(beta));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(state));
|
|
|
+
|
|
|
+ const int64_t S_k = k->ne[0];
|
|
|
+ const int64_t H_k = k->ne[1];
|
|
|
+ const int64_t n_tokens = k->ne[2];
|
|
|
+ const int64_t n_seqs = k->ne[3];
|
|
|
+
|
|
|
+ const int64_t S_v = v->ne[0];
|
|
|
+ const int64_t H_v = v->ne[1];
|
|
|
+
|
|
|
+ GGML_ASSERT(v->ne[2] == n_tokens);
|
|
|
+ GGML_ASSERT(q->ne[2] == n_tokens);
|
|
|
+ GGML_ASSERT(beta->ne[0] == H_v && beta->ne[1] == n_tokens && beta->ne[3] == n_seqs);
|
|
|
+ GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == n_seqs &&
|
|
|
+ state->ne[3] == n_tokens);
|
|
|
+
|
|
|
+ GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens);
|
|
|
+ GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens);
|
|
|
+
|
|
|
+ GGML_ASSERT(g->ne[0] == S_v && g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
|
|
|
+
|
|
|
+ // Beta sigmoid
|
|
|
+ struct ggml_tensor * beta_sigmoid = ggml_sigmoid(ctx0, beta);
|
|
|
+ cb(beta_sigmoid, "beta_sigmoid", il);
|
|
|
+
|
|
|
+ // Gate calculations are done elsewhere in llama-model.cpp
|
|
|
+
|
|
|
+ struct ggml_tensor * q_broadcast = q;
|
|
|
+ struct ggml_tensor * k_broadcast = k;
|
|
|
+
|
|
|
+ // if head keys and value keys are different, repeat to force tensors into matching shapes
|
|
|
+ if (H_k != H_v) {
|
|
|
+ GGML_ASSERT(H_v % H_k == 0);
|
|
|
+ int64_t repeat_factor = H_v / H_k;
|
|
|
|
|
|
+ q_broadcast = ggml_cont_4d(ctx0, q, S_k, n_tokens, H_k, n_seqs);
|
|
|
+ k_broadcast = ggml_cont_4d(ctx0, k, S_k, n_tokens, H_k, n_seqs);
|
|
|
+
|
|
|
+ q_broadcast = ggml_repeat_4d(ctx0, q_broadcast, S_k, n_tokens * repeat_factor, H_k, n_seqs);
|
|
|
+ k_broadcast = ggml_repeat_4d(ctx0, k_broadcast, S_k, n_tokens * repeat_factor, H_k, n_seqs);
|
|
|
+
|
|
|
+ q_broadcast = ggml_reshape_4d(ctx0, q_broadcast, S_k, H_v, n_seqs, n_tokens);
|
|
|
+ k_broadcast = ggml_reshape_4d(ctx0, k_broadcast, S_k, H_v, n_seqs, n_tokens);
|
|
|
+ }
|
|
|
+
|
|
|
+ struct ggml_tensor * v_reshape = ggml_cont_4d(ctx0, v, S_v, H_v, n_seqs, n_tokens);
|
|
|
+ struct ggml_tensor * g_reshape = ggml_cont_4d(ctx0, g, S_v, H_v, n_seqs, n_tokens);
|
|
|
+ struct ggml_tensor * beta_broadcast = ggml_cont_4d(ctx0, beta_sigmoid, 1, H_v, n_seqs, n_tokens);
|
|
|
+ struct ggml_tensor * state_broadcast = ggml_cont(ctx0, state);
|
|
|
+
|
|
|
+ return ggml_delta_net_op(q_broadcast, k_broadcast, v_reshape, g_reshape, beta_broadcast, state_broadcast,
|
|
|
+ use_qk_l2norm, scale, il);
|
|
|
+ }
|
|
|
+
|
|
|
+ struct ggml_tensor * ggml_delta_net_op(struct ggml_tensor * q, struct ggml_tensor * k, struct ggml_tensor * v, struct ggml_tensor * g,
|
|
|
+ struct ggml_tensor * beta, struct ggml_tensor * state, bool use_qk_l2norm, float scale, int il) {
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(q));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(k));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(v));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(g));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(beta));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(state));
|
|
|
+
|
|
|
+ const int64_t S_k = q->ne[0];
|
|
|
+ const int64_t H_k = q->ne[1];
|
|
|
+ const int64_t n_seq = q->ne[2];
|
|
|
+ const int64_t n_tokens = q->ne[3];
|
|
|
+
|
|
|
+ const int64_t S_v = v->ne[0];
|
|
|
+ const int64_t H_v = v->ne[1];
|
|
|
+
|
|
|
+ GGML_ASSERT(H_k == H_v); // we broadcasted the tensors in the main function to guarantee this
|
|
|
+
|
|
|
+ GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_v && k->ne[2] == n_seq && k->ne[3] == n_tokens);
|
|
|
+ GGML_ASSERT(v->ne[1] == H_v && v->ne[2] == n_seq && v->ne[3] == n_tokens);
|
|
|
+ GGML_ASSERT(g->ne[0] == S_v && g->ne[1] == H_v && g->ne[2] == n_seq && g->ne[3] == n_tokens);
|
|
|
+ GGML_ASSERT(beta->ne[0] == 1 && beta->ne[1] == H_v && beta->ne[2] == n_seq && beta->ne[3] == n_tokens);
|
|
|
+ GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == n_seq &&
|
|
|
+ state->ne[3] == n_tokens);
|
|
|
+
|
|
|
+ struct ggml_tensor * new_state = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, S_v, S_v * H_v, n_seq, n_tokens);
|
|
|
+
|
|
|
+ new_state = ggml_cpy(ctx0, state, new_state);
|
|
|
+ cb(new_state, "new_state", il);
|
|
|
+
|
|
|
+ if (use_qk_l2norm) {
|
|
|
+ q = ggml_l2_norm(ctx0, q, 1e-6f);
|
|
|
+ cb(q, "q_l2_norm", il);
|
|
|
+ k = ggml_l2_norm(ctx0, k, 1e-6f);
|
|
|
+ cb(q, "k_l2_norm", il);
|
|
|
+ }
|
|
|
+
|
|
|
+ q = ggml_scale(ctx0, q, scale);
|
|
|
+ cb(q, "q_scaled", il);
|
|
|
+
|
|
|
+ struct ggml_tensor * state_decay = ggml_mul(ctx0, state, g);
|
|
|
+ cb(state_decay, "state_decay", il);
|
|
|
+ struct ggml_tensor * kv_mem_presum = ggml_mul(ctx0, state_decay, k);
|
|
|
+
|
|
|
+ // Gotta do some squeezing here...
|
|
|
+ struct ggml_tensor * kv_mem_presum_squeeze =
|
|
|
+ ggml_reshape_4d(ctx0, kv_mem_presum, S_v, S_v, H_v, n_seq * n_tokens);
|
|
|
+
|
|
|
+ struct ggml_tensor * kv_mem = ggml_permute(
|
|
|
+ ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, kv_mem_presum_squeeze, 1, 2, 0, 3))), 2, 0, 1, 3);
|
|
|
+ cb(kv_mem, "kv_mem", il);
|
|
|
+ struct ggml_tensor * kv_mem_reshape = ggml_reshape_4d(ctx0, kv_mem, S_v, S_v, n_seq, n_tokens);
|
|
|
+ struct ggml_tensor * delta = ggml_mul(ctx0, ggml_sub(ctx0, kv_mem_reshape, v), beta);
|
|
|
+ cb(delta, "delta", il);
|
|
|
+ struct ggml_tensor * delta_kt = ggml_mul(ctx0, delta, k);
|
|
|
+ cb(delta_kt, "delta_kt", il);
|
|
|
+ struct ggml_tensor * state_plus_k_delta = ggml_add(ctx0, state_decay, delta_kt);
|
|
|
+ cb(state_plus_k_delta, "state_plus_k_delta", il);
|
|
|
+ struct ggml_tensor * state_q = ggml_mul(ctx0, state_plus_k_delta, q);
|
|
|
+ cb(state_q, "state_q", il);
|
|
|
+
|
|
|
+ // And here...
|
|
|
+ state_q = ggml_reshape_4d(ctx0, state_q, S_v, S_v, H_v, n_seq * n_tokens);
|
|
|
+ struct ggml_tensor * output = ggml_permute(ctx0, ggml_sum_rows(ctx0, state_q), 2, 0, 1, 3);
|
|
|
+ output = ggml_reshape_4d(ctx0, output, S_v, H_v, n_seq, n_tokens);
|
|
|
+ cb(output, "delta_net_output", il);
|
|
|
+
|
|
|
+ struct ggml_tensor * result = ggml_concat(ctx0, output, state_plus_k_delta, 1);
|
|
|
+ cb(result, "delta_net_result", il);
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+
|
|
|
+ ggml_tensor * build_qwen3next_attention_layer(ggml_tensor * cur,
|
|
|
+ ggml_tensor * inp_pos,
|
|
|
+ llm_graph_input_attn_kv * inp_attn,
|
|
|
+ const llama_model & model,
|
|
|
+ const int64_t n_embd_head,
|
|
|
+ const int il) {
|
|
|
ggml_tensor * gate = build_lora_mm(model.layers[il].wq_gate, cur);
|
|
|
|
|
|
// compute Q and K and RoPE them
|
|
|
@@ -19060,30 +19184,26 @@ private:
|
|
|
cb(Kcur, "Kcur_normed", il);
|
|
|
|
|
|
// Apply RoPE
|
|
|
- Qcur = ggml_rope_ext(
|
|
|
- ctx0, Qcur, inp_pos, nullptr,
|
|
|
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
|
- ext_factor, attn_factor, beta_fast, beta_slow);
|
|
|
+ Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
|
+ ext_factor, attn_factor, beta_fast, beta_slow);
|
|
|
|
|
|
- Kcur = ggml_rope_ext(
|
|
|
- ctx0, Kcur, inp_pos, nullptr,
|
|
|
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
|
- ext_factor, attn_factor, beta_fast, beta_slow);
|
|
|
+ Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
|
+ ext_factor, attn_factor, beta_fast, beta_slow);
|
|
|
|
|
|
cb(Qcur, "Qcur", il);
|
|
|
cb(Kcur, "Kcur", il);
|
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
|
|
// Attention computation
|
|
|
- const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
|
|
- cur = build_attn(inp_attn,
|
|
|
- model.layers[il].wo, nullptr,
|
|
|
- Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
|
|
-
|
|
|
+ const float kq_scale =
|
|
|
+ hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
|
|
+ cur = build_attn(inp_attn, model.layers[il].wo, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale,
|
|
|
+ il);
|
|
|
+
|
|
|
// Apply gating
|
|
|
cur = ggml_cont(ctx0, ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)));
|
|
|
cb(cur, "attn_gated", il);
|
|
|
-
|
|
|
+
|
|
|
return cur;
|
|
|
}
|
|
|
|
|
|
@@ -19252,16 +19372,18 @@ private:
|
|
|
cb(conv_output, "conv_output_final", il);
|
|
|
|
|
|
// Extract the convolved Q, K, V from conv_output
|
|
|
- ggml_tensor * q_conv = ggml_cont(ctx0, ggml_view_4d(ctx0, conv_output, head_k_dim, num_k_heads, n_tokens, n_seqs,
|
|
|
- head_k_dim, conv_output->nb[1], conv_output->nb[2], 0));
|
|
|
- cb(q_conv, "q_conv", il);
|
|
|
- ggml_tensor * k_conv =
|
|
|
+ ggml_tensor * q_conv =
|
|
|
ggml_cont(ctx0, ggml_view_4d(ctx0, conv_output, head_k_dim, num_k_heads, n_tokens, n_seqs, head_k_dim,
|
|
|
- conv_output->nb[1], conv_output->nb[2], head_k_dim * num_k_heads * ggml_element_size(conv_output)));
|
|
|
+ conv_output->nb[1], conv_output->nb[2], 0));
|
|
|
+ cb(q_conv, "q_conv", il);
|
|
|
+ ggml_tensor * k_conv = ggml_cont(ctx0, ggml_view_4d(ctx0, conv_output, head_k_dim, num_k_heads, n_tokens,
|
|
|
+ n_seqs, head_k_dim, conv_output->nb[1], conv_output->nb[2],
|
|
|
+ head_k_dim * num_k_heads * ggml_element_size(conv_output)));
|
|
|
cb(q_conv, "k_conv", il);
|
|
|
ggml_tensor * v_conv =
|
|
|
ggml_cont(ctx0, ggml_view_4d(ctx0, conv_output, head_v_dim, num_v_heads, n_tokens, n_seqs, head_v_dim,
|
|
|
- conv_output->nb[1], conv_output->nb[2], 2 * head_k_dim * num_k_heads * ggml_element_size(conv_output)));
|
|
|
+ conv_output->nb[1], conv_output->nb[2],
|
|
|
+ 2 * head_k_dim * num_k_heads * ggml_element_size(conv_output)));
|
|
|
cb(q_conv, "v_conv", il);
|
|
|
|
|
|
ggml_build_forward_expand(gf, ssm_states_all);
|
|
|
@@ -19274,28 +19396,21 @@ private:
|
|
|
ggml_tensor * target_gate = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_dim, n_heads, n_tokens, n_seqs);
|
|
|
ggml_tensor * gate_broadcast = ggml_reshape_4d(ctx0, gate, 1, n_heads, n_tokens, n_seqs);
|
|
|
gate = ggml_repeat(ctx0, gate_broadcast, target_gate);
|
|
|
+ cb(gate, "gate", il);
|
|
|
|
|
|
// Call the new ggml_delta_net function with the corrected flow
|
|
|
- ggml_tensor * output = ggml_delta_net(ctx0,
|
|
|
- k_conv, // k tensor (already convolved)
|
|
|
- v_conv, // v tensor (already convolved)
|
|
|
- q_conv, // q tensor (already convolved)
|
|
|
- gate, // g tensor
|
|
|
- beta, // beta tensor
|
|
|
- state_broadcast, // state tensor
|
|
|
- true, // use_qk_l2norm
|
|
|
- 1.0f // scale
|
|
|
- );
|
|
|
- cb(output, "delta_net_output", il);
|
|
|
+ ggml_tensor * output = ggml_delta_net(k_conv, v_conv, q_conv, gate, beta, state_broadcast, true, 1.0f, il);
|
|
|
|
|
|
// Extract the output part
|
|
|
ggml_tensor * attn_out = ggml_view_4d(ctx0, output, head_dim, n_heads, n_tokens, n_seqs, output->nb[0],
|
|
|
output->nb[1], output->nb[2], 0);
|
|
|
+ cb(output, "attn_out", il);
|
|
|
|
|
|
// Extract the new state
|
|
|
ggml_tensor * new_state =
|
|
|
ggml_view_4d(ctx0, output, head_dim, head_dim * n_heads, n_tokens, n_seqs, output->nb[0], output->nb[1],
|
|
|
output->nb[2], n_tokens * n_seqs * head_dim * n_heads * ggml_element_size(output));
|
|
|
+ cb(output, "new_state", il);
|
|
|
|
|
|
// Only return the last recurrent state
|
|
|
struct ggml_tensor * state_reshaped =
|
|
|
@@ -19303,6 +19418,7 @@ private:
|
|
|
struct ggml_tensor * state_last = ggml_view_4d(
|
|
|
ctx0, state_reshaped, head_dim, head_dim, n_heads, 1, state_reshaped->nb[1], state_reshaped->nb[2],
|
|
|
state_reshaped->nb[3], head_dim * head_dim * n_heads * ((n_seqs * n_tokens) - 1));
|
|
|
+ cb(output, "new_state_last", il);
|
|
|
|
|
|
// Update the recurrent states
|
|
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, state_last, ssm_states_all));
|
|
|
@@ -19318,16 +19434,20 @@ private:
|
|
|
// Apply gated normalization: self.norm(core_attn_out, z)
|
|
|
// This is Qwen3NextRMSNormGated which applies: RMSNorm(x) * silu(gate)
|
|
|
ggml_tensor * attn_out_norm = build_norm(attn_out_2d, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
|
|
|
+ cb(output, "attn_out_norm", il);
|
|
|
|
|
|
// Apply silu gate: attn_out_norm * silu(z_2d)
|
|
|
ggml_tensor * z_silu = ggml_silu(ctx0, z_2d);
|
|
|
+ cb(output, "z_silu", il);
|
|
|
ggml_tensor * gated_output = ggml_mul(ctx0, attn_out_norm, z_silu);
|
|
|
+ cb(output, "gated_output", il);
|
|
|
|
|
|
// Reshape back to original dimensions: [n_heads * n_tokens * n_seqs, head_dim] -> [head_dim, n_heads, n_tokens, n_seqs]
|
|
|
ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, gated_output, head_dim, n_heads, n_tokens, n_seqs);
|
|
|
|
|
|
// Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
|
|
|
ggml_tensor * final_output = ggml_reshape_3d(ctx0, gated_output_4d, n_heads * head_dim, n_tokens, n_seqs);
|
|
|
+ cb(output, "final_output", il);
|
|
|
|
|
|
// Output projection
|
|
|
cur = build_lora_mm(model.layers[il].ssm_out, final_output);
|
|
|
@@ -19343,27 +19463,17 @@ private:
|
|
|
// Check if this is an MoE layer
|
|
|
if (model.layers[il].ffn_gate_inp != nullptr) {
|
|
|
// MoE branch
|
|
|
- ggml_tensor * moe_out = build_moe_ffn(cur,
|
|
|
- model.layers[il].ffn_gate_inp,
|
|
|
- model.layers[il].ffn_up_exps,
|
|
|
- model.layers[il].ffn_gate_exps,
|
|
|
- model.layers[il].ffn_down_exps,
|
|
|
- nullptr,
|
|
|
- n_expert, n_expert_used,
|
|
|
- LLM_FFN_SILU, true,
|
|
|
- false, 0.0,
|
|
|
- LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
|
|
- il);
|
|
|
+ ggml_tensor * moe_out =
|
|
|
+ build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
|
|
|
+ model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, nullptr, n_expert,
|
|
|
+ n_expert_used, LLM_FFN_SILU, true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
|
|
|
cb(moe_out, "ffn_moe_out", il);
|
|
|
|
|
|
// Add shared experts if present
|
|
|
if (model.layers[il].ffn_up_shexp != nullptr) {
|
|
|
- ggml_tensor * ffn_shexp = build_ffn(cur,
|
|
|
- model.layers[il].ffn_up_shexp, NULL, NULL,
|
|
|
- model.layers[il].ffn_gate_shexp, NULL, NULL,
|
|
|
- model.layers[il].ffn_down_shexp, NULL, NULL,
|
|
|
- NULL,
|
|
|
- LLM_FFN_SILU, LLM_FFN_PAR, il);
|
|
|
+ ggml_tensor * ffn_shexp =
|
|
|
+ build_ffn(cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL,
|
|
|
+ NULL, model.layers[il].ffn_down_shexp, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
|
|
|
cb(ffn_shexp, "ffn_shexp", il);
|
|
|
|
|
|
cur = ggml_add(ctx0, moe_out, ffn_shexp);
|
|
|
@@ -19373,17 +19483,13 @@ private:
|
|
|
}
|
|
|
} else {
|
|
|
// Dense FFN branch
|
|
|
- cur = build_ffn(cur,
|
|
|
- model.layers[il].ffn_up, NULL, NULL,
|
|
|
- model.layers[il].ffn_gate, NULL, NULL,
|
|
|
- model.layers[il].ffn_down, NULL, NULL,
|
|
|
- NULL,
|
|
|
- LLM_FFN_SILU, LLM_FFN_PAR, il);
|
|
|
+ cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL,
|
|
|
+ model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
|
|
|
cb(cur, "ffn_out", il);
|
|
|
}
|
|
|
|
|
|
// Residual connection
|
|
|
- cur = ggml_add(ctx0, cur, cur); // This should be the residual from before FFN
|
|
|
+ cur = ggml_add(ctx0, cur, cur); // This should be the residual from before FFN
|
|
|
cb(cur, "ffn_residual", il);
|
|
|
|
|
|
return cur;
|
|
|
@@ -19398,7 +19504,6 @@ private:
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-
|
|
|
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
|
|
|
llama_memory_i * res;
|
|
|
|