Просмотр исходного кода

server: (refactor) implement generator-based API for task results (#17174)

* server: (refactor) implement generator-based API for task results

* improve

* moving some code

* fix "Response ended prematurely"

* add sink.done before return false

* rm redundant check

* rm unused var

* rename generator --> reader
Xuan-Son Nguyen 2 месяцев назад
Родитель
Сommit
00c94083b3
2 измененных файлов с 232 добавлено и 193 удалено
  1. 212 187
      tools/server/server.cpp
  2. 20 6
      tools/server/utils.hpp

+ 212 - 187
tools/server/server.cpp

@@ -684,7 +684,7 @@ struct server_task_result {
     }
     virtual bool is_stop() {
         // only used by server_task_result_cmpl_*
-        return false;
+        return true;
     }
     virtual int get_index() {
         return -1;
@@ -3238,105 +3238,6 @@ struct server_context {
         queue_results.send(std::move(res));
     }
 
-    //
-    // Functions to create new task(s) and receive result(s)
-    //
-
-    void cancel_tasks(const std::unordered_set<int> & id_tasks) {
-        std::vector<server_task> cancel_tasks;
-        cancel_tasks.reserve(id_tasks.size());
-        for (const auto & id_task : id_tasks) {
-            SRV_WRN("cancel task, id_task = %d\n", id_task);
-
-            server_task task(SERVER_TASK_TYPE_CANCEL);
-            task.id_target = id_task;
-            queue_results.remove_waiting_task_id(id_task);
-            cancel_tasks.push_back(std::move(task));
-        }
-        // push to beginning of the queue, so it has highest priority
-        queue_tasks.post(std::move(cancel_tasks), true);
-    }
-
-    // receive the results from task(s)
-    void receive_multi_results(
-            const std::unordered_set<int> & id_tasks,
-            const std::function<void(std::vector<server_task_result_ptr>&)> & result_handler,
-            const std::function<void(json)> & error_handler,
-            const std::function<bool()> & is_connection_closed) {
-        std::vector<server_task_result_ptr> results(id_tasks.size());
-        for (int i = 0; i < (int)id_tasks.size(); i++) {
-            server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
-
-            if (is_connection_closed()) {
-                cancel_tasks(id_tasks);
-                return;
-            }
-
-            if (result == nullptr) {
-                i--; // retry
-                continue;
-            }
-
-            if (result->is_error()) {
-                error_handler(result->to_json());
-                cancel_tasks(id_tasks);
-                return;
-            }
-
-            GGML_ASSERT(
-                dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
-                || dynamic_cast<server_task_result_embd*>(result.get()) != nullptr
-                || dynamic_cast<server_task_result_rerank*>(result.get()) != nullptr
-            );
-            const size_t idx = result->get_index();
-            GGML_ASSERT(idx < results.size() && "index out of range");
-            results[idx] = std::move(result);
-        }
-        result_handler(results);
-    }
-
-    // receive the results from task(s), in stream mode
-    void receive_cmpl_results_stream(
-            const std::unordered_set<int> & id_tasks,
-            const std::function<bool(server_task_result_ptr&)> & result_handler,
-            const std::function<void(json)> & error_handler,
-            const std::function<bool()> & is_connection_closed) {
-        size_t n_finished = 0;
-        while (true) {
-            server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
-
-            if (is_connection_closed()) {
-                cancel_tasks(id_tasks);
-                return;
-            }
-
-            if (result == nullptr) {
-                continue; // retry
-            }
-
-            if (result->is_error()) {
-                error_handler(result->to_json());
-                cancel_tasks(id_tasks);
-                return;
-            }
-
-            GGML_ASSERT(
-                dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
-                || dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
-            );
-            if (!result_handler(result)) {
-                cancel_tasks(id_tasks);
-                break;
-            }
-
-            if (result->is_stop()) {
-                if (++n_finished == id_tasks.size()) {
-                    break;
-                }
-            }
-        }
-    }
-
     //
     // Functions to process the task
     //
@@ -4418,6 +4319,104 @@ struct server_context {
     }
 };
 
+// generator-like API for server responses, support pooling connection state and aggregating results
+struct server_response_reader {
+    std::unordered_set<int> id_tasks;
+    server_context & ctx_server;
+    size_t received_count = 0;
+    bool cancelled = false;
+
+    server_response_reader(server_context & ctx_server) : ctx_server(ctx_server) {}
+    ~server_response_reader() {
+        stop();
+    }
+
+    void post_tasks(std::vector<server_task> && tasks) {
+        id_tasks = server_task::get_list_id(tasks);
+        ctx_server.queue_results.add_waiting_tasks(tasks);
+        ctx_server.queue_tasks.post(std::move(tasks));
+    }
+
+    bool has_next() {
+        return !cancelled && received_count < id_tasks.size();
+    }
+
+    // return nullptr if should_stop() is true before receiving a result
+    // note: if one error is received, it will stop further processing and return error result
+    server_task_result_ptr next(const std::function<bool()> & should_stop) {
+        while (true) {
+            server_task_result_ptr result = ctx_server.queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
+            if (result == nullptr) {
+                // timeout, check stop condition
+                if (should_stop()) {
+                    SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n");
+                    return nullptr;
+                }
+            } else {
+                if (result->is_error()) {
+                    stop(); // cancel remaining tasks
+                    SRV_DBG("%s", "received error result, stopping further processing\n");
+                    return result;
+                }
+                if (result->is_stop()) {
+                    received_count++;
+                }
+                return result;
+            }
+        }
+
+        // should not reach here
+    }
+
+    struct batch_response {
+        bool is_terminated = false; // if true, indicates that processing was stopped before all results were received
+        std::vector<server_task_result_ptr> results;
+        server_task_result_ptr error; // nullptr if no error
+    };
+
+    batch_response wait_for_all(const std::function<bool()> & should_stop) {
+        batch_response batch_res;
+        batch_res.results.resize(id_tasks.size());
+        while (has_next()) {
+            auto res = next(should_stop);
+            if (res == nullptr) {
+                batch_res.is_terminated = true;
+                return batch_res;
+            }
+            if (res->is_error()) {
+                batch_res.error = std::move(res);
+                return batch_res;
+            }
+            const size_t idx = res->get_index();
+            GGML_ASSERT(idx < batch_res.results.size() && "index out of range");
+            GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received");
+            batch_res.results[idx] = std::move(res);
+        }
+        return batch_res;
+    }
+
+    void stop() {
+        ctx_server.queue_results.remove_waiting_task_ids(id_tasks);
+        if (has_next() && !cancelled) {
+            // if tasks is not finished yet, cancel them
+            cancelled = true;
+            std::vector<server_task> cancel_tasks;
+            cancel_tasks.reserve(id_tasks.size());
+            for (const auto & id_task : id_tasks) {
+                SRV_WRN("cancel task, id_task = %d\n", id_task);
+                server_task task(SERVER_TASK_TYPE_CANCEL);
+                task.id_target = id_task;
+                ctx_server.queue_results.remove_waiting_task_id(id_task);
+                cancel_tasks.push_back(std::move(task));
+            }
+            // push to beginning of the queue, so it has highest priority
+            ctx_server.queue_tasks.post(std::move(cancel_tasks), true);
+        } else {
+            SRV_DBG("%s", "all tasks already finished, no need to cancel\n");
+        }
+    }
+};
+
 static void log_server_request(const httplib::Request & req, const httplib::Response & res) {
     // skip GH copilot requests when using default port
     if (req.path == "/v1/health") {
@@ -5000,7 +4999,10 @@ int main(int argc, char ** argv) {
         GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
 
         auto completion_id = gen_chatcmplid();
-        std::unordered_set<int> task_ids;
+        // need to store the reader as a pointer, so that it won't be destroyed when the handle returns
+        // use shared_ptr as it's shared between the chunked_content_provider() and on_complete()
+        const auto rd = std::make_shared<server_response_reader>(ctx_server);
+
         try {
             std::vector<server_task> tasks;
 
@@ -5018,17 +5020,8 @@ int main(int argc, char ** argv) {
                 // Everything else, including multimodal completions.
                 inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
             }
-            const size_t n_ctx_slot = ctx_server.slots.front().n_ctx;
             tasks.reserve(inputs.size());
             for (size_t i = 0; i < inputs.size(); i++) {
-                auto n_prompt_tokens = inputs[i].size();
-                if (n_prompt_tokens >= n_ctx_slot) {
-                    json error_data = format_error_response("the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
-                    error_data["n_prompt_tokens"] = n_prompt_tokens;
-                    error_data["n_ctx"] = n_ctx_slot;
-                    res_error(res, error_data);
-                    return;
-                }
                 server_task task = server_task(type);
 
                 task.id    = ctx_server.queue_tasks.get_new_id();
@@ -5049,9 +5042,7 @@ int main(int argc, char ** argv) {
                 tasks.push_back(std::move(task));
             }
 
-            task_ids = server_task::get_list_id(tasks);
-            ctx_server.queue_results.add_waiting_tasks(tasks);
-            ctx_server.queue_tasks.post(std::move(tasks));
+            rd->post_tasks(std::move(tasks));
         } catch (const std::exception & e) {
             res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
             return;
@@ -5060,54 +5051,95 @@ int main(int argc, char ** argv) {
         bool stream = json_value(data, "stream", false);
 
         if (!stream) {
-            ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
-                if (results.size() == 1) {
-                    // single result
-                    res_ok(res, results[0]->to_json());
-                } else {
-                    // multiple results (multitask)
-                    json arr = json::array();
-                    for (auto & res : results) {
-                        arr.push_back(res->to_json());
-                    }
-                    res_ok(res, arr);
+            // non-stream, wait for the results
+            auto all_results = rd->wait_for_all(is_connection_closed);
+            if (all_results.is_terminated) {
+                return; // connection is closed
+            } else if (all_results.error) {
+                res_error(res, all_results.error->to_json());
+                return;
+            } else {
+                json arr = json::array();
+                for (auto & res : all_results.results) {
+                    GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
+                    arr.push_back(res->to_json());
                 }
-            }, [&](const json & error_data) {
-                res_error(res, error_data);
-            }, is_connection_closed);
+                // if single request, return single object instead of array
+                res_ok(res, arr.size() == 1 ? arr[0] : arr);
+            }
 
-            ctx_server.queue_results.remove_waiting_task_ids(task_ids);
         } else {
-            const auto chunked_content_provider = [task_ids, &ctx_server, oaicompat](size_t, httplib::DataSink & sink) {
-                ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
-                    json res_json = result->to_json();
-                    if (res_json.is_array()) {
-                        for (const auto & res : res_json) {
-                            if (!server_sent_event(sink, res)) {
-                                // sending failed (HTTP connection closed), cancel the generation
-                                return false;
-                            }
-                        }
-                        return true;
-                    } else {
-                        return server_sent_event(sink, res_json);
+            // in streaming mode, the first error must be treated as non-stream response
+            // this is to match the OAI API behavior
+            // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309
+            server_task_result_ptr first_result = rd->next(is_connection_closed);
+            if (first_result == nullptr) {
+                return; // connection is closed
+            } else if (first_result->is_error()) {
+                res_error(res, first_result->to_json());
+                return;
+            } else {
+                GGML_ASSERT(
+                    dynamic_cast<server_task_result_cmpl_partial*>(first_result.get()) != nullptr
+                    || dynamic_cast<server_task_result_cmpl_final*>(first_result.get()) != nullptr
+                );
+            }
+
+            // next responses are streamed
+            json first_result_json = first_result->to_json();
+            const auto chunked_content_provider = [first_result_json, rd, oaicompat](size_t, httplib::DataSink & sink) mutable -> bool {
+                // flush the first result as it's not an error
+                if (!first_result_json.empty()) {
+                    if (!server_sent_event(sink, first_result_json)) {
+                        sink.done();
+                        return false; // sending failed, go to on_complete()
                     }
-                }, [&](const json & error_data) {
-                    server_sent_event(sink, json{{"error", error_data}});
-                }, [&sink]() {
-                    // note: do not use req.is_connection_closed here because req is already destroyed
-                    return !sink.is_writable();
-                });
-                if (oaicompat != OAICOMPAT_TYPE_NONE) {
-                    static const std::string ev_done = "data: [DONE]\n\n";
-                    sink.write(ev_done.data(), ev_done.size());
+                    first_result_json.clear(); // mark as sent
                 }
-                sink.done();
-                return false;
+
+                // receive subsequent results
+                auto result = rd->next([&sink]{ return !sink.is_writable(); });
+                if (result == nullptr) {
+                    sink.done();
+                    return false; // connection is closed, go to on_complete()
+                }
+
+                // send the results
+                json res_json = result->to_json();
+                bool ok = false;
+                if (result->is_error()) {
+                    ok = server_sent_event(sink, json {{ "error", result->to_json() }});
+                    sink.done();
+                    return false; // go to on_complete()
+                } else {
+                    GGML_ASSERT(
+                        dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
+                        || dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
+                    );
+                    ok = server_sent_event(sink, res_json);
+                }
+
+                if (!ok) {
+                    sink.done();
+                    return false; // sending failed, go to on_complete()
+                }
+
+                // check if there is more data
+                if (!rd->has_next()) {
+                    if (oaicompat != OAICOMPAT_TYPE_NONE) {
+                        static const std::string ev_done = "data: [DONE]\n\n";
+                        sink.write(ev_done.data(), ev_done.size());
+                    }
+                    sink.done();
+                    return false; // no more data, go to on_complete()
+                }
+
+                // has next data, continue
+                return true;
             };
 
-            auto on_complete = [task_ids, &ctx_server] (bool) {
-                ctx_server.queue_results.remove_waiting_task_ids(task_ids);
+            auto on_complete = [rd](bool) {
+                rd->stop();
             };
 
             res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
@@ -5401,8 +5433,7 @@ int main(int argc, char ** argv) {
 
         // create and queue the task
         json responses = json::array();
-        bool error = false;
-        std::unordered_set<int> task_ids;
+        server_response_reader rd(ctx_server);
         {
             std::vector<server_task> tasks;
             for (size_t i = 0; i < tokenized_prompts.size(); i++) {
@@ -5418,27 +5449,23 @@ int main(int argc, char ** argv) {
 
                 tasks.push_back(std::move(task));
             }
-
-            task_ids = server_task::get_list_id(tasks);
-            ctx_server.queue_results.add_waiting_tasks(tasks);
-            ctx_server.queue_tasks.post(std::move(tasks));
+            rd.post_tasks(std::move(tasks));
         }
 
-        // get the result
-        ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
-            for (auto & res : results) {
+        // wait for the results
+        auto all_results = rd.wait_for_all(req.is_connection_closed);
+
+        // collect results
+        if (all_results.is_terminated) {
+            return; // connection is closed
+        } else if (all_results.error) {
+            res_error(res, all_results.error->to_json());
+            return;
+        } else {
+            for (auto & res : all_results.results) {
                 GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
                 responses.push_back(res->to_json());
             }
-        }, [&](const json & error_data) {
-            res_error(res, error_data);
-            error = true;
-        }, req.is_connection_closed);
-
-        ctx_server.queue_results.remove_waiting_task_ids(task_ids);
-
-        if (error) {
-            return;
         }
 
         // write JSON response
@@ -5492,8 +5519,7 @@ int main(int argc, char ** argv) {
 
         // create and queue the task
         json responses = json::array();
-        bool error = false;
-        std::unordered_set<int> task_ids;
+        server_response_reader rd(ctx_server);
         {
             std::vector<server_task> tasks;
             tasks.reserve(documents.size());
@@ -5505,24 +5531,23 @@ int main(int argc, char ** argv) {
                 task.tokens = std::move(tmp);
                 tasks.push_back(std::move(task));
             }
-
-            task_ids = server_task::get_list_id(tasks);
-            ctx_server.queue_results.add_waiting_tasks(tasks);
-            ctx_server.queue_tasks.post(std::move(tasks));
+            rd.post_tasks(std::move(tasks));
         }
 
-        ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
-            for (auto & res : results) {
+        // wait for the results
+        auto all_results = rd.wait_for_all(req.is_connection_closed);
+
+        // collect results
+        if (all_results.is_terminated) {
+            return; // connection is closed
+        } else if (all_results.error) {
+            res_error(res, all_results.error->to_json());
+            return;
+        } else {
+            for (auto & res : all_results.results) {
                 GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
                 responses.push_back(res->to_json());
             }
-        }, [&](const json & error_data) {
-            res_error(res, error_data);
-            error = true;
-        }, req.is_connection_closed);
-
-        if (error) {
-            return;
         }
 
         // write JSON response

+ 20 - 6
tools/server/utils.hpp

@@ -453,15 +453,29 @@ static std::string tokens_to_output_formatted_string(const llama_context * ctx,
     return out;
 }
 
+// note: if data is a json array, it will be sent as multiple events, one per item
 static bool server_sent_event(httplib::DataSink & sink, const json & data) {
-    const std::string str =
-        "data: " +
-        data.dump(-1, ' ', false, json::error_handler_t::replace) +
-        "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
+    static auto send_single = [](httplib::DataSink & sink, const json & data) -> bool {
+        const std::string str =
+            "data: " +
+            data.dump(-1, ' ', false, json::error_handler_t::replace) +
+            "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
+
+        LOG_DBG("data stream, to_send: %s", str.c_str());
+        return sink.write(str.c_str(), str.size());
+    };
 
-    LOG_DBG("data stream, to_send: %s", str.c_str());
+    if (data.is_array()) {
+        for (const auto & item : data) {
+            if (!send_single(sink, item)) {
+                return false;
+            }
+        }
+    } else {
+        return send_single(sink, data);
+    }
 
-    return sink.write(str.c_str(), str.size());
+    return true;
 }
 
 //