Kaynağa Gözat

server: move msg diffs tracking to HTTP thread (#17740)

* server: move msg diffs tracking to HTTP thread

* wip

* tool call tests ok

* minor : style

* cont : fix

* move states to server_response_reader

* add safe-guard

* fix

* fix 2

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Xuan-Son Nguyen 1 ay önce
ebeveyn
işleme
c4c10bfb86

+ 85 - 86
tools/server/server-context.cpp

@@ -101,8 +101,6 @@ struct server_slot {
     std::string  generated_text;
     llama_tokens generated_tokens;
 
-    common_chat_msg chat_msg;
-
     std::vector<completion_token_output> generated_token_probs;
 
     bool has_next_token = true;
@@ -153,9 +151,6 @@ struct server_slot {
 
     llama_token sampled;
 
-    common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
-    std::vector<std::string> generated_tool_call_ids;
-
     // stats
     size_t n_sent_text = 0; // number of sent text character
 
@@ -183,13 +178,10 @@ struct server_slot {
         stop           = STOP_TYPE_NONE;
         stopping_word  = "";
         n_sent_text    = 0;
-        chat_format    = COMMON_CHAT_FORMAT_CONTENT_ONLY;
 
         generated_tokens.clear();
         generated_token_probs.clear();
-        chat_msg = {};
         json_schema = json();
-        generated_tool_call_ids.clear();
 
         // clear speculative decoding stats
         n_draft_total = 0;
@@ -302,23 +294,6 @@ struct server_slot {
         return timings;
     }
 
-    const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
-        GGML_ASSERT(task);
-
-        auto previous_msg = chat_msg;
-        SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
-        auto new_msg = common_chat_parse(
-            generated_text,
-            /* is_partial= */ stop != STOP_TYPE_EOS,
-            task->params.oaicompat_chat_syntax);
-        if (!new_msg.empty()) {
-            new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
-            chat_msg = new_msg;
-            diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
-        }
-        return chat_msg;
-    }
-
     size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
         GGML_ASSERT(task);
 
@@ -1284,8 +1259,6 @@ struct server_context_impl {
         } else {
             res->content = tkn.text_to_send;
             res->tokens  = { tkn.tok };
-
-            slot.update_chat_msg(res->oaicompat_msg_diffs);
         }
 
         res->n_decoded           = slot.n_decoded;
@@ -1317,8 +1290,14 @@ struct server_context_impl {
         res->id_slot = slot.id;
 
         res->index           = slot.task->index;
-        res->content         = slot.generated_text;
-        res->tokens          = std::move(slot.generated_tokens);
+        // in stream mode, content and tokens are already in last partial chunk
+        if (slot.task->params.stream) {
+            res->content     = "";
+            res->tokens      = llama_tokens{};
+        } else {
+            res->content     = std::move(slot.generated_text);
+            res->tokens      = std::move(slot.generated_tokens);
+        }
         res->timings         = slot.get_timings();
         res->prompt          = slot.task->tokens.detokenize(ctx, true);
         res->response_fields = std::move(slot.task->params.response_fields);
@@ -1338,7 +1317,6 @@ struct server_context_impl {
         res->res_type          = slot.task->params.res_type;
         res->oaicompat_model   = slot.task->params.oaicompat_model;
         res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
-        res->oaicompat_msg     = slot.update_chat_msg(res->oaicompat_msg_diffs);
 
         // populate res.probs_output
         if (slot.task->params.sampling.n_probs > 0) {
@@ -2596,6 +2574,9 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
     try {
         std::vector<server_task> tasks;
 
+        // tracking generation state and partial tool calls
+        std::vector<task_result_state> states;
+
         const auto & prompt = data.at("prompt");
         // TODO: this log can become very long, put it behind a flag or think about a more compact format
         //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
@@ -2611,6 +2592,7 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
             inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
         }
         tasks.reserve(inputs.size());
+        states.reserve(inputs.size());
         for (size_t i = 0; i < inputs.size(); i++) {
             server_task task = server_task(type);
 
@@ -2628,10 +2610,12 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
             task.params.res_type          = res_type;
             task.params.oaicompat_cmpl_id = completion_id;
             task.params.oaicompat_model   = ctx_server.model_name;
+            states.push_back(task.params.oaicompat_chat_syntax);
 
             tasks.push_back(std::move(task));
         }
 
+        rd.set_states(std::move(states));
         rd.post_tasks(std::move(tasks));
     } catch (const std::exception & e) {
         res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
@@ -2657,7 +2641,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
             // if single request, return single object instead of array
             res->ok(arr.size() == 1 ? arr[0] : arr);
         }
-
     } else {
         // in streaming mode, the first error must be treated as non-stream response
         // this is to match the OAI API behavior
@@ -2676,76 +2659,92 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
         }
 
         // next responses are streamed
+        // to be sent immediately
+        json first_result_json = first_result->to_json();
         if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
-            res->data = format_anthropic_sse(first_result->to_json());
+            res->data = format_anthropic_sse(first_result_json);
         } else {
-            res->data = format_oai_sse(first_result->to_json()); // to be sent immediately
+            res->data = format_oai_sse(first_result_json);
         }
         res->status = 200;
         res->content_type = "text/event-stream";
         res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool {
-            if (should_stop()) {
-                SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
-                return false; // should_stop condition met
-            }
-
-            if (!res_this->data.empty()) {
-                // flush the first chunk
-                output = std::move(res_this->data);
-                res_this->data.clear();
-                return true;
-            }
-
-            server_response_reader & rd = res_this->rd;
-
-            // check if there is more data
-            if (!rd.has_next()) {
-                if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
-                    // Anthropic doesn't send [DONE], message_stop was already sent
-                    output = "";
-                } else if (res_type != TASK_RESPONSE_TYPE_NONE) {
-                    output = "data: [DONE]\n\n";
-                } else {
-                    output = "";
-                }
-                SRV_DBG("%s", "all results received, terminating stream\n");
-                return false; // no more data, terminate
-            }
-
-            // receive subsequent results
-            auto result = rd.next(should_stop);
-            if (result == nullptr) {
-                SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
-                return false; // should_stop condition met
-            }
-
-            // send the results
-            json res_json = result->to_json();
-            if (result->is_error()) {
+            static auto format_error = [](task_response_type res_type, const json & res_json) {
                 if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
-                    output = format_anthropic_sse({
+                    return format_anthropic_sse({
                         {"event", "error"},
                         {"data", res_json},
                     });
                 } else {
-                    output = format_oai_sse(json {{ "error", res_json }});
+                    return format_oai_sse(json {{ "error", res_json }});
                 }
-                SRV_DBG("%s", "error received during streaming, terminating stream\n");
-                return false; // terminate on error
-            } else {
-                GGML_ASSERT(
-                    dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
-                    || dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
-                );
-                if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
-                    output = format_anthropic_sse(res_json);
+            };
+
+            try {
+                if (should_stop()) {
+                    SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
+                    return false; // should_stop condition met
+                }
+
+                if (!res_this->data.empty()) {
+                    // flush the first chunk
+                    output = std::move(res_this->data);
+                    res_this->data.clear();
+                    return true;
+                }
+
+                server_response_reader & rd = res_this->rd;
+
+                // check if there is more data
+                if (!rd.has_next()) {
+                    if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
+                        // Anthropic doesn't send [DONE], message_stop was already sent
+                        output = "";
+                    } else if (res_type != TASK_RESPONSE_TYPE_NONE) {
+                        output = "data: [DONE]\n\n";
+                    } else {
+                        output = "";
+                    }
+                    SRV_DBG("%s", "all results received, terminating stream\n");
+                    return false; // no more data, terminate
+                }
+
+                // receive subsequent results
+                auto result = rd.next(should_stop);
+                if (result == nullptr) {
+                    SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
+                    return false; // should_stop condition met
+                }
+
+                // send the results
+                if (result->is_error()) {
+                    json res_json = result->to_json();
+                    output = format_error(res_type, res_json);
+                    SRV_DBG("%s", "error received during streaming, terminating stream\n");
+                    return false; // terminate on error
                 } else {
-                    output = format_oai_sse(res_json);
+                    GGML_ASSERT(
+                        dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
+                        || dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
+                    );
+                    json res_json = result->to_json();
+                    if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
+                        output = format_anthropic_sse(res_json);
+                    } else {
+                        output = format_oai_sse(res_json);
+                    }
                 }
-            }
 
-            // has next data, continue
-            return true;
+                // has next data, continue
+                return true;
+
+            } catch (const std::exception & e) {
+                json error_json = format_error_response(e.what(), ERROR_TYPE_SERVER);
+                output = format_error(res_type, error_json);
+
+                // terminate on exception
+                return false;
+            }
         };
     }
 

+ 10 - 0
tools/server/server-queue.cpp

@@ -271,6 +271,10 @@ void server_response::terminate() {
 // server_response_reader
 //
 
+void server_response_reader::set_states(std::vector<task_result_state> && states) {
+    this->states = std::move(states);
+}
+
 void server_response_reader::post_tasks(std::vector<server_task> && tasks) {
     id_tasks = server_task::get_list_id(tasks);
     queue_results.add_waiting_tasks(tasks);
@@ -298,6 +302,12 @@ server_task_result_ptr server_response_reader::next(const std::function<bool()>
                 SRV_DBG("%s", "received error result, stopping further processing\n");
                 return result;
             }
+            if (!states.empty()) {
+                // update the generation state if needed
+                size_t idx = result->get_index();
+                GGML_ASSERT(idx < states.size());
+                result->update(states[idx]);
+            }
             if (result->is_stop()) {
                 received_count++;
             }

+ 9 - 0
tools/server/server-queue.h

@@ -7,6 +7,8 @@
 #include <mutex>
 #include <unordered_set>
 
+// struct for managing server tasks
+// in most cases, use server_response_reader to post new tasks and retrieve results
 struct server_queue {
 private:
     int id = 0;
@@ -67,6 +69,8 @@ private:
     void cleanup_pending_task(int id_target);
 };
 
+// struct for managing server responses
+// in most cases, use server_response_reader to retrieve results
 struct server_response {
 private:
     bool running = true;
@@ -120,6 +124,10 @@ struct server_response_reader {
     bool cancelled = false;
     int polling_interval_seconds;
 
+    // tracking generation state and partial tool calls
+    // only used by streaming completions
+    std::vector<task_result_state> states;
+
     // should_stop function will be called each polling_interval_seconds
     server_response_reader(std::pair<server_queue &, server_response &> server_queues, int polling_interval_seconds)
         : queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {}
@@ -127,6 +135,7 @@ struct server_response_reader {
         stop();
     }
 
+    void set_states(std::vector<task_result_state> && states);
     void post_tasks(std::vector<server_task> && tasks);
     bool has_next() const;
 

+ 24 - 3
tools/server/server-task.cpp

@@ -565,6 +565,7 @@ std::vector<unsigned char> completion_token_output::str_to_bytes(const std::stri
 // server_task_result_cmpl_final
 //
 json server_task_result_cmpl_final::to_json() {
+    GGML_ASSERT(is_updated && "update() must be called before to_json()");
     switch (res_type) {
         case TASK_RESPONSE_TYPE_NONE:
             return to_json_non_oaicompat();
@@ -582,8 +583,8 @@ json server_task_result_cmpl_final::to_json() {
 json server_task_result_cmpl_final::to_json_non_oaicompat() {
     json res = json {
         {"index",               index},
-        {"content",             stream ? "" : content}, // in stream mode, content is already in last partial chunk
-        {"tokens",              stream ? llama_tokens {} : tokens},
+        {"content",             content},
+        {"tokens",              tokens},
         {"id_slot",             id_slot},
         {"stop",                true},
         {"model",               oaicompat_model},
@@ -619,7 +620,7 @@ json server_task_result_cmpl_final::to_json_oaicompat() {
     json res = json {
         {"choices",            json::array({
             json{
-                {"text",          stream ? "" : content}, // in stream mode, content is already in last partial chunk
+                {"text",          content},
                 {"index",         index},
                 {"logprobs",      logprobs},
                 {"finish_reason", finish_reason},
@@ -700,6 +701,25 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat() {
     return res;
 }
 
+common_chat_msg task_result_state::update_chat_msg(
+        const std::string & text_added,
+        bool is_partial,
+        std::vector<common_chat_msg_diff> & diffs) {
+    generated_text += text_added;
+    auto msg_prv_copy = chat_msg;
+    SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
+    auto new_msg = common_chat_parse(
+        generated_text,
+        is_partial,
+        oaicompat_chat_syntax);
+    if (!new_msg.empty()) {
+        new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
+        chat_msg = new_msg;
+        diffs = common_chat_msg_diff::compute_diffs(msg_prv_copy, new_msg.empty() ? msg_prv_copy : new_msg);
+    }
+    return chat_msg;
+}
+
 json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() {
     std::time_t t = std::time(0);
     std::string finish_reason = "length";
@@ -956,6 +976,7 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() {
 // server_task_result_cmpl_partial
 //
 json server_task_result_cmpl_partial::to_json() {
+    GGML_ASSERT(is_updated && "update() must be called before to_json()");
     switch (res_type) {
         case TASK_RESPONSE_TYPE_NONE:
             return to_json_non_oaicompat();

+ 37 - 3
tools/server/server-task.h

@@ -161,6 +161,25 @@ struct result_prompt_progress {
     json to_json() const;
 };
 
+// struct for tracking the state of a task (e.g., for streaming)
+struct task_result_state {
+    // tracking diffs for partial tool calls
+    std::vector<common_chat_msg_diff> diffs;
+    common_chat_syntax oaicompat_chat_syntax;
+    common_chat_msg chat_msg;
+    std::string generated_text; // append new chunks of generated text here
+    std::vector<std::string> generated_tool_call_ids;
+
+    task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
+        : oaicompat_chat_syntax(oaicompat_chat_syntax) {}
+
+    // parse partial tool calls and update the internal state
+    common_chat_msg update_chat_msg(
+        const std::string & text_added,
+        bool is_partial,
+        std::vector<common_chat_msg_diff> & diffs);
+};
+
 struct server_task_result {
     int id           = -1;
     int id_slot      = -1;
@@ -175,6 +194,9 @@ struct server_task_result {
     virtual int get_index() {
         return -1;
     }
+    virtual void update(task_result_state &) {
+        // only used by server_task_result_cmpl_*
+    }
     virtual json to_json() = 0;
     virtual ~server_task_result() = default;
 };
@@ -233,9 +255,10 @@ struct server_task_result_cmpl_final : server_task_result {
     task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
     std::string        oaicompat_model;
     std::string        oaicompat_cmpl_id;
-    common_chat_msg    oaicompat_msg;
+    common_chat_msg    oaicompat_msg; // to be populated by update()
 
-    std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
+    std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
+    bool is_updated = false;
 
     virtual int get_index() override {
         return index;
@@ -247,6 +270,11 @@ struct server_task_result_cmpl_final : server_task_result {
 
     virtual json to_json() override;
 
+    virtual void update(task_result_state & state) override {
+        is_updated = true;
+        oaicompat_msg = state.update_chat_msg(content, false, oaicompat_msg_diffs);
+    }
+
     json to_json_non_oaicompat();
 
     json to_json_oaicompat();
@@ -280,7 +308,8 @@ struct server_task_result_cmpl_partial : server_task_result {
     task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
     std::string        oaicompat_model;
     std::string        oaicompat_cmpl_id;
-    std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
+    std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
+    bool is_updated = false;
 
     virtual int get_index() override {
         return index;
@@ -292,6 +321,11 @@ struct server_task_result_cmpl_partial : server_task_result {
 
     virtual json to_json() override;
 
+    virtual void update(task_result_state & state) override {
+        is_updated = true;
+        state.update_chat_msg(content, true, oaicompat_msg_diffs);
+    }
+
     json to_json_non_oaicompat();
 
     json to_json_oaicompat();