|
@@ -254,6 +254,11 @@ struct llama_server_context {
|
|
|
n_past += n_eval;
|
|
n_past += n_eval;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ if (params.n_predict == 0) {
|
|
|
|
|
+ has_next_token = false;
|
|
|
|
|
+ return llama_token_eos();
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
// out of user input, sample next token
|
|
// out of user input, sample next token
|
|
|
const float temp = params.temp;
|
|
const float temp = params.temp;
|
|
|
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
|
|
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
|
|
@@ -419,6 +424,19 @@ struct llama_server_context {
|
|
|
|
|
|
|
|
return token_text;
|
|
return token_text;
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ std::vector<float> getEmbedding() {
|
|
|
|
|
+ static const int n_embd = llama_n_embd(ctx);
|
|
|
|
|
+ if (!params.embedding) {
|
|
|
|
|
+ LOG_WARNING("embedding disabled", {
|
|
|
|
|
+ { "params.embedding", params.embedding },
|
|
|
|
|
+ });
|
|
|
|
|
+ return std::vector<float>(n_embd, 0.0f);
|
|
|
|
|
+ }
|
|
|
|
|
+ const float * data = llama_get_embeddings(ctx);
|
|
|
|
|
+ std::vector<float> embedding(data, data + n_embd);
|
|
|
|
|
+ return embedding;
|
|
|
|
|
+ }
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
static void server_print_usage(const char * argv0, const gpt_params & params,
|
|
static void server_print_usage(const char * argv0, const gpt_params & params,
|
|
@@ -457,6 +475,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params,
|
|
|
fprintf(stderr, " --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
|
|
fprintf(stderr, " --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
|
|
|
fprintf(stderr, " --port PORT port to listen (default (default: %d)\n", sparams.port);
|
|
fprintf(stderr, " --port PORT port to listen (default (default: %d)\n", sparams.port);
|
|
|
fprintf(stderr, " -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
|
|
fprintf(stderr, " -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
|
|
|
|
|
+ fprintf(stderr, " --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
|
|
|
fprintf(stderr, "\n");
|
|
fprintf(stderr, "\n");
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -603,6 +622,8 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
|
|
|
params.use_mlock = true;
|
|
params.use_mlock = true;
|
|
|
} else if (arg == "--no-mmap") {
|
|
} else if (arg == "--no-mmap") {
|
|
|
params.use_mmap = false;
|
|
params.use_mmap = false;
|
|
|
|
|
+ } else if (arg == "--embedding") {
|
|
|
|
|
+ params.embedding = true;
|
|
|
} else {
|
|
} else {
|
|
|
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
|
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
|
|
server_print_usage(argv[0], default_params, default_sparams);
|
|
server_print_usage(argv[0], default_params, default_sparams);
|
|
@@ -646,6 +667,12 @@ static json format_generation_settings(llama_server_context & llama) {
|
|
|
};
|
|
};
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+static json format_embedding_response(llama_server_context & llama) {
|
|
|
|
|
+ return json {
|
|
|
|
|
+ { "embedding", llama.getEmbedding() },
|
|
|
|
|
+ };
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
static json format_final_response(llama_server_context & llama, const std::string & content) {
|
|
static json format_final_response(llama_server_context & llama, const std::string & content) {
|
|
|
return json {
|
|
return json {
|
|
|
{ "content", content },
|
|
{ "content", content },
|
|
@@ -881,12 +908,27 @@ int main(int argc, char ** argv) {
|
|
|
|
|
|
|
|
svr.Post("/tokenize", [&llama](const Request & req, Response & res) {
|
|
svr.Post("/tokenize", [&llama](const Request & req, Response & res) {
|
|
|
const json body = json::parse(req.body);
|
|
const json body = json::parse(req.body);
|
|
|
- const std::string content = body["content"].get<std::string>();
|
|
|
|
|
|
|
+ const std::string content = body.value("content", "");
|
|
|
const std::vector<llama_token> tokens = llama_tokenize(llama.ctx, content, false);
|
|
const std::vector<llama_token> tokens = llama_tokenize(llama.ctx, content, false);
|
|
|
const json data = format_tokenizer_response(tokens);
|
|
const json data = format_tokenizer_response(tokens);
|
|
|
return res.set_content(data.dump(), "application/json");
|
|
return res.set_content(data.dump(), "application/json");
|
|
|
});
|
|
});
|
|
|
|
|
|
|
|
|
|
+ svr.Post("/embedding", [&llama](const Request & req, Response & res) {
|
|
|
|
|
+ const json body = json::parse(req.body);
|
|
|
|
|
+
|
|
|
|
|
+ llama.rewind();
|
|
|
|
|
+ llama_reset_timings(llama.ctx);
|
|
|
|
|
+ llama.params.prompt = body.value("content", "");
|
|
|
|
|
+ llama.params.n_predict = 0;
|
|
|
|
|
+ llama.loadPrompt();
|
|
|
|
|
+ llama.beginCompletion();
|
|
|
|
|
+ llama.doCompletion();
|
|
|
|
|
+
|
|
|
|
|
+ const json data = format_embedding_response(llama);
|
|
|
|
|
+ return res.set_content(data.dump(), "application/json");
|
|
|
|
|
+ });
|
|
|
|
|
+
|
|
|
svr.set_logger(log_server_request);
|
|
svr.set_logger(log_server_request);
|
|
|
|
|
|
|
|
svr.set_exception_handler([](const Request &, Response & res, std::exception_ptr ep) {
|
|
svr.set_exception_handler([](const Request &, Response & res, std::exception_ptr ep) {
|