|
@@ -105,7 +105,7 @@
|
|
|
#endif
|
|
#endif
|
|
|
|
|
|
|
|
#define LLAMA_MAX_NODES 8192
|
|
#define LLAMA_MAX_NODES 8192
|
|
|
-#define LLAMA_MAX_EXPERTS 8
|
|
|
|
|
|
|
+#define LLAMA_MAX_EXPERTS 16
|
|
|
|
|
|
|
|
|
|
|
|
|
//
|
|
//
|
|
@@ -220,6 +220,7 @@ enum llm_arch {
|
|
|
LLM_ARCH_MAMBA,
|
|
LLM_ARCH_MAMBA,
|
|
|
LLM_ARCH_XVERSE,
|
|
LLM_ARCH_XVERSE,
|
|
|
LLM_ARCH_COMMAND_R,
|
|
LLM_ARCH_COMMAND_R,
|
|
|
|
|
+ LLM_ARCH_DBRX,
|
|
|
LLM_ARCH_UNKNOWN,
|
|
LLM_ARCH_UNKNOWN,
|
|
|
};
|
|
};
|
|
|
|
|
|
|
@@ -252,6 +253,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|
|
{ LLM_ARCH_MAMBA, "mamba" },
|
|
{ LLM_ARCH_MAMBA, "mamba" },
|
|
|
{ LLM_ARCH_XVERSE, "xverse" },
|
|
{ LLM_ARCH_XVERSE, "xverse" },
|
|
|
{ LLM_ARCH_COMMAND_R, "command-r" },
|
|
{ LLM_ARCH_COMMAND_R, "command-r" },
|
|
|
|
|
+ { LLM_ARCH_DBRX, "dbrx" },
|
|
|
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
|
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
|
|
};
|
|
};
|
|
|
|
|
|
|
@@ -934,6 +936,22 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
|
|
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
|
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
|
|
},
|
|
},
|
|
|
},
|
|
},
|
|
|
|
|
+ {
|
|
|
|
|
+ LLM_ARCH_DBRX,
|
|
|
|
|
+ {
|
|
|
|
|
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
|
|
|
|
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
|
|
|
|
+ { LLM_TENSOR_OUTPUT, "output" },
|
|
|
|
|
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
|
|
|
|
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
|
|
|
|
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
|
|
|
|
+ { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
|
|
|
|
|
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
|
|
|
|
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
|
|
|
|
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
|
|
|
|
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
|
|
|
|
+ },
|
|
|
|
|
+ },
|
|
|
{
|
|
{
|
|
|
LLM_ARCH_UNKNOWN,
|
|
LLM_ARCH_UNKNOWN,
|
|
|
{
|
|
{
|
|
@@ -1707,6 +1725,7 @@ enum e_model {
|
|
|
MODEL_XL,
|
|
MODEL_XL,
|
|
|
MODEL_8x7B,
|
|
MODEL_8x7B,
|
|
|
MODEL_8x22B,
|
|
MODEL_8x22B,
|
|
|
|
|
+ MODEL_16x12B,
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
static const size_t kiB = 1024;
|
|
static const size_t kiB = 1024;
|
|
@@ -3562,6 +3581,7 @@ static const char * llama_model_type_name(e_model type) {
|
|
|
case MODEL_XL: return "1.5B";
|
|
case MODEL_XL: return "1.5B";
|
|
|
case MODEL_8x7B: return "8x7B";
|
|
case MODEL_8x7B: return "8x7B";
|
|
|
case MODEL_8x22B: return "8x22B";
|
|
case MODEL_8x22B: return "8x22B";
|
|
|
|
|
+ case MODEL_16x12B: return "16x12B";
|
|
|
default: return "?B";
|
|
default: return "?B";
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
@@ -3983,6 +4003,16 @@ static void llm_load_hparams(
|
|
|
default: model.type = e_model::MODEL_UNKNOWN;
|
|
default: model.type = e_model::MODEL_UNKNOWN;
|
|
|
}
|
|
}
|
|
|
} break;
|
|
} break;
|
|
|
|
|
+ case LLM_ARCH_DBRX:
|
|
|
|
|
+ {
|
|
|
|
|
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
|
|
|
|
+ ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv);
|
|
|
|
|
+
|
|
|
|
|
+ switch (hparams.n_layer) {
|
|
|
|
|
+ case 40: model.type = e_model::MODEL_16x12B; break;
|
|
|
|
|
+ default: model.type = e_model::MODEL_UNKNOWN;
|
|
|
|
|
+ }
|
|
|
|
|
+ } break;
|
|
|
default: (void)0;
|
|
default: (void)0;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -4671,6 +4701,39 @@ static bool llm_load_tensors(
|
|
|
layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
|
|
layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
|
|
|
}
|
|
}
|
|
|
} break;
|
|
} break;
|
|
|
|
|
+ case LLM_ARCH_DBRX:
|
|
|
|
|
+ {
|
|
|
|
|
+ if (n_expert == 0) {
|
|
|
|
|
+ throw std::runtime_error("DBRX model cannot have zero experts");
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
|
|
|
|
+
|
|
|
|
|
+ // output
|
|
|
|
|
+ {
|
|
|
|
|
+ model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
|
|
|
|
+ model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ for (int i = 0; i < n_layer; ++i) {
|
|
|
|
|
+ ggml_context * ctx_layer = ctx_for_layer(i);
|
|
|
|
|
+ ggml_context * ctx_split = ctx_for_layer_split(i);
|
|
|
|
|
+
|
|
|
|
|
+ auto & layer = model.layers[i];
|
|
|
|
|
+
|
|
|
|
|
+ layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
|
|
|
|
+
|
|
|
|
|
+ layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
|
|
|
|
|
+ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
|
|
|
|
|
+
|
|
|
|
|
+ layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd});
|
|
|
|
|
+
|
|
|
|
|
+ layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
|
|
|
|
|
+ layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert});
|
|
|
|
|
+ layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert});
|
|
|
|
|
+ layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert});
|
|
|
|
|
+ }
|
|
|
|
|
+ } break;
|
|
|
case LLM_ARCH_BAICHUAN:
|
|
case LLM_ARCH_BAICHUAN:
|
|
|
{
|
|
{
|
|
|
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
|
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
|
@@ -6433,62 +6496,7 @@ struct llm_build_context {
|
|
|
LLM_NORM_RMS, cb, il);
|
|
LLM_NORM_RMS, cb, il);
|
|
|
cb(cur, "ffn_norm", il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
|
|
|
|
- ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
|
|
|
|
|
- cb(logits, "ffn_moe_logits", il);
|
|
|
|
|
-
|
|
|
|
|
- ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
|
|
|
|
|
- cb(probs, "ffn_moe_probs", il);
|
|
|
|
|
-
|
|
|
|
|
- // select experts
|
|
|
|
|
- ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
|
|
|
|
|
- cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
|
|
|
|
-
|
|
|
|
|
- ggml_tensor * weights = ggml_get_rows(ctx0,
|
|
|
|
|
- ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
|
|
|
|
|
- cb(weights, "ffn_moe_weights", il);
|
|
|
|
|
-
|
|
|
|
|
- weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
|
|
|
|
|
-
|
|
|
|
|
- ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
|
|
|
|
|
- cb(weights_sum, "ffn_moe_weights_sum", il);
|
|
|
|
|
-
|
|
|
|
|
- weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
|
|
|
|
|
- cb(weights, "ffn_moe_weights_norm", il);
|
|
|
|
|
-
|
|
|
|
|
- // compute expert outputs
|
|
|
|
|
- ggml_tensor * moe_out = nullptr;
|
|
|
|
|
-
|
|
|
|
|
- for (int i = 0; i < n_expert_used; ++i) {
|
|
|
|
|
- ggml_tensor * cur_expert;
|
|
|
|
|
-
|
|
|
|
|
- ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur);
|
|
|
|
|
- cb(cur_up, "ffn_moe_up", il);
|
|
|
|
|
-
|
|
|
|
|
- ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur);
|
|
|
|
|
- cb(cur_gate, "ffn_moe_gate", il);
|
|
|
|
|
-
|
|
|
|
|
- cur_gate = ggml_silu(ctx0, cur_gate);
|
|
|
|
|
- cb(cur_gate, "ffn_moe_silu", il);
|
|
|
|
|
-
|
|
|
|
|
- cur_expert = ggml_mul(ctx0, cur_up, cur_gate);
|
|
|
|
|
- cb(cur_expert, "ffn_moe_gate_par", il);
|
|
|
|
|
-
|
|
|
|
|
- cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd]
|
|
|
|
|
- cb(cur_expert, "ffn_moe_down", il);
|
|
|
|
|
-
|
|
|
|
|
- cur_expert = ggml_mul(ctx0, cur_expert,
|
|
|
|
|
- ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
|
|
|
|
|
- cb(cur_expert, "ffn_moe_weighted", il);
|
|
|
|
|
-
|
|
|
|
|
- if (i == 0) {
|
|
|
|
|
- moe_out = cur_expert;
|
|
|
|
|
- } else {
|
|
|
|
|
- moe_out = ggml_add(ctx0, moe_out, cur_expert);
|
|
|
|
|
- cb(moe_out, "ffn_moe_out", il);
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- cur = moe_out;
|
|
|
|
|
|
|
+ cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, il);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
@@ -6520,6 +6528,78 @@ struct llm_build_context {
|
|
|
return gf;
|
|
return gf;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ // REVIEW: will be replaced by https://github.com/ggerganov/llama.cpp/pull/6505
|
|
|
|
|
+ ggml_tensor * build_moe_ffn(ggml_tensor * cur, int32_t n_tokens, llm_ffn_op_type type_op, int il) {
|
|
|
|
|
+ ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
|
|
|
|
|
+ cb(logits, "ffn_moe_logits", il);
|
|
|
|
|
+
|
|
|
|
|
+ ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
|
|
|
|
|
+ cb(probs, "ffn_moe_probs", il);
|
|
|
|
|
+
|
|
|
|
|
+ // select experts
|
|
|
|
|
+ ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
|
|
|
|
|
+ cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
|
|
|
|
+
|
|
|
|
|
+ ggml_tensor * weights = ggml_get_rows(ctx0,
|
|
|
|
|
+ ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
|
|
|
|
|
+ cb(weights, "ffn_moe_weights", il);
|
|
|
|
|
+
|
|
|
|
|
+ weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
|
|
|
|
|
+
|
|
|
|
|
+ ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
|
|
|
|
|
+ cb(weights_sum, "ffn_moe_weights_sum", il);
|
|
|
|
|
+
|
|
|
|
|
+ weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
|
|
|
|
|
+ cb(weights, "ffn_moe_weights_norm", il);
|
|
|
|
|
+
|
|
|
|
|
+ // compute expert outputs
|
|
|
|
|
+ ggml_tensor * moe_out = nullptr;
|
|
|
|
|
+
|
|
|
|
|
+ for (int i = 0; i < n_expert_used; ++i) {
|
|
|
|
|
+ ggml_tensor * cur_expert;
|
|
|
|
|
+
|
|
|
|
|
+ ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur);
|
|
|
|
|
+ cb(cur_up, "ffn_moe_up", il);
|
|
|
|
|
+
|
|
|
|
|
+ ggml_tensor * gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur);
|
|
|
|
|
+ cb(gate, "ffn_moe_gate", il);
|
|
|
|
|
+
|
|
|
|
|
+ switch (type_op) {
|
|
|
|
|
+ case LLM_FFN_SILU:
|
|
|
|
|
+ {
|
|
|
|
|
+ gate = ggml_silu(ctx0, gate);
|
|
|
|
|
+ cb(gate, "ffn_moe_silu", il);
|
|
|
|
|
+ } break;
|
|
|
|
|
+ case LLM_FFN_GELU:
|
|
|
|
|
+ {
|
|
|
|
|
+ gate = ggml_gelu(ctx0, gate);
|
|
|
|
|
+ cb(gate, "ffn_moe_gelu", il);
|
|
|
|
|
+ } break;
|
|
|
|
|
+ default:
|
|
|
|
|
+ GGML_ASSERT(false);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ cur_expert = ggml_mul(ctx0, cur_up, gate);
|
|
|
|
|
+ cb(cur_expert, "ffn_moe_gate_par", il);
|
|
|
|
|
+
|
|
|
|
|
+ cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd]
|
|
|
|
|
+ cb(cur_expert, "ffn_moe_down", il);
|
|
|
|
|
+
|
|
|
|
|
+ cur_expert = ggml_mul(ctx0, cur_expert,
|
|
|
|
|
+ ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
|
|
|
|
|
+ cb(cur_expert, "ffn_moe_weighted", il);
|
|
|
|
|
+
|
|
|
|
|
+ if (i == 0) {
|
|
|
|
|
+ moe_out = cur_expert;
|
|
|
|
|
+ } else {
|
|
|
|
|
+ moe_out = ggml_add(ctx0, moe_out, cur_expert);
|
|
|
|
|
+ cb(moe_out, "ffn_moe_out", il);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return moe_out;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
struct ggml_cgraph * build_baichuan() {
|
|
struct ggml_cgraph * build_baichuan() {
|
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
|
|
|
|
|
|
@@ -6967,74 +7047,143 @@ struct llm_build_context {
|
|
|
LLM_NORM_RMS, cb, il);
|
|
LLM_NORM_RMS, cb, il);
|
|
|
cb(cur, "ffn_norm", il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
|
|
|
|
- ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
|
|
|
|
|
- cb(logits, "ffn_moe_logits", il);
|
|
|
|
|
|
|
+ cur = build_moe_ffn(cur, n_tokens, LLM_FFN_GELU, il);
|
|
|
|
|
+
|
|
|
|
|
+ // Grok
|
|
|
|
|
+ // if layer_out_norm is present then apply it before adding the input
|
|
|
|
|
+ // Idea: maybe ffn_out_norm is a better name
|
|
|
|
|
+ if (model.layers[il].layer_out_norm) {
|
|
|
|
|
+ cur = llm_build_norm(ctx0, cur, hparams,
|
|
|
|
|
+ model.layers[il].layer_out_norm, NULL,
|
|
|
|
|
+ LLM_NORM_RMS, cb, il);
|
|
|
|
|
+ cb(cur, "layer_out_norm", il);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ cur = ggml_add(ctx0, cur, ffn_inp);
|
|
|
|
|
+ cb(cur, "ffn_out", il);
|
|
|
|
|
+
|
|
|
|
|
+ ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
|
|
|
|
|
+ if (layer_dir != nullptr) {
|
|
|
|
|
+ cur = ggml_add(ctx0, cur, layer_dir);
|
|
|
|
|
+ }
|
|
|
|
|
+ cb(cur, "l_out", il);
|
|
|
|
|
+
|
|
|
|
|
+ // input for next layer
|
|
|
|
|
+ inpL = cur;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ cur = inpL;
|
|
|
|
|
|
|
|
- ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
|
|
|
|
|
- cb(probs, "ffn_moe_probs", il);
|
|
|
|
|
|
|
+ cur = llm_build_norm(ctx0, cur, hparams,
|
|
|
|
|
+ model.output_norm, NULL,
|
|
|
|
|
+ LLM_NORM_RMS, cb, -1);
|
|
|
|
|
+ cb(cur, "result_norm", -1);
|
|
|
|
|
|
|
|
- // select experts
|
|
|
|
|
- ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
|
|
|
|
|
- cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
|
|
|
|
|
|
+ // lm_head
|
|
|
|
|
+ cur = ggml_mul_mat(ctx0, model.output, cur);
|
|
|
|
|
|
|
|
- ggml_tensor * weights = ggml_get_rows(ctx0,
|
|
|
|
|
- ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
|
|
|
|
|
- cb(weights, "ffn_moe_weights", il);
|
|
|
|
|
|
|
+ // Grok
|
|
|
|
|
+ // multiply logits by output_multiplier_scale of 0.5773502691896257
|
|
|
|
|
|
|
|
- weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
|
|
|
|
|
|
|
+ cur = ggml_scale(ctx0, cur, 0.5773502691896257f);
|
|
|
|
|
|
|
|
- ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
|
|
|
|
|
- cb(weights_sum, "ffn_moe_weights_sum", il);
|
|
|
|
|
|
|
+ cb(cur, "result_output", -1);
|
|
|
|
|
|
|
|
- weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
|
|
|
|
|
- cb(weights, "ffn_moe_weights_norm", il);
|
|
|
|
|
|
|
+ ggml_build_forward_expand(gf, cur);
|
|
|
|
|
|
|
|
- // compute expert outputs
|
|
|
|
|
- ggml_tensor * moe_out = nullptr;
|
|
|
|
|
|
|
+ return gf;
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- for (int i = 0; i < n_expert_used; ++i) {
|
|
|
|
|
- ggml_tensor * cur_expert;
|
|
|
|
|
|
|
+ struct ggml_cgraph * build_dbrx() {
|
|
|
|
|
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
|
|
|
|
|
|
|
- ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur);
|
|
|
|
|
- cb(cur_up, "ffn_moe_up", il);
|
|
|
|
|
|
|
+ // mutable variable, needed during the last layer of the computation to skip unused tokens
|
|
|
|
|
+ int32_t n_tokens = this->n_tokens;
|
|
|
|
|
|
|
|
- ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur);
|
|
|
|
|
- cb(cur_gate, "ffn_moe_gate", il);
|
|
|
|
|
|
|
+ const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
|
|
|
+ const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
|
|
|
|
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
|
|
|
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
|
|
|
|
|
- //GeLU
|
|
|
|
|
- cur_gate = ggml_gelu(ctx0, cur_gate);
|
|
|
|
|
- cb(cur_gate, "ffn_moe_gelu", il);
|
|
|
|
|
|
|
+ struct ggml_tensor * cur;
|
|
|
|
|
+ struct ggml_tensor * inpL;
|
|
|
|
|
|
|
|
- cur_expert = ggml_mul(ctx0, cur_up, cur_gate);
|
|
|
|
|
- cb(cur_expert, "ffn_moe_gate_par", il);
|
|
|
|
|
|
|
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
|
|
|
|
|
|
|
- cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd]
|
|
|
|
|
- cb(cur_expert, "ffn_moe_down", il);
|
|
|
|
|
|
|
+ // inp_pos - contains the positions
|
|
|
|
|
+ struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
|
|
|
|
- cur_expert = ggml_mul(ctx0, cur_expert,
|
|
|
|
|
- ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
|
|
|
|
|
- cb(cur_expert, "ffn_moe_weighted", il);
|
|
|
|
|
|
|
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
|
|
|
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
|
|
|
|
- if (i == 0) {
|
|
|
|
|
- moe_out = cur_expert;
|
|
|
|
|
- } else {
|
|
|
|
|
- moe_out = ggml_add(ctx0, moe_out, cur_expert);
|
|
|
|
|
- cb(moe_out, "ffn_moe_out", il);
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ for (int il = 0; il < n_layer; ++il) {
|
|
|
|
|
+ struct ggml_tensor * inpSA = inpL;
|
|
|
|
|
|
|
|
- cur = moe_out;
|
|
|
|
|
|
|
+ // norm
|
|
|
|
|
+ cur = llm_build_norm(ctx0, inpL, hparams,
|
|
|
|
|
+ model.layers[il].attn_norm, NULL,
|
|
|
|
|
+ LLM_NORM, cb, il);
|
|
|
|
|
+ cb(cur, "attn_norm", il);
|
|
|
|
|
|
|
|
- // Grok
|
|
|
|
|
- // if layer_out_norm is present then apply it before adding the input
|
|
|
|
|
- // Idea: maybe ffn_out_norm is a better name
|
|
|
|
|
- if (model.layers[il].layer_out_norm) {
|
|
|
|
|
- cur = llm_build_norm(ctx0, cur, hparams,
|
|
|
|
|
- model.layers[il].layer_out_norm, NULL,
|
|
|
|
|
- LLM_NORM_RMS, cb, il);
|
|
|
|
|
- cb(cur, "layer_out_norm", il);
|
|
|
|
|
|
|
+ // self-attention
|
|
|
|
|
+ {
|
|
|
|
|
+ struct ggml_tensor * Qcur = nullptr;
|
|
|
|
|
+ struct ggml_tensor * Kcur = nullptr;
|
|
|
|
|
+ struct ggml_tensor * Vcur = nullptr;
|
|
|
|
|
+
|
|
|
|
|
+ cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
|
|
|
|
|
+ cb(cur, "wqkv", il);
|
|
|
|
|
+
|
|
|
|
|
+ cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
|
|
|
|
|
+ cb(cur, "wqkv_clamped", il);
|
|
|
|
|
+
|
|
|
|
|
+ Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
|
|
|
|
+ Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
|
|
|
|
+ Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
|
|
|
|
+
|
|
|
|
|
+ cb(Qcur, "Qcur", il);
|
|
|
|
|
+ cb(Kcur, "Kcur", il);
|
|
|
|
|
+ cb(Vcur, "Vcur", il);
|
|
|
|
|
+
|
|
|
|
|
+ Qcur = ggml_rope_custom(
|
|
|
|
|
+ ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
|
|
|
|
|
+ n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
|
|
|
|
+ ext_factor, attn_factor, beta_fast, beta_slow
|
|
|
|
|
+ );
|
|
|
|
|
+ cb(Qcur, "Qcur", il);
|
|
|
|
|
+
|
|
|
|
|
+ Kcur = ggml_rope_custom(
|
|
|
|
|
+ ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
|
|
|
|
|
+ n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
|
|
|
|
+ ext_factor, attn_factor, beta_fast, beta_slow
|
|
|
|
|
+ );
|
|
|
|
|
+ cb(Kcur, "Kcur", il);
|
|
|
|
|
+
|
|
|
|
|
+ cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
|
|
|
|
+ model.layers[il].wo, NULL,
|
|
|
|
|
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ if (il == n_layer - 1) {
|
|
|
|
|
+ // skip computing output for unused tokens
|
|
|
|
|
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
|
|
|
+ n_tokens = n_outputs;
|
|
|
|
|
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
|
|
|
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
|
|
|
+ cb(ffn_inp, "ffn_inp", il);
|
|
|
|
|
+
|
|
|
|
|
+ // feed-forward network
|
|
|
|
|
+ // MoE branch
|
|
|
|
|
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
|
|
|
+ model.layers[il].attn_out_norm, NULL,
|
|
|
|
|
+ LLM_NORM, cb, il);
|
|
|
|
|
+ cb(cur, "attn_out_norm", il);
|
|
|
|
|
+
|
|
|
|
|
+ cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, il);
|
|
|
|
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
|
cb(cur, "ffn_out", il);
|
|
cb(cur, "ffn_out", il);
|
|
@@ -7052,18 +7201,13 @@ struct llm_build_context {
|
|
|
cur = inpL;
|
|
cur = inpL;
|
|
|
|
|
|
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
cur = llm_build_norm(ctx0, cur, hparams,
|
|
|
- model.output_norm, NULL,
|
|
|
|
|
- LLM_NORM_RMS, cb, -1);
|
|
|
|
|
|
|
+ model.output_norm, NULL,
|
|
|
|
|
+ LLM_NORM, cb, -1);
|
|
|
cb(cur, "result_norm", -1);
|
|
cb(cur, "result_norm", -1);
|
|
|
|
|
|
|
|
// lm_head
|
|
// lm_head
|
|
|
cur = ggml_mul_mat(ctx0, model.output, cur);
|
|
cur = ggml_mul_mat(ctx0, model.output, cur);
|
|
|
|
|
|
|
|
- // Grok
|
|
|
|
|
- // multiply logits by output_multiplier_scale of 0.5773502691896257
|
|
|
|
|
-
|
|
|
|
|
- cur = ggml_scale(ctx0, cur, 0.5773502691896257f);
|
|
|
|
|
-
|
|
|
|
|
cb(cur, "result_output", -1);
|
|
cb(cur, "result_output", -1);
|
|
|
|
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
ggml_build_forward_expand(gf, cur);
|
|
@@ -9785,6 +9929,10 @@ static struct ggml_cgraph * llama_build_graph(
|
|
|
{
|
|
{
|
|
|
result = llm.build_command_r();
|
|
result = llm.build_command_r();
|
|
|
} break;
|
|
} break;
|
|
|
|
|
+ case LLM_ARCH_DBRX:
|
|
|
|
|
+ {
|
|
|
|
|
+ result = llm.build_dbrx();
|
|
|
|
|
+ } break;
|
|
|
default:
|
|
default:
|
|
|
GGML_ASSERT(false);
|
|
GGML_ASSERT(false);
|
|
|
}
|
|
}
|
|
@@ -14638,6 +14786,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
|
|
// the pairs of head values are offset by n_rot/2
|
|
// the pairs of head values are offset by n_rot/2
|
|
|
case LLM_ARCH_FALCON:
|
|
case LLM_ARCH_FALCON:
|
|
|
case LLM_ARCH_GROK:
|
|
case LLM_ARCH_GROK:
|
|
|
|
|
+ case LLM_ARCH_DBRX:
|
|
|
case LLM_ARCH_PERSIMMON:
|
|
case LLM_ARCH_PERSIMMON:
|
|
|
case LLM_ARCH_BERT:
|
|
case LLM_ARCH_BERT:
|
|
|
case LLM_ARCH_NOMIC_BERT:
|
|
case LLM_ARCH_NOMIC_BERT:
|