|
|
@@ -90,6 +90,8 @@ const char * llm_type_name(llm_type type) {
|
|
|
case LLM_TYPE_57B_A14B: return "57B.A14B";
|
|
|
case LLM_TYPE_27B: return "27B";
|
|
|
case LLM_TYPE_290B: return "290B";
|
|
|
+ case LLM_TYPE_17B_16E: return "17Bx16E (Scout)";
|
|
|
+ case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
|
|
|
default: return "?B";
|
|
|
}
|
|
|
}
|
|
|
@@ -550,6 +552,25 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
|
}
|
|
|
}
|
|
|
} break;
|
|
|
+ case LLM_ARCH_LLAMA4:
|
|
|
+ {
|
|
|
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
|
|
+ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
|
|
+ ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step);
|
|
|
+ hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full
|
|
|
+ hparams.n_attn_chunk = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
|
|
|
+ hparams.n_swa = 1; // TODO @ngxson : this is added to trigger the SWA branch (we store the chunked attn mask in the SWA tensor), will need to clean this up later
|
|
|
+
|
|
|
+ switch (hparams.n_expert) {
|
|
|
+ case 16: type = LLM_TYPE_17B_16E; break;
|
|
|
+ case 128: type = LLM_TYPE_17B_128E; break;
|
|
|
+ default: type = LLM_TYPE_UNKNOWN;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (type == LLM_TYPE_17B_128E) {
|
|
|
+ hparams.use_kq_norm = false;
|
|
|
+ }
|
|
|
+ } break;
|
|
|
case LLM_ARCH_DECI:
|
|
|
{
|
|
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
|
|
@@ -1690,6 +1711,56 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|
|
}
|
|
|
}
|
|
|
} break;
|
|
|
+ case LLM_ARCH_LLAMA4:
|
|
|
+ {
|
|
|
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
+
|
|
|
+ // output
|
|
|
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
|
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
|
|
+
|
|
|
+ // if output is NULL, init from the input tok embed
|
|
|
+ if (output == NULL) {
|
|
|
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
|
|
+ }
|
|
|
+
|
|
|
+ GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Llama 4 requires n_moe_layer_step > 0");
|
|
|
+ for (int i = 0; i < n_layer; ++i) {
|
|
|
+ bool is_moe_layer = (i + 1) % hparams.n_moe_layer_step == 0;
|
|
|
+
|
|
|
+ auto & layer = layers[i];
|
|
|
+
|
|
|
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
|
+
|
|
|
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
|
|
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
|
|
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
|
|
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
|
|
+
|
|
|
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
|
+
|
|
|
+ layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
|
|
+
|
|
|
+ if (is_moe_layer) {
|
|
|
+ int n_ff_exp = hparams.n_ff_exp;
|
|
|
+
|
|
|
+ layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
|
|
+ layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
|
|
|
+ layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0);
|
|
|
+ layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
|
|
|
+
|
|
|
+ // Shared expert
|
|
|
+ const int64_t n_ff_shexp = n_ff_exp;
|
|
|
+ layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0);
|
|
|
+ layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd }, 0);
|
|
|
+ layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0);
|
|
|
+ } else {
|
|
|
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
|
|
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
|
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } break;
|
|
|
case LLM_ARCH_DECI:
|
|
|
{
|
|
|
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
@@ -4203,12 +4274,22 @@ struct llm_build_llama : public llm_graph_context {
|
|
|
// inp_pos - contains the positions
|
|
|
ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
|
|
+ // temperature tuning
|
|
|
+ ggml_tensor * inp_attn_scale = nullptr;
|
|
|
+ if (arch == LLM_ARCH_LLAMA4) {
|
|
|
+ inp_attn_scale = build_inp_attn_scale();
|
|
|
+ }
|
|
|
+
|
|
|
auto * inp_attn = build_attn_inp_kv_unified();
|
|
|
|
|
|
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
|
ggml_tensor * inpSA = inpL;
|
|
|
|
|
|
+ bool use_rope = arch == LLM_ARCH_LLAMA4
|
|
|
+ ? (il + 1) % hparams.n_no_rope_layer_step != 0
|
|
|
+ : true;
|
|
|
+
|
|
|
// norm
|
|
|
cur = build_norm(inpL,
|
|
|
model.layers[il].attn_norm, NULL,
|
|
|
@@ -4246,25 +4327,38 @@ struct llm_build_llama : public llm_graph_context {
|
|
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
|
|
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
|
|
|
|
|
- Qcur = ggml_rope_ext(
|
|
|
- ctx0, Qcur, inp_pos, rope_factors,
|
|
|
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
|
- ext_factor, attn_factor, beta_fast, beta_slow
|
|
|
- );
|
|
|
+ if (use_rope) {
|
|
|
+ Qcur = ggml_rope_ext(
|
|
|
+ ctx0, Qcur, inp_pos, rope_factors,
|
|
|
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
|
+ ext_factor, attn_factor, beta_fast, beta_slow
|
|
|
+ );
|
|
|
|
|
|
- Kcur = ggml_rope_ext(
|
|
|
- ctx0, Kcur, inp_pos, rope_factors,
|
|
|
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
|
- ext_factor, attn_factor, beta_fast, beta_slow
|
|
|
- );
|
|
|
+ Kcur = ggml_rope_ext(
|
|
|
+ ctx0, Kcur, inp_pos, rope_factors,
|
|
|
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
|
+ ext_factor, attn_factor, beta_fast, beta_slow
|
|
|
+ );
|
|
|
+ } else if (inp_attn_scale) {
|
|
|
+ Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale);
|
|
|
+ }
|
|
|
|
|
|
cb(Qcur, "Qcur", il);
|
|
|
cb(Kcur, "Kcur", il);
|
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
|
|
+ if (arch == LLM_ARCH_LLAMA4 && use_rope && hparams.use_kq_norm) {
|
|
|
+ // Llama4TextL2Norm
|
|
|
+ Qcur = ggml_rms_norm(ctx0, Qcur, 1e-6);
|
|
|
+ Kcur = ggml_rms_norm(ctx0, Kcur, 1e-6);
|
|
|
+ cb(Qcur, "Qcur_normed", il);
|
|
|
+ cb(Kcur, "Kcur_normed", il);
|
|
|
+ }
|
|
|
+
|
|
|
cur = build_attn(inp_attn, gf,
|
|
|
model.layers[il].wo, model.layers[il].bo,
|
|
|
Qcur, Kcur, Vcur, nullptr, kq_scale, il);
|
|
|
+ cb(cur, "attn_out", il);
|
|
|
}
|
|
|
|
|
|
if (il == n_layer - 1) {
|
|
|
@@ -4282,7 +4376,7 @@ struct llm_build_llama : public llm_graph_context {
|
|
|
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
|
|
- // feed-forward network
|
|
|
+ // feed-forward network (non-MoE)
|
|
|
if (model.layers[il].ffn_gate_inp == nullptr) {
|
|
|
|
|
|
cur = build_norm(ffn_inp,
|
|
|
@@ -4297,6 +4391,38 @@ struct llm_build_llama : public llm_graph_context {
|
|
|
NULL,
|
|
|
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
|
|
cb(cur, "ffn_out", il);
|
|
|
+
|
|
|
+ } else if (arch == LLM_ARCH_LLAMA4) {
|
|
|
+ // llama4 MoE
|
|
|
+ ggml_tensor * ffn_inp_normed = build_norm(ffn_inp,
|
|
|
+ model.layers[il].ffn_norm, NULL,
|
|
|
+ LLM_NORM_RMS, il);
|
|
|
+ cb(cur, "ffn_norm", il);
|
|
|
+
|
|
|
+ ggml_tensor * moe_out = build_moe_ffn(ffn_inp_normed,
|
|
|
+ model.layers[il].ffn_gate_inp,
|
|
|
+ model.layers[il].ffn_up_exps,
|
|
|
+ model.layers[il].ffn_gate_exps,
|
|
|
+ model.layers[il].ffn_down_exps,
|
|
|
+ nullptr,
|
|
|
+ n_expert, n_expert_used,
|
|
|
+ LLM_FFN_SILU, false,
|
|
|
+ false, 0.0,
|
|
|
+ LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID,
|
|
|
+ il);
|
|
|
+
|
|
|
+ // Shared experts
|
|
|
+ ggml_tensor * shexp_out = build_ffn(ffn_inp_normed,
|
|
|
+ model.layers[il].ffn_up_shexp, NULL, NULL,
|
|
|
+ model.layers[il].ffn_gate_shexp, NULL, NULL,
|
|
|
+ model.layers[il].ffn_down_shexp, NULL, NULL,
|
|
|
+ NULL,
|
|
|
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
|
|
|
+ cb(shexp_out, "ffn_moe_shexp", il);
|
|
|
+
|
|
|
+ cur = ggml_add(ctx0, moe_out, shexp_out);
|
|
|
+ cb(cur, "ffn_moe_out_merged", il);
|
|
|
+
|
|
|
} else {
|
|
|
// MoE branch
|
|
|
cur = build_norm(ffn_inp,
|
|
|
@@ -12091,6 +12217,7 @@ llm_graph_result_ptr llama_model::build_graph(
|
|
|
|
|
|
switch (arch) {
|
|
|
case LLM_ARCH_LLAMA:
|
|
|
+ case LLM_ARCH_LLAMA4:
|
|
|
case LLM_ARCH_MINICPM:
|
|
|
case LLM_ARCH_GRANITE:
|
|
|
case LLM_ARCH_GRANITE_MOE:
|
|
|
@@ -12440,6 +12567,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|
|
|
|
|
// use what we call a normal RoPE, operating on pairs of consecutive head values
|
|
|
case LLM_ARCH_LLAMA:
|
|
|
+ case LLM_ARCH_LLAMA4:
|
|
|
case LLM_ARCH_DECI:
|
|
|
case LLM_ARCH_BAICHUAN:
|
|
|
case LLM_ARCH_STARCODER:
|