|
|
@@ -160,17 +160,19 @@ enum llm_arch {
|
|
|
LLM_ARCH_GPTJ,
|
|
|
LLM_ARCH_GPTNEOX,
|
|
|
LLM_ARCH_MPT,
|
|
|
+ LLM_ARCH_STARCODER,
|
|
|
LLM_ARCH_UNKNOWN,
|
|
|
};
|
|
|
|
|
|
static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
|
|
|
- { LLM_ARCH_LLAMA, "llama" },
|
|
|
- { LLM_ARCH_FALCON, "falcon" },
|
|
|
- { LLM_ARCH_GPT2, "gpt2" },
|
|
|
- { LLM_ARCH_GPTJ, "gptj" },
|
|
|
- { LLM_ARCH_GPTNEOX, "gptneox" },
|
|
|
- { LLM_ARCH_MPT, "mpt" },
|
|
|
- { LLM_ARCH_BAICHUAN,"baichuan" },
|
|
|
+ { LLM_ARCH_LLAMA, "llama" },
|
|
|
+ { LLM_ARCH_FALCON, "falcon" },
|
|
|
+ { LLM_ARCH_GPT2, "gpt2" },
|
|
|
+ { LLM_ARCH_GPTJ, "gptj" },
|
|
|
+ { LLM_ARCH_GPTNEOX, "gptneox" },
|
|
|
+ { LLM_ARCH_MPT, "mpt" },
|
|
|
+ { LLM_ARCH_BAICHUAN, "baichuan" },
|
|
|
+ { LLM_ARCH_STARCODER, "starcoder" },
|
|
|
};
|
|
|
|
|
|
enum llm_kv {
|
|
|
@@ -376,6 +378,21 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
|
|
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
|
|
},
|
|
|
},
|
|
|
+ {
|
|
|
+ LLM_ARCH_STARCODER,
|
|
|
+ {
|
|
|
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
|
|
+ { LLM_TENSOR_POS_EMBD, "position_embd" },
|
|
|
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
|
|
+ { LLM_TENSOR_OUTPUT, "output" },
|
|
|
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
|
|
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
|
|
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
|
|
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
|
|
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
|
|
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
|
|
+ },
|
|
|
+ },
|
|
|
{
|
|
|
LLM_ARCH_UNKNOWN,
|
|
|
{
|
|
|
@@ -895,9 +912,11 @@ static llama_state g_state;
|
|
|
// available llama models
|
|
|
enum e_model {
|
|
|
MODEL_UNKNOWN,
|
|
|
+ MODEL_1B,
|
|
|
MODEL_3B,
|
|
|
MODEL_7B,
|
|
|
MODEL_13B,
|
|
|
+ MODEL_15B,
|
|
|
MODEL_30B,
|
|
|
MODEL_34B,
|
|
|
MODEL_40B,
|
|
|
@@ -966,13 +985,22 @@ struct llama_layer {
|
|
|
struct ggml_tensor * wo;
|
|
|
struct ggml_tensor * wqkv;
|
|
|
|
|
|
+ // attention bias
|
|
|
+ struct ggml_tensor * bo;
|
|
|
+ struct ggml_tensor * bqkv;
|
|
|
+
|
|
|
// normalization
|
|
|
struct ggml_tensor * ffn_norm;
|
|
|
+ struct ggml_tensor * ffn_norm_b;
|
|
|
|
|
|
// ff
|
|
|
struct ggml_tensor * w1; // ffn_gate
|
|
|
struct ggml_tensor * w2; // ffn_down
|
|
|
struct ggml_tensor * w3; // ffn_up
|
|
|
+
|
|
|
+ // ff bias
|
|
|
+ struct ggml_tensor * b2; // ffn_down
|
|
|
+ struct ggml_tensor * b3; // ffn_up
|
|
|
};
|
|
|
|
|
|
struct llama_kv_cache {
|
|
|
@@ -1050,6 +1078,7 @@ struct llama_model {
|
|
|
llama_vocab vocab;
|
|
|
|
|
|
struct ggml_tensor * tok_embeddings;
|
|
|
+ struct ggml_tensor * pos_embeddings;
|
|
|
|
|
|
struct ggml_tensor * output_norm;
|
|
|
struct ggml_tensor * output_norm_b;
|
|
|
@@ -1593,9 +1622,11 @@ std::string llama_model_ftype_name(enum llama_ftype ftype) {
|
|
|
|
|
|
static const char * llama_model_type_name(e_model type) {
|
|
|
switch (type) {
|
|
|
+ case MODEL_1B: return "1B";
|
|
|
case MODEL_3B: return "3B";
|
|
|
case MODEL_7B: return "7B";
|
|
|
case MODEL_13B: return "13B";
|
|
|
+ case MODEL_15B: return "15B";
|
|
|
case MODEL_30B: return "30B";
|
|
|
case MODEL_34B: return "34B";
|
|
|
case MODEL_40B: return "40B";
|
|
|
@@ -1713,6 +1744,17 @@ static void llm_load_hparams(
|
|
|
default: model.type = e_model::MODEL_UNKNOWN;
|
|
|
}
|
|
|
} break;
|
|
|
+ case LLM_ARCH_STARCODER:
|
|
|
+ {
|
|
|
+ GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
|
|
|
+ switch (hparams.n_layer) {
|
|
|
+ case 24: model.type = e_model::MODEL_1B; break;
|
|
|
+ case 36: model.type = e_model::MODEL_3B; break;
|
|
|
+ case 42: model.type = e_model::MODEL_7B; break;
|
|
|
+ case 40: model.type = e_model::MODEL_15B; break;
|
|
|
+ default: model.type = e_model::MODEL_UNKNOWN;
|
|
|
+ }
|
|
|
+ } break;
|
|
|
default: (void)0;
|
|
|
};
|
|
|
|
|
|
@@ -2166,6 +2208,85 @@ static void llm_load_tensors(
|
|
|
}
|
|
|
}
|
|
|
} break;
|
|
|
+ case LLM_ARCH_STARCODER:
|
|
|
+ {
|
|
|
+ model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
|
|
|
+ model.pos_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}, GGML_BACKEND_CPU);
|
|
|
+
|
|
|
+ // output
|
|
|
+ {
|
|
|
+ ggml_backend backend_norm;
|
|
|
+ ggml_backend backend_output;
|
|
|
+
|
|
|
+ if (n_gpu_layers > int(n_layer)) {
|
|
|
+ // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
|
|
|
+ // on Windows however this is detrimental unless everything is on the GPU
|
|
|
+#ifndef _WIN32
|
|
|
+ backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
|
|
|
+#else
|
|
|
+ backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
|
|
|
+#endif // _WIN32
|
|
|
+
|
|
|
+ backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
|
|
|
+ } else {
|
|
|
+ backend_norm = GGML_BACKEND_CPU;
|
|
|
+ backend_output = GGML_BACKEND_CPU;
|
|
|
+ }
|
|
|
+
|
|
|
+ model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
|
|
|
+ model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm);
|
|
|
+ model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
|
|
|
+
|
|
|
+ if (backend_norm == GGML_BACKEND_GPU) {
|
|
|
+ vram_weights += ggml_nbytes(model.output_norm);
|
|
|
+ vram_weights += ggml_nbytes(model.output_norm_b);
|
|
|
+ }
|
|
|
+ if (backend_output == GGML_BACKEND_GPU_SPLIT) {
|
|
|
+ vram_weights += ggml_nbytes(model.output);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ const uint32_t n_ff = hparams.n_ff;
|
|
|
+
|
|
|
+ const int i_gpu_start = n_layer - n_gpu_layers;
|
|
|
+
|
|
|
+ model.layers.resize(n_layer);
|
|
|
+
|
|
|
+ for (uint32_t i = 0; i < n_layer; ++i) {
|
|
|
+ const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
|
|
|
+ const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
|
|
|
+
|
|
|
+ auto & layer = model.layers[i];
|
|
|
+
|
|
|
+ layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
|
|
|
+ layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend);
|
|
|
+
|
|
|
+ layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
|
|
|
+ layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend_split);
|
|
|
+
|
|
|
+ layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
|
|
|
+ layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend_split);
|
|
|
+
|
|
|
+ layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
|
|
|
+ layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend);
|
|
|
+
|
|
|
+ layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split);
|
|
|
+ layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend_split);
|
|
|
+
|
|
|
+ layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
|
|
+ layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend_split);
|
|
|
+
|
|
|
+ if (backend == GGML_BACKEND_GPU) {
|
|
|
+ vram_weights +=
|
|
|
+ ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) +
|
|
|
+ ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.bqkv) +
|
|
|
+ ggml_nbytes(layer.wo) + ggml_nbytes(layer.bo) +
|
|
|
+ ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_norm_b) +
|
|
|
+ ggml_nbytes(layer.w2) + ggml_nbytes(layer.b2) +
|
|
|
+ ggml_nbytes(layer.w3) + ggml_nbytes(layer.b3);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } break;
|
|
|
default:
|
|
|
throw std::runtime_error("unknown architecture");
|
|
|
};
|
|
|
@@ -3305,6 +3426,235 @@ static struct ggml_cgraph * llm_build_falcon(
|
|
|
return gf;
|
|
|
}
|
|
|
|
|
|
+static struct ggml_cgraph * llm_build_starcoder(
|
|
|
+ llama_context & lctx,
|
|
|
+ const llama_token * tokens,
|
|
|
+ const float * embd,
|
|
|
+ int n_tokens,
|
|
|
+ int n_past) {
|
|
|
+
|
|
|
+ GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT
|
|
|
+
|
|
|
+ const int N = n_tokens;
|
|
|
+
|
|
|
+ const auto & model = lctx.model;
|
|
|
+ const auto & hparams = model.hparams;
|
|
|
+
|
|
|
+ const auto & kv_self = lctx.kv_self;
|
|
|
+
|
|
|
+ GGML_ASSERT(!!kv_self.ctx);
|
|
|
+
|
|
|
+ const int64_t n_embd = hparams.n_embd;
|
|
|
+ const int64_t n_layer = hparams.n_layer;
|
|
|
+ const int64_t n_ctx = hparams.n_ctx;
|
|
|
+ const int64_t n_head = hparams.n_head;
|
|
|
+ const int64_t n_head_kv = hparams.n_head_kv;
|
|
|
+ const int64_t n_embd_head = hparams.n_embd_head();
|
|
|
+ const int64_t n_embd_gqa = hparams.n_embd_gqa();
|
|
|
+
|
|
|
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
+
|
|
|
+ const float norm_eps = hparams.f_norm_eps;
|
|
|
+
|
|
|
+ auto & buf_compute = lctx.buf_compute;
|
|
|
+
|
|
|
+ struct ggml_init_params params = {
|
|
|
+ /*.mem_size =*/ buf_compute.size,
|
|
|
+ /*.mem_buffer =*/ buf_compute.data,
|
|
|
+ /*.no_alloc =*/ false,
|
|
|
+ };
|
|
|
+
|
|
|
+ params.no_alloc = true;
|
|
|
+
|
|
|
+ struct ggml_context * ctx0 = ggml_init(params);
|
|
|
+
|
|
|
+ ggml_cgraph * gf = ggml_new_graph(ctx0);
|
|
|
+
|
|
|
+ struct ggml_tensor * cur;
|
|
|
+ struct ggml_tensor * token;
|
|
|
+ struct ggml_tensor * position;
|
|
|
+ struct ggml_tensor * inpL;
|
|
|
+
|
|
|
+ if (tokens) {
|
|
|
+ struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
|
|
+
|
|
|
+ ggml_allocr_alloc(lctx.alloc, inp_tokens);
|
|
|
+ if (!ggml_allocr_is_measure(lctx.alloc)) {
|
|
|
+ memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
|
|
|
+ }
|
|
|
+ ggml_set_name(inp_tokens, "inp_tokens");
|
|
|
+
|
|
|
+ token = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
|
|
|
+ } else {
|
|
|
+#ifdef GGML_USE_MPI
|
|
|
+ GGML_ASSERT(false && "not implemented");
|
|
|
+#endif
|
|
|
+
|
|
|
+ token = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N);
|
|
|
+
|
|
|
+ ggml_allocr_alloc(lctx.alloc, token);
|
|
|
+ if (!ggml_allocr_is_measure(lctx.alloc)) {
|
|
|
+ memcpy(token->data, embd, N * n_embd * ggml_element_size(inpL));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ {
|
|
|
+ // Compute position embeddings.
|
|
|
+ struct ggml_tensor * inp_positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
|
|
+ ggml_allocr_alloc(lctx.alloc, inp_positions);
|
|
|
+ if (!ggml_allocr_is_measure(lctx.alloc)) {
|
|
|
+ for (int i = 0; i < N; ++i) {
|
|
|
+ ((int32_t *) inp_positions->data)[i] = n_past + i;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ ggml_set_name(inp_positions, "inp_positions");
|
|
|
+
|
|
|
+ position = ggml_get_rows(ctx0, model.pos_embeddings, inp_positions);
|
|
|
+ }
|
|
|
+
|
|
|
+ struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
|
|
+ ggml_allocr_alloc(lctx.alloc, KQ_scale);
|
|
|
+ if (!ggml_allocr_is_measure(lctx.alloc)) {
|
|
|
+ ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
|
|
|
+ }
|
|
|
+ ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
|
|
|
+
|
|
|
+ inpL = ggml_add(ctx0, token, position);
|
|
|
+ ggml_set_name(inpL, "inpL");
|
|
|
+
|
|
|
+ for (int il = 0; il < n_layer; ++il) {
|
|
|
+ {
|
|
|
+ // Norm
|
|
|
+ cur = ggml_norm(ctx0, inpL, norm_eps);
|
|
|
+ cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].attn_norm), model.layers[il].attn_norm_b);
|
|
|
+ }
|
|
|
+
|
|
|
+ {
|
|
|
+ // Self Attention
|
|
|
+ cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wqkv, cur), model.layers[il].bqkv);
|
|
|
+
|
|
|
+ struct ggml_tensor * tmpq = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd);
|
|
|
+ struct ggml_tensor * tmpk = ggml_view_2d(ctx0, cur, n_embd_gqa, N, cur->nb[1], sizeof(float)*n_embd);
|
|
|
+ struct ggml_tensor * tmpv = ggml_view_2d(ctx0, cur, n_embd_gqa, N, cur->nb[1], sizeof(float)*(n_embd + n_embd_gqa));
|
|
|
+
|
|
|
+ struct ggml_tensor * Qcur = tmpq;
|
|
|
+ struct ggml_tensor * Kcur = tmpk;
|
|
|
+
|
|
|
+ {
|
|
|
+ struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, N));
|
|
|
+ ggml_set_name(Vcur, "Vcur");
|
|
|
+
|
|
|
+ struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past));
|
|
|
+ ggml_set_name(k, "k");
|
|
|
+
|
|
|
+ struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa,
|
|
|
+ ( n_ctx)*ggml_element_size(kv_self.v),
|
|
|
+ (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v));
|
|
|
+
|
|
|
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
|
|
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
|
|
+ }
|
|
|
+
|
|
|
+ struct ggml_tensor * Q =
|
|
|
+ ggml_permute(ctx0,
|
|
|
+ ggml_cpy(ctx0,
|
|
|
+ Qcur,
|
|
|
+ ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd_head, n_head, N)),
|
|
|
+ 0, 2, 1, 3);
|
|
|
+ ggml_set_name(Q, "Q");
|
|
|
+
|
|
|
+ struct ggml_tensor * K =
|
|
|
+ ggml_view_3d(ctx0, kv_self.k,
|
|
|
+ n_embd_head, n_past + N, n_head_kv,
|
|
|
+ ggml_element_size(kv_self.k)*n_embd_gqa,
|
|
|
+ ggml_element_size(kv_self.k)*n_embd_head,
|
|
|
+ ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
|
|
|
+ ggml_set_name(K, "K");
|
|
|
+
|
|
|
+ // K * Q
|
|
|
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
|
|
+ ggml_set_name(KQ, "KQ");
|
|
|
+
|
|
|
+ // KQ_scaled = KQ / sqrt(n_embd_head)
|
|
|
+ // KQ_scaled shape [n_past + N, N, n_head, 1]
|
|
|
+ struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
|
|
|
+ ggml_set_name(KQ_scaled, "KQ_scaled");
|
|
|
+
|
|
|
+ // KQ_masked = mask_past(KQ_scaled)
|
|
|
+ struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
|
|
|
+ ggml_set_name(KQ_masked, "KQ_masked");
|
|
|
+
|
|
|
+ // KQ = soft_max(KQ_masked)
|
|
|
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
|
|
|
+ ggml_set_name(KQ_soft_max, "KQ_soft_max");
|
|
|
+
|
|
|
+ // split cached V into n_head heads
|
|
|
+ struct ggml_tensor * V =
|
|
|
+ ggml_view_3d(ctx0, kv_self.v,
|
|
|
+ n_past + N, n_embd_head, n_head_kv,
|
|
|
+ ggml_element_size(kv_self.v)*n_ctx,
|
|
|
+ ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
|
|
|
+ ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
|
|
|
+ ggml_set_name(V, "V");
|
|
|
+
|
|
|
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
|
|
+ ggml_set_name(KQV, "KQV");
|
|
|
+
|
|
|
+ // KQV_merged = KQV.permute(0, 2, 1, 3)
|
|
|
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
|
|
+ ggml_set_name(KQV_merged, "KQV_merged");
|
|
|
+
|
|
|
+ // cur = KQV_merged.contiguous().view(n_embd, N)
|
|
|
+ cur = ggml_cpy(ctx0,
|
|
|
+ KQV_merged,
|
|
|
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
|
|
|
+ ggml_set_name(cur, "KQV_merged_contiguous");
|
|
|
+ }
|
|
|
+
|
|
|
+ // Projection
|
|
|
+ cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wo, cur), model.layers[il].bo);
|
|
|
+
|
|
|
+ // Add the input
|
|
|
+ cur = ggml_add(ctx0, cur, inpL);
|
|
|
+
|
|
|
+ struct ggml_tensor * inpFF = cur;
|
|
|
+
|
|
|
+ // FF
|
|
|
+ {
|
|
|
+ // Norm
|
|
|
+ {
|
|
|
+ cur = ggml_norm(ctx0, inpFF, norm_eps);
|
|
|
+ cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ffn_norm), model.layers[il].ffn_norm_b);
|
|
|
+ }
|
|
|
+
|
|
|
+ cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w3, cur), model.layers[il].b3);
|
|
|
+
|
|
|
+ // GELU activation
|
|
|
+ cur = ggml_gelu(ctx0, cur);
|
|
|
+
|
|
|
+ // Projection
|
|
|
+ cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w2, cur), model.layers[il].b2);
|
|
|
+ }
|
|
|
+
|
|
|
+ inpL = ggml_add(ctx0, cur, inpFF);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Output Norm
|
|
|
+ {
|
|
|
+ cur = ggml_norm(ctx0, inpL, norm_eps);
|
|
|
+ cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.output_norm), model.output_norm_b);
|
|
|
+ }
|
|
|
+ ggml_set_name(cur, "result_norm");
|
|
|
+
|
|
|
+ cur = ggml_mul_mat(ctx0, model.output, cur);
|
|
|
+ ggml_set_name(cur, "result_output");
|
|
|
+
|
|
|
+ ggml_build_forward_expand(gf, cur);
|
|
|
+ ggml_free(ctx0);
|
|
|
+
|
|
|
+ return gf;
|
|
|
+}
|
|
|
+
|
|
|
static struct ggml_cgraph * llama_build_graph(
|
|
|
llama_context & lctx,
|
|
|
const llama_token * tokens,
|
|
|
@@ -3328,6 +3678,10 @@ static struct ggml_cgraph * llama_build_graph(
|
|
|
{
|
|
|
result = llm_build_falcon(lctx, tokens, embd, n_tokens, n_past);
|
|
|
} break;
|
|
|
+ case LLM_ARCH_STARCODER:
|
|
|
+ {
|
|
|
+ result = llm_build_starcoder(lctx, tokens, embd, n_tokens, n_past);
|
|
|
+ } break;
|
|
|
default:
|
|
|
GGML_ASSERT(false);
|
|
|
};
|