Просмотр исходного кода

server : handle closed connection for tasks (#18459)

Georgi Gerganov 1 месяц назад
Родитель
Сommit
2a85f720b8
1 измененных файлов с 51 добавлено и 12 удалено
  1. 51 12
      tools/server/server-context.cpp

+ 51 - 12
tools/server/server-context.cpp

@@ -2960,19 +2960,22 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
         // in streaming mode, the first error must be treated as non-stream response
         // in streaming mode, the first error must be treated as non-stream response
         // this is to match the OAI API behavior
         // this is to match the OAI API behavior
         // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309
         // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309
-        server_task_result_ptr first_result = rd.next(req.should_stop);
+        auto first_result = rd.next(req.should_stop);
         if (first_result == nullptr) {
         if (first_result == nullptr) {
+            GGML_ASSERT(req.should_stop());
             return res; // connection is closed
             return res; // connection is closed
-        } else if (first_result->is_error()) {
+        }
+
+        if (first_result->is_error()) {
             res->error(first_result->to_json());
             res->error(first_result->to_json());
             return res;
             return res;
-        } else {
-            GGML_ASSERT(
-                dynamic_cast<server_task_result_cmpl_partial*>(first_result.get()) != nullptr
-                || dynamic_cast<server_task_result_cmpl_final*>(first_result.get()) != nullptr
-            );
         }
         }
 
 
+        GGML_ASSERT(
+            dynamic_cast<server_task_result_cmpl_partial*>(first_result.get()) != nullptr ||
+            dynamic_cast<server_task_result_cmpl_final*>  (first_result.get()) != nullptr
+        );
+
         // next responses are streamed
         // next responses are streamed
         // to be sent immediately
         // to be sent immediately
         json first_result_json = first_result->to_json();
         json first_result_json = first_result->to_json();
@@ -3028,6 +3031,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
                 auto result = rd.next(req.should_stop);
                 auto result = rd.next(req.should_stop);
                 if (result == nullptr) {
                 if (result == nullptr) {
                     SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
                     SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
+                    GGML_ASSERT(req.should_stop());
                     return false; // should_stop condition met
                     return false; // should_stop condition met
                 }
                 }
 
 
@@ -3111,6 +3115,11 @@ void server_routes::init_routes() {
 
 
         // get the result
         // get the result
         auto result = res->rd.next(req.should_stop);
         auto result = res->rd.next(req.should_stop);
+        if (!result) {
+            // connection was closed
+            GGML_ASSERT(req.should_stop());
+            return res;
+        }
 
 
         if (result->is_error()) {
         if (result->is_error()) {
             res->error(result->to_json());
             res->error(result->to_json());
@@ -3211,6 +3220,11 @@ void server_routes::init_routes() {
 
 
         // get the result
         // get the result
         auto result = res->rd.next(req.should_stop);
         auto result = res->rd.next(req.should_stop);
+        if (!result) {
+            // connection was closed
+            GGML_ASSERT(req.should_stop());
+            return res;
+        }
 
 
         if (result->is_error()) {
         if (result->is_error()) {
             res->error(result->to_json());
             res->error(result->to_json());
@@ -3717,7 +3731,12 @@ void server_routes::init_routes() {
         }
         }
 
 
         // get the result
         // get the result
-        server_task_result_ptr result = rd.next(req.should_stop);
+        auto result = rd.next(req.should_stop);
+        if (!result) {
+            // connection was closed
+            GGML_ASSERT(req.should_stop());
+            return res;
+        }
 
 
         if (result->is_error()) {
         if (result->is_error()) {
             res->error(result->to_json());
             res->error(result->to_json());
@@ -3746,7 +3765,12 @@ void server_routes::init_routes() {
         }
         }
 
 
         // get the result
         // get the result
-        server_task_result_ptr result = rd.next(req.should_stop);
+        auto result = rd.next(req.should_stop);
+        if (!result) {
+            // connection was closed
+            GGML_ASSERT(req.should_stop());
+            return res;
+        }
 
 
         if (result->is_error()) {
         if (result->is_error()) {
             res->error(result->to_json());
             res->error(result->to_json());
@@ -3779,7 +3803,12 @@ std::unique_ptr<server_res_generator> server_routes::handle_slots_save(const ser
         rd.post_task(std::move(task));
         rd.post_task(std::move(task));
     }
     }
 
 
-    server_task_result_ptr result = rd.next(req.should_stop);
+    auto result = rd.next(req.should_stop);
+    if (!result) {
+        // connection was closed
+        GGML_ASSERT(req.should_stop());
+        return res;
+    }
 
 
     if (result->is_error()) {
     if (result->is_error()) {
         res->error(result->to_json());
         res->error(result->to_json());
@@ -3810,7 +3839,12 @@ std::unique_ptr<server_res_generator> server_routes::handle_slots_restore(const
         rd.post_task(std::move(task));
         rd.post_task(std::move(task));
     }
     }
 
 
-    server_task_result_ptr result = rd.next(req.should_stop);
+    auto result = rd.next(req.should_stop);
+    if (!result) {
+        // connection was closed
+        GGML_ASSERT(req.should_stop());
+        return res;
+    }
 
 
     if (result->is_error()) {
     if (result->is_error()) {
         res->error(result->to_json());
         res->error(result->to_json());
@@ -3832,7 +3866,12 @@ std::unique_ptr<server_res_generator> server_routes::handle_slots_erase(const se
         rd.post_task(std::move(task));
         rd.post_task(std::move(task));
     }
     }
 
 
-    server_task_result_ptr result = rd.next(req.should_stop);
+    auto result = rd.next(req.should_stop);
+    if (!result) {
+        // connection was closed
+        GGML_ASSERT(req.should_stop());
+        return res;
+    }
 
 
     if (result->is_error()) {
     if (result->is_error()) {
         res->error(result->to_json());
         res->error(result->to_json());