|
|
@@ -101,8 +101,6 @@ struct server_slot {
|
|
|
std::string generated_text;
|
|
|
llama_tokens generated_tokens;
|
|
|
|
|
|
- common_chat_msg chat_msg;
|
|
|
-
|
|
|
std::vector<completion_token_output> generated_token_probs;
|
|
|
|
|
|
bool has_next_token = true;
|
|
|
@@ -153,9 +151,6 @@ struct server_slot {
|
|
|
|
|
|
llama_token sampled;
|
|
|
|
|
|
- common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
|
|
- std::vector<std::string> generated_tool_call_ids;
|
|
|
-
|
|
|
// stats
|
|
|
size_t n_sent_text = 0; // number of sent text character
|
|
|
|
|
|
@@ -183,13 +178,10 @@ struct server_slot {
|
|
|
stop = STOP_TYPE_NONE;
|
|
|
stopping_word = "";
|
|
|
n_sent_text = 0;
|
|
|
- chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
|
|
|
|
|
generated_tokens.clear();
|
|
|
generated_token_probs.clear();
|
|
|
- chat_msg = {};
|
|
|
json_schema = json();
|
|
|
- generated_tool_call_ids.clear();
|
|
|
|
|
|
// clear speculative decoding stats
|
|
|
n_draft_total = 0;
|
|
|
@@ -302,23 +294,6 @@ struct server_slot {
|
|
|
return timings;
|
|
|
}
|
|
|
|
|
|
- const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
|
|
|
- GGML_ASSERT(task);
|
|
|
-
|
|
|
- auto previous_msg = chat_msg;
|
|
|
- SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
|
|
|
- auto new_msg = common_chat_parse(
|
|
|
- generated_text,
|
|
|
- /* is_partial= */ stop != STOP_TYPE_EOS,
|
|
|
- task->params.oaicompat_chat_syntax);
|
|
|
- if (!new_msg.empty()) {
|
|
|
- new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
|
|
|
- chat_msg = new_msg;
|
|
|
- diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
|
|
|
- }
|
|
|
- return chat_msg;
|
|
|
- }
|
|
|
-
|
|
|
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
|
|
|
GGML_ASSERT(task);
|
|
|
|
|
|
@@ -1284,8 +1259,6 @@ struct server_context_impl {
|
|
|
} else {
|
|
|
res->content = tkn.text_to_send;
|
|
|
res->tokens = { tkn.tok };
|
|
|
-
|
|
|
- slot.update_chat_msg(res->oaicompat_msg_diffs);
|
|
|
}
|
|
|
|
|
|
res->n_decoded = slot.n_decoded;
|
|
|
@@ -1317,8 +1290,14 @@ struct server_context_impl {
|
|
|
res->id_slot = slot.id;
|
|
|
|
|
|
res->index = slot.task->index;
|
|
|
- res->content = slot.generated_text;
|
|
|
- res->tokens = std::move(slot.generated_tokens);
|
|
|
+ // in stream mode, content and tokens are already in last partial chunk
|
|
|
+ if (slot.task->params.stream) {
|
|
|
+ res->content = "";
|
|
|
+ res->tokens = llama_tokens{};
|
|
|
+ } else {
|
|
|
+ res->content = std::move(slot.generated_text);
|
|
|
+ res->tokens = std::move(slot.generated_tokens);
|
|
|
+ }
|
|
|
res->timings = slot.get_timings();
|
|
|
res->prompt = slot.task->tokens.detokenize(ctx, true);
|
|
|
res->response_fields = std::move(slot.task->params.response_fields);
|
|
|
@@ -1338,7 +1317,6 @@ struct server_context_impl {
|
|
|
res->res_type = slot.task->params.res_type;
|
|
|
res->oaicompat_model = slot.task->params.oaicompat_model;
|
|
|
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
|
|
|
- res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
|
|
|
|
|
|
// populate res.probs_output
|
|
|
if (slot.task->params.sampling.n_probs > 0) {
|
|
|
@@ -2596,6 +2574,9 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|
|
try {
|
|
|
std::vector<server_task> tasks;
|
|
|
|
|
|
+ // tracking generation state and partial tool calls
|
|
|
+ std::vector<task_result_state> states;
|
|
|
+
|
|
|
const auto & prompt = data.at("prompt");
|
|
|
// TODO: this log can become very long, put it behind a flag or think about a more compact format
|
|
|
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
|
|
|
@@ -2611,6 +2592,7 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|
|
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
|
|
}
|
|
|
tasks.reserve(inputs.size());
|
|
|
+ states.reserve(inputs.size());
|
|
|
for (size_t i = 0; i < inputs.size(); i++) {
|
|
|
server_task task = server_task(type);
|
|
|
|
|
|
@@ -2628,10 +2610,12 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|
|
task.params.res_type = res_type;
|
|
|
task.params.oaicompat_cmpl_id = completion_id;
|
|
|
task.params.oaicompat_model = ctx_server.model_name;
|
|
|
+ states.push_back(task.params.oaicompat_chat_syntax);
|
|
|
|
|
|
tasks.push_back(std::move(task));
|
|
|
}
|
|
|
|
|
|
+ rd.set_states(std::move(states));
|
|
|
rd.post_tasks(std::move(tasks));
|
|
|
} catch (const std::exception & e) {
|
|
|
res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
|
|
@@ -2657,7 +2641,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|
|
// if single request, return single object instead of array
|
|
|
res->ok(arr.size() == 1 ? arr[0] : arr);
|
|
|
}
|
|
|
-
|
|
|
} else {
|
|
|
// in streaming mode, the first error must be treated as non-stream response
|
|
|
// this is to match the OAI API behavior
|
|
|
@@ -2676,76 +2659,92 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|
|
}
|
|
|
|
|
|
// next responses are streamed
|
|
|
+ // to be sent immediately
|
|
|
+ json first_result_json = first_result->to_json();
|
|
|
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
|
|
- res->data = format_anthropic_sse(first_result->to_json());
|
|
|
+ res->data = format_anthropic_sse(first_result_json);
|
|
|
} else {
|
|
|
- res->data = format_oai_sse(first_result->to_json()); // to be sent immediately
|
|
|
+ res->data = format_oai_sse(first_result_json);
|
|
|
}
|
|
|
res->status = 200;
|
|
|
res->content_type = "text/event-stream";
|
|
|
res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool {
|
|
|
- if (should_stop()) {
|
|
|
- SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
|
|
|
- return false; // should_stop condition met
|
|
|
- }
|
|
|
-
|
|
|
- if (!res_this->data.empty()) {
|
|
|
- // flush the first chunk
|
|
|
- output = std::move(res_this->data);
|
|
|
- res_this->data.clear();
|
|
|
- return true;
|
|
|
- }
|
|
|
-
|
|
|
- server_response_reader & rd = res_this->rd;
|
|
|
-
|
|
|
- // check if there is more data
|
|
|
- if (!rd.has_next()) {
|
|
|
- if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
|
|
- // Anthropic doesn't send [DONE], message_stop was already sent
|
|
|
- output = "";
|
|
|
- } else if (res_type != TASK_RESPONSE_TYPE_NONE) {
|
|
|
- output = "data: [DONE]\n\n";
|
|
|
- } else {
|
|
|
- output = "";
|
|
|
- }
|
|
|
- SRV_DBG("%s", "all results received, terminating stream\n");
|
|
|
- return false; // no more data, terminate
|
|
|
- }
|
|
|
-
|
|
|
- // receive subsequent results
|
|
|
- auto result = rd.next(should_stop);
|
|
|
- if (result == nullptr) {
|
|
|
- SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
|
|
|
- return false; // should_stop condition met
|
|
|
- }
|
|
|
-
|
|
|
- // send the results
|
|
|
- json res_json = result->to_json();
|
|
|
- if (result->is_error()) {
|
|
|
+ static auto format_error = [](task_response_type res_type, const json & res_json) {
|
|
|
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
|
|
- output = format_anthropic_sse({
|
|
|
+ return format_anthropic_sse({
|
|
|
{"event", "error"},
|
|
|
{"data", res_json},
|
|
|
});
|
|
|
} else {
|
|
|
- output = format_oai_sse(json {{ "error", res_json }});
|
|
|
+ return format_oai_sse(json {{ "error", res_json }});
|
|
|
}
|
|
|
- SRV_DBG("%s", "error received during streaming, terminating stream\n");
|
|
|
- return false; // terminate on error
|
|
|
- } else {
|
|
|
- GGML_ASSERT(
|
|
|
- dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|
|
|
- || dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
|
|
|
- );
|
|
|
- if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
|
|
- output = format_anthropic_sse(res_json);
|
|
|
+ };
|
|
|
+
|
|
|
+ try {
|
|
|
+ if (should_stop()) {
|
|
|
+ SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
|
|
|
+ return false; // should_stop condition met
|
|
|
+ }
|
|
|
+
|
|
|
+ if (!res_this->data.empty()) {
|
|
|
+ // flush the first chunk
|
|
|
+ output = std::move(res_this->data);
|
|
|
+ res_this->data.clear();
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
+ server_response_reader & rd = res_this->rd;
|
|
|
+
|
|
|
+ // check if there is more data
|
|
|
+ if (!rd.has_next()) {
|
|
|
+ if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
|
|
+ // Anthropic doesn't send [DONE], message_stop was already sent
|
|
|
+ output = "";
|
|
|
+ } else if (res_type != TASK_RESPONSE_TYPE_NONE) {
|
|
|
+ output = "data: [DONE]\n\n";
|
|
|
+ } else {
|
|
|
+ output = "";
|
|
|
+ }
|
|
|
+ SRV_DBG("%s", "all results received, terminating stream\n");
|
|
|
+ return false; // no more data, terminate
|
|
|
+ }
|
|
|
+
|
|
|
+ // receive subsequent results
|
|
|
+ auto result = rd.next(should_stop);
|
|
|
+ if (result == nullptr) {
|
|
|
+ SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
|
|
|
+ return false; // should_stop condition met
|
|
|
+ }
|
|
|
+
|
|
|
+ // send the results
|
|
|
+ if (result->is_error()) {
|
|
|
+ json res_json = result->to_json();
|
|
|
+ output = format_error(res_type, res_json);
|
|
|
+ SRV_DBG("%s", "error received during streaming, terminating stream\n");
|
|
|
+ return false; // terminate on error
|
|
|
} else {
|
|
|
- output = format_oai_sse(res_json);
|
|
|
+ GGML_ASSERT(
|
|
|
+ dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|
|
|
+ || dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
|
|
|
+ );
|
|
|
+ json res_json = result->to_json();
|
|
|
+ if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
|
|
+ output = format_anthropic_sse(res_json);
|
|
|
+ } else {
|
|
|
+ output = format_oai_sse(res_json);
|
|
|
+ }
|
|
|
}
|
|
|
- }
|
|
|
|
|
|
- // has next data, continue
|
|
|
- return true;
|
|
|
+ // has next data, continue
|
|
|
+ return true;
|
|
|
+
|
|
|
+ } catch (const std::exception & e) {
|
|
|
+ json error_json = format_error_response(e.what(), ERROR_TYPE_SERVER);
|
|
|
+ output = format_error(res_type, error_json);
|
|
|
+
|
|
|
+ // terminate on exception
|
|
|
+ return false;
|
|
|
+ }
|
|
|
};
|
|
|
}
|
|
|
|