Przeglądaj źródła

server : implement cancellable request (#11285)

* server : implement cancellable request

* fix typo

* httplib 0.18.5

* fix i underflow
Xuan Son Nguyen 1 rok temu
rodzic
commit
f30f099228

Plik diff jest za duży
+ 406 - 109
examples/server/httplib.h


+ 65 - 8
examples/server/server.cpp

@@ -19,6 +19,7 @@
 #include "loading.html.hpp"
 
 #include <atomic>
+#include <chrono>
 #include <condition_variable>
 #include <cstddef>
 #include <cinttypes>
@@ -32,6 +33,8 @@
 
 using json = nlohmann::ordered_json;
 
+constexpr int HTTP_POLLING_SECONDS = 1;
+
 enum stop_type {
     STOP_TYPE_NONE,
     STOP_TYPE_EOS,
@@ -1602,6 +1605,30 @@ struct server_response {
         // should never reach here
     }
 
+    // same as recv(), but have timeout in seconds
+    // if timeout is reached, nullptr is returned
+    server_task_result_ptr recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout) {
+        while (true) {
+            std::unique_lock<std::mutex> lock(mutex_results);
+            bool cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout), [&]{
+                return !queue_results.empty();
+            });
+            if (!cr_res) {
+                return nullptr;
+            }
+
+            for (int i = 0; i < (int) queue_results.size(); i++) {
+                if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
+                    server_task_result_ptr res = std::move(queue_results[i]);
+                    queue_results.erase(queue_results.begin() + i);
+                    return res;
+                }
+            }
+        }
+
+        // should never reach here
+    }
+
     // single-task version of recv()
     server_task_result_ptr recv(int id_task) {
         std::unordered_set<int> id_tasks = {id_task};
@@ -2322,10 +2349,21 @@ struct server_context {
     void receive_multi_results(
             const std::unordered_set<int> & id_tasks,
             const std::function<void(std::vector<server_task_result_ptr>&)> & result_handler,
-            const std::function<void(json)> & error_handler) {
+            const std::function<void(json)> & error_handler,
+            const std::function<bool()> & is_connection_closed) {
         std::vector<server_task_result_ptr> results(id_tasks.size());
-        for (size_t i = 0; i < id_tasks.size(); i++) {
-            server_task_result_ptr result = queue_results.recv(id_tasks);
+        for (int i = 0; i < (int)id_tasks.size(); i++) {
+            server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
+
+            if (is_connection_closed()) {
+                cancel_tasks(id_tasks);
+                return;
+            }
+
+            if (result == nullptr) {
+                i--; // retry
+                continue;
+            }
 
             if (result->is_error()) {
                 error_handler(result->to_json());
@@ -2349,10 +2387,20 @@ struct server_context {
     void receive_cmpl_results_stream(
             const std::unordered_set<int> & id_tasks,
             const std::function<bool(server_task_result_ptr&)> & result_handler,
-            const std::function<void(json)> & error_handler) {
+            const std::function<void(json)> & error_handler,
+            const std::function<bool()> & is_connection_closed) {
         size_t n_finished = 0;
         while (true) {
-            server_task_result_ptr result = queue_results.recv(id_tasks);
+            server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
+
+            if (is_connection_closed()) {
+                cancel_tasks(id_tasks);
+                return;
+            }
+
+            if (result == nullptr) {
+                continue; // retry
+            }
 
             if (result->is_error()) {
                 error_handler(result->to_json());
@@ -3633,6 +3681,7 @@ int main(int argc, char ** argv) {
     const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok](
             server_task_type type,
             json & data,
+            std::function<bool()> is_connection_closed,
             httplib::Response & res,
             oaicompat_type oaicompat) {
         GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
@@ -3694,7 +3743,7 @@ int main(int argc, char ** argv) {
                 }
             }, [&](const json & error_data) {
                 res_error(res, error_data);
-            });
+            }, is_connection_closed);
 
             ctx_server.queue_results.remove_waiting_task_ids(task_ids);
         } else {
@@ -3704,6 +3753,7 @@ int main(int argc, char ** argv) {
                     if (res_json.is_array()) {
                         for (const auto & res : res_json) {
                             if (!server_sent_event(sink, "data", res)) {
+                                // sending failed (HTTP connection closed), cancel the generation
                                 return false;
                             }
                         }
@@ -3713,6 +3763,9 @@ int main(int argc, char ** argv) {
                     }
                 }, [&](const json & error_data) {
                     server_sent_event(sink, "error", error_data);
+                }, [&sink]() {
+                    // note: do not use req.is_connection_closed here because req is already destroyed
+                    return !sink.is_writable();
                 });
                 if (oaicompat != OAICOMPAT_TYPE_NONE) {
                     static const std::string ev_done = "data: [DONE]\n\n";
@@ -3735,6 +3788,7 @@ int main(int argc, char ** argv) {
         return handle_completions_impl(
             SERVER_TASK_TYPE_COMPLETION,
             data,
+            req.is_connection_closed,
             res,
             OAICOMPAT_TYPE_NONE);
     };
@@ -3744,6 +3798,7 @@ int main(int argc, char ** argv) {
         return handle_completions_impl(
             SERVER_TASK_TYPE_COMPLETION,
             data,
+            req.is_connection_closed,
             res,
             OAICOMPAT_TYPE_COMPLETION);
     };
@@ -3820,6 +3875,7 @@ int main(int argc, char ** argv) {
         return handle_completions_impl(
             SERVER_TASK_TYPE_INFILL,
             data,
+            req.is_connection_closed,
             res,
             OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
     };
@@ -3834,6 +3890,7 @@ int main(int argc, char ** argv) {
         return handle_completions_impl(
             SERVER_TASK_TYPE_COMPLETION,
             data,
+            req.is_connection_closed,
             res,
             OAICOMPAT_TYPE_CHAT);
     };
@@ -3980,7 +4037,7 @@ int main(int argc, char ** argv) {
             }, [&](const json & error_data) {
                 res_error(res, error_data);
                 error = true;
-            });
+            }, req.is_connection_closed);
 
             ctx_server.queue_results.remove_waiting_task_ids(task_ids);
         }
@@ -4070,7 +4127,7 @@ int main(int argc, char ** argv) {
             }, [&](const json & error_data) {
                 res_error(res, error_data);
                 error = true;
-            });
+            }, req.is_connection_closed);
         }
 
         if (error) {

+ 21 - 0
examples/server/tests/unit/test_completion.py

@@ -1,4 +1,5 @@
 import pytest
+import requests
 import time
 from openai import OpenAI
 from utils import *
@@ -405,3 +406,23 @@ def test_n_probs_post_sampling():
             assert "bytes" in prob and type(prob["bytes"]) == list
         # because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
         assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
+
+
+def test_cancel_request():
+    global server
+    server.n_ctx = 4096
+    server.n_predict = -1
+    server.n_slots = 1
+    server.server_slots = True
+    server.start()
+    # send a request that will take a long time, but cancel it before it finishes
+    try:
+        server.make_request("POST", "/completion", data={
+            "prompt": "I believe the meaning of life is",
+        }, timeout=0.1)
+    except requests.exceptions.ReadTimeout:
+        pass # expected
+    # make sure the slot is free
+    time.sleep(1) # wait for HTTP_POLLING_SECONDS
+    res = server.make_request("GET", "/slots")
+    assert res.body[0]["is_processing"] == False

+ 4 - 3
examples/server/tests/utils.py

@@ -219,17 +219,18 @@ class ServerProcess:
         path: str,
         data: dict | Any | None = None,
         headers: dict | None = None,
+        timeout: float | None = None,
     ) -> ServerResponse:
         url = f"http://{self.server_host}:{self.server_port}{path}"
         parse_body = False
         if method == "GET":
-            response = requests.get(url, headers=headers)
+            response = requests.get(url, headers=headers, timeout=timeout)
             parse_body = True
         elif method == "POST":
-            response = requests.post(url, headers=headers, json=data)
+            response = requests.post(url, headers=headers, json=data, timeout=timeout)
             parse_body = True
         elif method == "OPTIONS":
-            response = requests.options(url, headers=headers)
+            response = requests.options(url, headers=headers, timeout=timeout)
         else:
             raise ValueError(f"Unimplemented method: {method}")
         result = ServerResponse()

Niektóre pliki nie zostały wyświetlone z powodu dużej ilości zmienionych plików