|
|
@@ -2589,6 +2589,10 @@ struct server_context_impl {
|
|
|
int get_slot_n_ctx() {
|
|
|
return slots.back().n_ctx;
|
|
|
}
|
|
|
+
|
|
|
+ server_response_reader get_response_reader() {
|
|
|
+ return server_response_reader(queue_tasks, queue_results, HTTP_POLLING_SECONDS);
|
|
|
+ }
|
|
|
};
|
|
|
|
|
|
//
|
|
|
@@ -2618,8 +2622,8 @@ llama_context * server_context::get_llama_context() const {
|
|
|
return impl->ctx;
|
|
|
}
|
|
|
|
|
|
-std::pair<server_queue &, server_response &> server_context::get_queues() {
|
|
|
- return { impl->queue_tasks, impl->queue_results };
|
|
|
+server_response_reader server_context::get_response_reader() {
|
|
|
+ return impl->get_response_reader();
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -2628,7 +2632,7 @@ std::pair<server_queue &, server_response &> server_context::get_queues() {
|
|
|
struct server_res_generator : server_http_res {
|
|
|
server_response_reader rd;
|
|
|
server_res_generator(server_context_impl & ctx_server)
|
|
|
- : rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS) {}
|
|
|
+ : rd(ctx_server.queue_tasks, ctx_server.queue_results, HTTP_POLLING_SECONDS) {}
|
|
|
void ok(const json & response_data) {
|
|
|
status = 200;
|
|
|
data = safe_json_to_str(response_data);
|
|
|
@@ -2661,9 +2665,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|
|
try {
|
|
|
std::vector<server_task> tasks;
|
|
|
|
|
|
- // tracking generation state and partial tool calls
|
|
|
- std::vector<task_result_state> states;
|
|
|
-
|
|
|
const auto & prompt = data.at("prompt");
|
|
|
// TODO: this log can become very long, put it behind a flag or think about a more compact format
|
|
|
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
|
|
|
@@ -2679,7 +2680,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|
|
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
|
|
}
|
|
|
tasks.reserve(inputs.size());
|
|
|
- states.reserve(inputs.size());
|
|
|
int idx = 0;
|
|
|
for (size_t i = 0; i < inputs.size(); i++) {
|
|
|
server_task task = server_task(type);
|
|
|
@@ -2698,7 +2698,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|
|
task.params.res_type = res_type;
|
|
|
task.params.oaicompat_cmpl_id = completion_id;
|
|
|
task.params.oaicompat_model = ctx_server.model_name;
|
|
|
- states.push_back(task.params.oaicompat_chat_syntax);
|
|
|
|
|
|
if (task.params.n_cmpl > 1) {
|
|
|
task.n_children = task.params.n_cmpl - 1;
|
|
|
@@ -2707,7 +2706,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|
|
task.id,
|
|
|
ctx_server.queue_tasks.get_new_id(),
|
|
|
idx++);
|
|
|
- states.push_back(child.params.oaicompat_chat_syntax);
|
|
|
tasks.push_back(std::move(child));
|
|
|
}
|
|
|
}
|
|
|
@@ -2715,7 +2713,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|
|
tasks.push_back(std::move(task));
|
|
|
}
|
|
|
|
|
|
- rd.set_states(std::move(states));
|
|
|
rd.post_tasks(std::move(tasks));
|
|
|
} catch (const std::exception & e) {
|
|
|
res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
|
|
@@ -3445,7 +3442,7 @@ void server_routes::init_routes() {
|
|
|
|
|
|
// create and queue the task
|
|
|
json responses = json::array();
|
|
|
- server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS);
|
|
|
+ server_response_reader rd = ctx_server.get_response_reader();
|
|
|
{
|
|
|
std::vector<server_task> tasks;
|
|
|
tasks.reserve(documents.size());
|
|
|
@@ -3705,7 +3702,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_embeddings_impl(cons
|
|
|
|
|
|
// create and queue the task
|
|
|
json responses = json::array();
|
|
|
- server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS);
|
|
|
+ server_response_reader rd = ctx_server.get_response_reader();
|
|
|
{
|
|
|
std::vector<server_task> tasks;
|
|
|
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|