|
|
@@ -184,6 +184,7 @@ enum llm_arch {
|
|
|
LLM_ARCH_OLMOE,
|
|
|
LLM_ARCH_OPENELM,
|
|
|
LLM_ARCH_ARCTIC,
|
|
|
+ LLM_ARCH_DEEPSEEK,
|
|
|
LLM_ARCH_DEEPSEEK2,
|
|
|
LLM_ARCH_CHATGLM,
|
|
|
LLM_ARCH_BITNET,
|
|
|
@@ -239,6 +240,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|
|
{ LLM_ARCH_OLMOE, "olmoe" },
|
|
|
{ LLM_ARCH_OPENELM, "openelm" },
|
|
|
{ LLM_ARCH_ARCTIC, "arctic" },
|
|
|
+ { LLM_ARCH_DEEPSEEK, "deepseek" },
|
|
|
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
|
|
|
{ LLM_ARCH_CHATGLM, "chatglm" },
|
|
|
{ LLM_ARCH_BITNET, "bitnet" },
|
|
|
@@ -1309,6 +1311,33 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|
|
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
|
|
},
|
|
|
},
|
|
|
+ {
|
|
|
+ LLM_ARCH_DEEPSEEK,
|
|
|
+ {
|
|
|
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
|
|
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
|
|
+ { LLM_TENSOR_OUTPUT, "output" },
|
|
|
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
|
|
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
|
|
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
|
|
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
|
|
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
|
|
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
|
|
+ { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
|
|
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
|
|
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
|
|
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
|
|
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
|
|
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
|
|
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
|
|
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
|
|
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
|
|
+ { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
|
|
|
+ { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
|
|
+ { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
|
|
+ { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
|
|
+ },
|
|
|
+ },
|
|
|
{
|
|
|
LLM_ARCH_DEEPSEEK2,
|
|
|
{
|
|
|
@@ -1600,6 +1629,7 @@ enum llm_chat_template {
|
|
|
LLM_CHAT_TEMPLATE_EXAONE_3,
|
|
|
LLM_CHAT_TEMPLATE_RWKV_WORLD,
|
|
|
LLM_CHAT_TEMPLATE_GRANITE,
|
|
|
+ LLM_CHAT_TEMPLATE_GIGACHAT,
|
|
|
LLM_CHAT_TEMPLATE_UNKNOWN,
|
|
|
};
|
|
|
|
|
|
@@ -1631,6 +1661,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
|
|
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
|
|
|
{ "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
|
|
|
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
|
|
|
+ { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
|
|
|
};
|
|
|
|
|
|
static llm_arch llm_arch_from_string(const std::string & name) {
|
|
|
@@ -6094,6 +6125,19 @@ static void llm_load_hparams(
|
|
|
model.type = e_model::MODEL_UNKNOWN;
|
|
|
}
|
|
|
} break;
|
|
|
+ case LLM_ARCH_DEEPSEEK:
|
|
|
+ {
|
|
|
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
|
|
+ ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
|
|
|
+ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
|
|
+ ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
|
|
|
+ ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
|
|
|
+
|
|
|
+ switch (hparams.n_layer) {
|
|
|
+ case 28: model.type = e_model::MODEL_20B; break;
|
|
|
+ default: model.type = e_model::MODEL_UNKNOWN;
|
|
|
+ }
|
|
|
+ } break;
|
|
|
case LLM_ARCH_DEEPSEEK2:
|
|
|
{
|
|
|
bool is_lite = (hparams.n_layer == 27);
|
|
|
@@ -6440,6 +6484,7 @@ static void llm_load_vocab(
|
|
|
tokenizer_pre == "phi-2" ||
|
|
|
tokenizer_pre == "jina-es" ||
|
|
|
tokenizer_pre == "jina-de" ||
|
|
|
+ tokenizer_pre == "gigachat" ||
|
|
|
tokenizer_pre == "jina-v1-en" ||
|
|
|
tokenizer_pre == "jina-v2-es" ||
|
|
|
tokenizer_pre == "jina-v2-de" ||
|
|
|
@@ -7091,6 +7136,13 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
|
|
|
|
|
LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len);
|
|
|
|
|
|
+ if (model.arch == LLM_ARCH_DEEPSEEK) {
|
|
|
+ LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
|
|
|
+ LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
|
|
+ LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
|
|
|
+ LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
|
|
|
+ }
|
|
|
+
|
|
|
if (model.arch == LLM_ARCH_DEEPSEEK2) {
|
|
|
LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
|
|
|
LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q);
|
|
|
@@ -8865,6 +8917,55 @@ static bool llm_load_tensors(
|
|
|
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
|
|
|
}
|
|
|
} break;
|
|
|
+ case LLM_ARCH_DEEPSEEK:
|
|
|
+ {
|
|
|
+
|
|
|
+ const int64_t n_ff_exp = hparams.n_ff_exp;
|
|
|
+ const int64_t n_expert_shared = hparams.n_expert_shared;
|
|
|
+
|
|
|
+ model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
+
|
|
|
+ // output
|
|
|
+ model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
|
|
+ model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
|
|
+
|
|
|
+ for (int i = 0; i < n_layer; ++i) {
|
|
|
+ auto & layer = model.layers[i];
|
|
|
+
|
|
|
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
|
|
+
|
|
|
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
|
|
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
|
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
|
|
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
|
|
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
|
|
+
|
|
|
+ if (i < (int) hparams.n_layer_dense_lead) {
|
|
|
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
|
|
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
|
|
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
|
|
+ } else {
|
|
|
+ layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
|
|
+
|
|
|
+ if (n_expert == 0) {
|
|
|
+ throw std::runtime_error("n_expert must be > 0");
|
|
|
+ }
|
|
|
+ if (n_expert_used == 0) {
|
|
|
+ throw std::runtime_error("n_expert_used must be > 0");
|
|
|
+ }
|
|
|
+
|
|
|
+ // MoE branch
|
|
|
+ layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
|
|
|
+ layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
|
|
|
+ layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
|
|
|
+
|
|
|
+ // Shared expert branch
|
|
|
+ layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
|
|
|
+ layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0);
|
|
|
+ layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } break;
|
|
|
case LLM_ARCH_DEEPSEEK2:
|
|
|
{
|
|
|
const bool is_lite = (hparams.n_layer == 27);
|
|
|
@@ -15219,6 +15320,161 @@ struct llm_build_context {
|
|
|
return gf;
|
|
|
}
|
|
|
|
|
|
+ struct ggml_cgraph * build_deepseek() {
|
|
|
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
+
|
|
|
+ // mutable variable, needed during the last layer of the computation to skip unused tokens
|
|
|
+ int32_t n_tokens = this->n_tokens;
|
|
|
+
|
|
|
+ const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
|
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
|
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
+
|
|
|
+ struct ggml_tensor * cur;
|
|
|
+ struct ggml_tensor * inpL;
|
|
|
+
|
|
|
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
|
|
+
|
|
|
+ // inp_pos - contains the positions
|
|
|
+ struct ggml_tensor * inp_pos = build_inp_pos();
|
|
|
+
|
|
|
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
|
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
+ const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
|
|
+ for (int il = 0; il < n_layer; ++il) {
|
|
|
+ struct ggml_tensor * inpSA = inpL;
|
|
|
+
|
|
|
+ // norm
|
|
|
+ cur = llm_build_norm(ctx0, inpL, hparams,
|
|
|
+ model.layers[il].attn_norm, NULL,
|
|
|
+ LLM_NORM_RMS, cb, il);
|
|
|
+ cb(cur, "attn_norm", il);
|
|
|
+
|
|
|
+ // self-attention
|
|
|
+ {
|
|
|
+ // rope freq factors for llama3; may return nullptr for llama2 and other models
|
|
|
+ struct ggml_tensor * rope_factors = build_rope_factors(il);
|
|
|
+
|
|
|
+ // compute Q and K and RoPE them
|
|
|
+ struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
|
+ cb(Qcur, "Qcur", il);
|
|
|
+ if (model.layers[il].bq) {
|
|
|
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
|
|
+ cb(Qcur, "Qcur", il);
|
|
|
+ }
|
|
|
+
|
|
|
+ struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
|
+ cb(Kcur, "Kcur", il);
|
|
|
+ if (model.layers[il].bk) {
|
|
|
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
|
|
+ cb(Kcur, "Kcur", il);
|
|
|
+ }
|
|
|
+
|
|
|
+ struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
|
+ cb(Vcur, "Vcur", il);
|
|
|
+ if (model.layers[il].bv) {
|
|
|
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
|
|
+ cb(Vcur, "Vcur", il);
|
|
|
+ }
|
|
|
+
|
|
|
+ Qcur = ggml_rope_ext(
|
|
|
+ ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
|
|
|
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
|
+ ext_factor, attn_factor, beta_fast, beta_slow
|
|
|
+ );
|
|
|
+ cb(Qcur, "Qcur", il);
|
|
|
+
|
|
|
+ Kcur = ggml_rope_ext(
|
|
|
+ ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
|
|
|
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
|
+ ext_factor, attn_factor, beta_fast, beta_slow
|
|
|
+ );
|
|
|
+ cb(Kcur, "Kcur", il);
|
|
|
+
|
|
|
+ cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
|
+ model.layers[il].wo, model.layers[il].bo,
|
|
|
+ Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
|
|
|
+ }
|
|
|
+
|
|
|
+ if (il == n_layer - 1) {
|
|
|
+ // skip computing output for unused tokens
|
|
|
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
|
+ n_tokens = n_outputs;
|
|
|
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
|
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
|
+ cb(ffn_inp, "ffn_inp", il);
|
|
|
+
|
|
|
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
|
+ model.layers[il].ffn_norm, NULL,
|
|
|
+ LLM_NORM_RMS, cb, il);
|
|
|
+ cb(cur, "ffn_norm", il);
|
|
|
+
|
|
|
+ if ((uint32_t) il < hparams.n_layer_dense_lead) {
|
|
|
+ cur = llm_build_ffn(ctx0, lctx, cur,
|
|
|
+ model.layers[il].ffn_up, NULL, NULL,
|
|
|
+ model.layers[il].ffn_gate, NULL, NULL,
|
|
|
+ model.layers[il].ffn_down, NULL, NULL,
|
|
|
+ NULL,
|
|
|
+ LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
|
+ cb(cur, "ffn_out", il);
|
|
|
+ } else {
|
|
|
+ // MoE branch
|
|
|
+ ggml_tensor * moe_out =
|
|
|
+ llm_build_moe_ffn(ctx0, lctx, cur,
|
|
|
+ model.layers[il].ffn_gate_inp,
|
|
|
+ model.layers[il].ffn_up_exps,
|
|
|
+ model.layers[il].ffn_gate_exps,
|
|
|
+ model.layers[il].ffn_down_exps,
|
|
|
+ n_expert, n_expert_used,
|
|
|
+ LLM_FFN_SILU, false,
|
|
|
+ false, hparams.expert_weights_scale,
|
|
|
+ cb, il);
|
|
|
+ cb(moe_out, "ffn_moe_out", il);
|
|
|
+
|
|
|
+ // FFN shared expert
|
|
|
+ {
|
|
|
+ ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, cur,
|
|
|
+ model.layers[il].ffn_up_shexp, NULL, NULL,
|
|
|
+ model.layers[il].ffn_gate_shexp, NULL, NULL,
|
|
|
+ model.layers[il].ffn_down_shexp, NULL, NULL,
|
|
|
+ NULL,
|
|
|
+ LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
|
|
+ cb(ffn_shexp, "ffn_shexp", il);
|
|
|
+
|
|
|
+ cur = ggml_add(ctx0, moe_out, ffn_shexp);
|
|
|
+ cb(cur, "ffn_out", il);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ cur = ggml_add(ctx0, cur, ffn_inp);
|
|
|
+ cur = lctx.cvec.apply_to(ctx0, cur, il);
|
|
|
+ cb(cur, "l_out", il);
|
|
|
+
|
|
|
+ // input for next layer
|
|
|
+ inpL = cur;
|
|
|
+ }
|
|
|
+
|
|
|
+ cur = inpL;
|
|
|
+
|
|
|
+ cur = llm_build_norm(ctx0, cur, hparams,
|
|
|
+ model.output_norm, NULL,
|
|
|
+ LLM_NORM_RMS, cb, -1);
|
|
|
+ cb(cur, "result_norm", -1);
|
|
|
+
|
|
|
+ // lm_head
|
|
|
+ cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
|
+
|
|
|
+ cb(cur, "result_output", -1);
|
|
|
+
|
|
|
+ ggml_build_forward_expand(gf, cur);
|
|
|
+
|
|
|
+ return gf;
|
|
|
+ }
|
|
|
+
|
|
|
struct ggml_cgraph * build_deepseek2() {
|
|
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
|
|
|
|
|
@@ -16906,6 +17162,10 @@ static struct ggml_cgraph * llama_build_graph(
|
|
|
{
|
|
|
result = llm.build_arctic();
|
|
|
} break;
|
|
|
+ case LLM_ARCH_DEEPSEEK:
|
|
|
+ {
|
|
|
+ result = llm.build_deepseek();
|
|
|
+ } break;
|
|
|
case LLM_ARCH_DEEPSEEK2:
|
|
|
{
|
|
|
result = llm.build_deepseek2();
|
|
|
@@ -20137,6 +20397,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
|
|
case LLM_ARCH_COMMAND_R:
|
|
|
case LLM_ARCH_OLMO:
|
|
|
case LLM_ARCH_ARCTIC:
|
|
|
+ case LLM_ARCH_DEEPSEEK:
|
|
|
case LLM_ARCH_DEEPSEEK2:
|
|
|
case LLM_ARCH_CHATGLM:
|
|
|
case LLM_ARCH_GRANITE:
|
|
|
@@ -22002,6 +22263,8 @@ static llm_chat_template llama_chat_detect_template(const std::string & tmpl) {
|
|
|
return LLM_CHAT_TEMPLATE_RWKV_WORLD;
|
|
|
} else if (tmpl_contains("<|start_of_role|>")) {
|
|
|
return LLM_CHAT_TEMPLATE_GRANITE;
|
|
|
+ } else if (tmpl_contains("message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1]")) {
|
|
|
+ return LLM_CHAT_TEMPLATE_GIGACHAT;
|
|
|
}
|
|
|
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
|
|
}
|
|
|
@@ -22325,6 +22588,32 @@ static int32_t llama_chat_apply_template_internal(
|
|
|
if (add_ass) {
|
|
|
ss << "<|start_of_role|>assistant<|end_of_role|>\n";
|
|
|
}
|
|
|
+ } else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) {
|
|
|
+ // GigaChat template
|
|
|
+ bool has_system = !chat.empty() && std::string(chat[0]->role) == "system";
|
|
|
+
|
|
|
+ // Handle system message if present
|
|
|
+ if (has_system) {
|
|
|
+ ss << "<s>" << chat[0]->content << "<|message_sep|>";
|
|
|
+ } else {
|
|
|
+ ss << "<s>";
|
|
|
+ }
|
|
|
+
|
|
|
+ // Process remaining messages
|
|
|
+ for (size_t i = has_system ? 1 : 0; i < chat.size(); i++) {
|
|
|
+ std::string role(chat[i]->role);
|
|
|
+ if (role == "user") {
|
|
|
+ ss << "user<|role_sep|>" << chat[i]->content << "<|message_sep|>"
|
|
|
+ << "available functions<|role_sep|>[]<|message_sep|>";
|
|
|
+ } else if (role == "assistant") {
|
|
|
+ ss << "assistant<|role_sep|>" << chat[i]->content << "<|message_sep|>";
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Add generation prompt if needed
|
|
|
+ if (add_ass) {
|
|
|
+ ss << "assistant<|role_sep|>";
|
|
|
+ }
|
|
|
} else {
|
|
|
// template not supported
|
|
|
return -1;
|