|
|
@@ -102,6 +102,9 @@ struct llama_context {
|
|
|
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
|
|
std::vector<float> logits;
|
|
|
bool logits_all = false;
|
|
|
+
|
|
|
+ // input embedding (1-dimensional array: [n_embd])
|
|
|
+ std::vector<float> embedding;
|
|
|
};
|
|
|
|
|
|
struct llama_context_params llama_context_default_params() {
|
|
|
@@ -112,6 +115,7 @@ struct llama_context_params llama_context_default_params() {
|
|
|
/*.f16_kv =*/ false,
|
|
|
/*.logits_all =*/ false,
|
|
|
/*.vocab_only =*/ false,
|
|
|
+ /*.embedding =*/ false,
|
|
|
};
|
|
|
|
|
|
return result;
|
|
|
@@ -592,8 +596,6 @@ static bool llama_model_load(
|
|
|
fin.close();
|
|
|
}
|
|
|
|
|
|
- lctx.logits.reserve(lctx.model.hparams.n_ctx);
|
|
|
-
|
|
|
lctx.t_load_us = ggml_time_us() - t_start_us;
|
|
|
|
|
|
return true;
|
|
|
@@ -791,6 +793,9 @@ static bool llama_eval_internal(
|
|
|
inpL = cur;
|
|
|
}
|
|
|
|
|
|
+ // used at the end to optionally extract the embeddings
|
|
|
+ struct ggml_tensor * embeddings = NULL;
|
|
|
+
|
|
|
// norm
|
|
|
{
|
|
|
inpL = ggml_rms_norm(ctx0, inpL);
|
|
|
@@ -799,6 +804,8 @@ static bool llama_eval_internal(
|
|
|
inpL = ggml_mul(ctx0,
|
|
|
ggml_repeat(ctx0, model.norm, inpL),
|
|
|
inpL);
|
|
|
+
|
|
|
+ embeddings = inpL;
|
|
|
}
|
|
|
|
|
|
// lm_head
|
|
|
@@ -821,15 +828,26 @@ static bool llama_eval_internal(
|
|
|
//embd_w.resize(n_vocab*N);
|
|
|
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
|
|
|
|
|
- auto & logits_out = lctx.logits;
|
|
|
+ // extract logits
|
|
|
+ {
|
|
|
+ auto & logits_out = lctx.logits;
|
|
|
+
|
|
|
+ if (lctx.logits_all) {
|
|
|
+ logits_out.resize(n_vocab * N);
|
|
|
+ memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
|
|
+ } else {
|
|
|
+ // return result for just the last token
|
|
|
+ logits_out.resize(n_vocab);
|
|
|
+ memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // extract embeddings
|
|
|
+ if (lctx.embedding.size()) {
|
|
|
+ auto & embedding_out = lctx.embedding;
|
|
|
|
|
|
- if (lctx.logits_all) {
|
|
|
- logits_out.resize(n_vocab * N);
|
|
|
- memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
|
|
- } else {
|
|
|
- // return result for just the last token
|
|
|
- logits_out.resize(n_vocab);
|
|
|
- memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
|
|
+ embedding_out.resize(n_embd);
|
|
|
+ memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
|
|
|
}
|
|
|
|
|
|
if (mem_per_token == 0) {
|
|
|
@@ -1416,6 +1434,20 @@ struct llama_context * llama_init_from_file(
|
|
|
return nullptr;
|
|
|
}
|
|
|
|
|
|
+ // reserve memory for context buffers
|
|
|
+ {
|
|
|
+ const auto & hparams = ctx->model.hparams;
|
|
|
+ if (params.logits_all) {
|
|
|
+ ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
|
|
|
+ } else {
|
|
|
+ ctx->logits.reserve(hparams.n_ctx);
|
|
|
+ }
|
|
|
+
|
|
|
+ if (params.embedding){
|
|
|
+ ctx->embedding.reserve(hparams.n_embd);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
return ctx;
|
|
|
}
|
|
|
|
|
|
@@ -1484,6 +1516,10 @@ float * llama_get_logits(struct llama_context * ctx) {
|
|
|
return ctx->logits.data();
|
|
|
}
|
|
|
|
|
|
+float * llama_get_embeddings(struct llama_context * ctx) {
|
|
|
+ return ctx->embedding.data();
|
|
|
+}
|
|
|
+
|
|
|
const char * llama_token_to_str(struct llama_context * ctx, llama_token token) {
|
|
|
if (token >= llama_n_vocab(ctx)) {
|
|
|
return nullptr;
|