Просмотр исходного кода

server: improve slots scheduling for n_cmpl (#18789)

* server : make sure children tasks are scheduled to launch with parent

* fix

* add comment pointing to this PR

* fix

* clean up

* more debug messages

* add pop_deferred_task with specific ID version

* improve the logic

* simple approach

* no double move

* correct return type of launch_slots_with_parent_task
Xuan-Son Nguyen 2 недель назад
Родитель
Сommit
a04c2b06a3

+ 115 - 74
tools/server/server-context.cpp

@@ -158,7 +158,7 @@ struct server_slot {
     double t_prompt_processing; // ms
     double t_prompt_processing; // ms
     double t_token_generation;  // ms
     double t_token_generation;  // ms
 
 
-    std::function<void(int)> callback_on_release;
+    std::function<void(int /* slot_id */)> callback_on_release;
 
 
     // Speculative decoding stats
     // Speculative decoding stats
     int32_t n_draft_total = 0;      // Total draft tokens generated
     int32_t n_draft_total = 0;      // Total draft tokens generated
@@ -298,17 +298,6 @@ struct server_slot {
         return n_draft_max;
         return n_draft_max;
     }
     }
 
 
-    // note: a slot can also be either a parent or a child
-    // TODO: move to server_task
-    bool is_parent() const {
-        return task->n_children > 0;
-    }
-
-    // TODO: move to server_task
-    bool is_child() const {
-        return task->id_parent >= 0;
-    }
-
     void release() {
     void release() {
         if (is_processing()) {
         if (is_processing()) {
             GGML_ASSERT(task);
             GGML_ASSERT(task);
@@ -321,7 +310,7 @@ struct server_slot {
             state = SLOT_STATE_IDLE;
             state = SLOT_STATE_IDLE;
 
 
             // do not keep context of the child slots - the parent's context is enough
             // do not keep context of the child slots - the parent's context is enough
-            if (is_child()) {
+            if (task->is_child()) {
                 prompt_clear(false);
                 prompt_clear(false);
             }
             }
 
 
@@ -805,8 +794,8 @@ private:
 
 
             SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
             SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
 
 
-            slot.callback_on_release = [this](int) {
-                queue_tasks.pop_deferred_task();
+            slot.callback_on_release = [this](int slot_id) {
+                queue_tasks.pop_deferred_task(slot_id);
             };
             };
 
 
             slot.reset();
             slot.reset();
@@ -920,9 +909,9 @@ private:
         return true;
         return true;
     }
     }
 
 
-    server_slot * get_slot_by_id(int id) {
+    server_slot * get_slot_by_id(int id_slot) {
         for (server_slot & slot : slots) {
         for (server_slot & slot : slots) {
-            if (slot.id == id) {
+            if (slot.id == id_slot) {
                 return &slot;
                 return &slot;
             }
             }
         }
         }
@@ -1196,12 +1185,11 @@ private:
 
 
         slot.task = std::make_unique<const server_task>(std::move(task));
         slot.task = std::make_unique<const server_task>(std::move(task));
 
 
-        slot.state = slot.is_child()
+        slot.state = slot.task->is_child()
             ? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt
             ? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt
             : SLOT_STATE_STARTED;
             : SLOT_STATE_STARTED;
 
 
-        SLT_INF(slot, "processing task, is_child = %d\n", slot.is_child());
-
+        SLT_INF(slot, "processing task, is_child = %d\n", slot.task->is_child());
         return true;
         return true;
     }
     }
 
 
@@ -1596,9 +1584,7 @@ private:
 
 
     // tokenize the input if it's set by CLI, return false on error
     // tokenize the input if it's set by CLI, return false on error
     bool tokenize_cli_input(server_task & task) {
     bool tokenize_cli_input(server_task & task) {
-        if (task.cli_input == nullptr) {
-            return true; // nothing to do
-        }
+        GGML_ASSERT(task.cli_input != nullptr);
         try {
         try {
             auto & opt = oai_parser_opt;
             auto & opt = oai_parser_opt;
             common_chat_templates_inputs inputs;
             common_chat_templates_inputs inputs;
@@ -1632,6 +1618,64 @@ private:
         return true;
         return true;
     }
     }
 
 
+    std::vector<server_slot *> get_free_slots(size_t n_slots_needed, int exclude_id_slot) {
+        std::vector<server_slot *> free_slots;
+        for (auto & slot : slots) {
+            if (!slot.is_processing() && slot.id != exclude_id_slot) {
+                free_slots.push_back(&slot);
+            }
+            if (free_slots.size() >= n_slots_needed) {
+                break;
+            }
+        }
+        return free_slots;
+    }
+
+    // launch multiple slots for parent + child tasks
+    bool launch_slots_with_parent_task(server_slot & parent_slot, std::vector<server_slot *> & child_slots, server_task && parent_task) {
+        GGML_ASSERT(!parent_slot.is_processing());
+        GGML_ASSERT(parent_task.is_parent());
+        GGML_ASSERT(child_slots.size() == parent_task.child_tasks.size());
+
+        int id_parent = parent_task.id;
+
+        SRV_INF("launching slots for parent task id_task = %d with %zu child tasks\n", id_parent, parent_task.child_tasks.size());
+
+        // to be called in case of failure to release all launched slots
+        auto release_slots = [this, id_parent]() {
+            for (auto & slot : slots) {
+                if (slot.is_processing() && (
+                        slot.task->id == id_parent ||
+                        slot.task->id_parent == id_parent
+                )) {
+                    slot.release();
+                }
+            }
+        };
+
+        // launch all child tasks first
+        size_t idx = 0;
+        GGML_ASSERT(child_slots.size() == parent_task.child_tasks.size());
+        for (auto * slot : child_slots) {
+            int id_child = parent_task.child_tasks[idx].id;
+            if (!launch_slot_with_task(*slot, std::move(parent_task.child_tasks[idx]))) {
+                SRV_ERR("failed to launch slot with child task, id_task = %d\n", id_child);
+                release_slots();
+                return false;
+            }
+            idx++;
+        }
+
+        // finally, launch the parent task
+        if (!launch_slot_with_task(parent_slot, std::move(parent_task))) {
+            SRV_ERR("failed to launch slot with task, id_task = %d\n", id_parent);
+            release_slots();
+            return false;
+        }
+
+        return true;
+    }
+
     void process_single_task(server_task && task) {
     void process_single_task(server_task && task) {
         switch (task.type) {
         switch (task.type) {
             case SERVER_TASK_TYPE_COMPLETION:
             case SERVER_TASK_TYPE_COMPLETION:
@@ -1639,31 +1683,55 @@ private:
             case SERVER_TASK_TYPE_EMBEDDING:
             case SERVER_TASK_TYPE_EMBEDDING:
             case SERVER_TASK_TYPE_RERANK:
             case SERVER_TASK_TYPE_RERANK:
                 {
                 {
-                    if (!tokenize_cli_input(task)) {
-                        break;
+                    // special case: if input is provided via CLI, tokenize it first
+                    // otherwise, no need to tokenize as it's already done inside the HTTP thread
+                    if (task.cli_input != nullptr) {
+                        if (!tokenize_cli_input(task)) {
+                            break;
+                        }
                     }
                     }
 
 
                     const int id_slot = task.id_slot;
                     const int id_slot = task.id_slot;
+                    const int id_task = task.id;
+
+                    server_slot * slot = id_slot != -1
+                                            ? get_slot_by_id(id_slot)
+                                            : get_available_slot(task);
 
 
-                    server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
+                    //
+                    // slot scheduling logic
+                    //
 
 
                     if (slot == nullptr) {
                     if (slot == nullptr) {
                         // if no slot is available, we defer this task for processing later
                         // if no slot is available, we defer this task for processing later
-                        SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id);
+                        SRV_DBG("no slot is available, defer task, id_task = %d\n", id_task);
                         queue_tasks.defer(std::move(task));
                         queue_tasks.defer(std::move(task));
                         break;
                         break;
                     }
                     }
 
 
                     if (slot->is_processing()) {
                     if (slot->is_processing()) {
                         // if requested slot is unavailable, we defer this task for processing later
                         // if requested slot is unavailable, we defer this task for processing later
-                        SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
+                        SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", id_task);
                         queue_tasks.defer(std::move(task));
                         queue_tasks.defer(std::move(task));
                         break;
                         break;
                     }
                     }
 
 
-                    if (!launch_slot_with_task(*slot, std::move(task))) {
-                        SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
-                        break;
+                    if (task.is_parent()) {
+                        // try getting free slots for all child tasks
+                        size_t n_child_tasks = task.child_tasks.size();
+                        std::vector<server_slot *> child_slots = get_free_slots(n_child_tasks, slot->id);
+                        if (child_slots.size() < n_child_tasks) {
+                            SRV_DBG("not enough free slots for child tasks, n_free = %zu, n_children = %zu, defer task, id_task = %d\n", child_slots.size(), n_child_tasks, id_task);
+                            queue_tasks.defer(std::move(task));
+                            break;
+                        }
+                        if (!launch_slots_with_parent_task(*slot, child_slots, std::move(task))) {
+                            SRV_ERR("failed to launch slot with parent task, id_task = %d\n", id_task);
+                            break; // drop the task
+                        }
+                    } else if (!launch_slot_with_task(*slot, std::move(task))) {
+                        SRV_ERR("failed to launch slot with task, id_task = %d\n", id_task);
+                        break; // drop the task
                     }
                     }
                 } break;
                 } break;
             case SERVER_TASK_TYPE_CANCEL:
             case SERVER_TASK_TYPE_CANCEL:
@@ -1932,7 +2000,7 @@ private:
                     GGML_ABORT("not supported by multimodal");
                     GGML_ABORT("not supported by multimodal");
                 }
                 }
 
 
-                if (slot.is_parent() || slot.is_child()) {
+                if (slot.task->is_parent() || slot.task->is_child()) {
                     send_error(slot, "context shift cannot be used for shared prompt", ERROR_TYPE_SERVER);
                     send_error(slot, "context shift cannot be used for shared prompt", ERROR_TYPE_SERVER);
                     slot.release();
                     slot.release();
                     continue;
                     continue;
@@ -2079,21 +2147,6 @@ private:
 
 
                 // this slot still has a prompt to be processed
                 // this slot still has a prompt to be processed
                 if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
                 if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
-                    // wait for all children to be launched
-                    if (slot.is_parent()) {
-                        int n_launched = 0;
-                        for (auto & other : slots) {
-                            if (other.is_processing() && other.is_child() && other.task->id_parent == slot.task->id) {
-                                ++n_launched;
-                            }
-                        }
-
-                        if (n_launched < slot.task->n_children) {
-                            SLT_DBG(slot, "waiting for children to be launched, n_children = %d, n_launched = %d\n", slot.task->n_children, n_launched);
-                            continue;
-                        }
-                    }
-
                     const auto & input_tokens = slot.task->tokens;
                     const auto & input_tokens = slot.task->tokens;
 
 
                     // TODO: maybe move branch to outside of this loop in the future
                     // TODO: maybe move branch to outside of this loop in the future
@@ -2647,9 +2700,7 @@ private:
 
 
             // handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too
             // handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too
             for (auto & slot : slots) {
             for (auto & slot : slots) {
-                if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) {
-                    SLT_INF(slot, "parent task prompt done, n_children = %d\n", slot.task->n_children);
-
+                if (slot.state == SLOT_STATE_DONE_PROMPT && slot.task->is_parent()) {
                     std::vector<server_slot *> children;
                     std::vector<server_slot *> children;
                     for (auto & other : slots) {
                     for (auto & other : slots) {
                         if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) {
                         if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) {
@@ -2657,17 +2708,15 @@ private:
                         }
                         }
                     }
                     }
 
 
-                    // we can only proceed if all child slots are having the correct tasks
-                    if (slot.task->n_children == (int) children.size()) {
-                        // copy state to the child slots
-                        for (auto & child : children) {
-                            SLT_INF(slot, " - copying state to child %d\n", child->id);
+                    // all children slots should already launched by launch_slots_with_parent_task()
+                    // copy state to the child slots
+                    for (auto & child : children) {
+                        SLT_INF(slot, " - copying state to child %d\n", child->id);
 
 
-                            GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER);
+                        GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER);
 
 
-                            slot.copy_state_to(*child);
-                            child->state = SLOT_STATE_DONE_PROMPT;
-                        }
+                        slot.copy_state_to(*child);
+                        child->state = SLOT_STATE_DONE_PROMPT;
                     }
                     }
                 }
                 }
             }
             }
@@ -2943,7 +2992,9 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
             // Everything else, including multimodal completions.
             // Everything else, including multimodal completions.
             inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
             inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
         }
         }
-        tasks.reserve(inputs.size());
+
+        // tasks.reserve(inputs.size()); // TODO: this is inaccurate due to child tasks
+
         for (size_t i = 0; i < inputs.size(); i++) {
         for (size_t i = 0; i < inputs.size(); i++) {
             server_task task = server_task(type);
             server_task task = server_task(type);
 
 
@@ -2964,23 +3015,13 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
 
 
             // prepare child tasks
             // prepare child tasks
             if (task.params.n_cmpl > 1) {
             if (task.params.n_cmpl > 1) {
-                task.n_children = task.params.n_cmpl - 1;
-
-                for (int j = 0; j < task.n_children; j++) {
-                    server_task child = task.create_child(task.id, rd.get_new_id());
-
-                    // use different sampling seed for each child
-                    // note: https://github.com/ggml-org/llama.cpp/pull/18700#discussion_r2675115723
-                    if (child.params.sampling.seed != LLAMA_DEFAULT_SEED) {
-                        child.params.sampling.seed += j + 1;
-                    }
-
-                    tasks.push_back(std::move(child));
+                int n_children = task.params.n_cmpl - 1;
+                for (int j = 0; j < n_children; j++) {
+                    task.add_child(task.id, rd.get_new_id());
                 }
                 }
             }
             }
 
 
-            // note: the parent task always launches first
-            tasks.insert(tasks.begin(), std::move(task));
+            tasks.push_back(std::move(task));
         }
         }
 
 
         rd.post_tasks(std::move(tasks));
         rd.post_tasks(std::move(tasks));

+ 34 - 11
tools/server/server-queue.cpp

@@ -74,11 +74,26 @@ int server_queue::get_new_id() {
     return new_id;
     return new_id;
 }
 }
 
 
-void server_queue::pop_deferred_task() {
+void server_queue::pop_deferred_task(int id_slot) {
     std::unique_lock<std::mutex> lock(mutex_tasks);
     std::unique_lock<std::mutex> lock(mutex_tasks);
     if (!queue_tasks_deferred.empty()) {
     if (!queue_tasks_deferred.empty()) {
-        queue_tasks.emplace_front(std::move(queue_tasks_deferred.front()));
-        queue_tasks_deferred.pop_front();
+        // try to find a task that uses the specified slot
+        bool found = false;
+        for (auto it = queue_tasks_deferred.begin(); it != queue_tasks_deferred.end(); ++it) {
+            if (it->id_slot == id_slot) {
+                QUE_DBG("pop deferred task (use slot %d), id_task = %d\n", id_slot, it->id);
+                queue_tasks.emplace_front(std::move(*it));
+                queue_tasks_deferred.erase(it);
+                found = true;
+                break;
+            }
+        }
+        // if not tasks found using the slot, just pop the first deferred task (default behavior)
+        if (!found) {
+            QUE_DBG("pop deferred task, id_task = %d\n", queue_tasks_deferred.front().id);
+            queue_tasks.emplace_front(std::move(queue_tasks_deferred.front()));
+            queue_tasks_deferred.pop_front();
+        }
     }
     }
     time_last_task = ggml_time_ms();
     time_last_task = ggml_time_ms();
     condition_tasks.notify_one();
     condition_tasks.notify_one();
@@ -217,12 +232,12 @@ void server_response::add_waiting_task_id(int id_task) {
     waiting_task_ids.insert(id_task);
     waiting_task_ids.insert(id_task);
 }
 }
 
 
-void server_response::add_waiting_tasks(const std::vector<server_task> & tasks) {
+void server_response::add_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
     std::unique_lock<std::mutex> lock(mutex_results);
     std::unique_lock<std::mutex> lock(mutex_results);
 
 
-    for (const auto & task : tasks) {
-        RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size());
-        waiting_task_ids.insert(task.id);
+    for (const auto & id_task : id_tasks) {
+        RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
+        waiting_task_ids.insert(id_task);
     }
     }
 }
 }
 
 
@@ -327,6 +342,7 @@ void server_response::terminate() {
 
 
 void server_response_reader::post_task(server_task && task, bool front) {
 void server_response_reader::post_task(server_task && task, bool front) {
     GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader");
     GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader");
+    GGML_ASSERT(!task.is_parent() && "not supported, use post_tasks() instead");
     task.index = 0;
     task.index = 0;
     id_tasks.insert(task.id);
     id_tasks.insert(task.id);
     states.push_back(task.create_state());
     states.push_back(task.create_state());
@@ -338,11 +354,18 @@ void server_response_reader::post_tasks(std::vector<server_task> && tasks, bool
     GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader");
     GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader");
     id_tasks = server_task::get_list_id(tasks);
     id_tasks = server_task::get_list_id(tasks);
     states.reserve(tasks.size());
     states.reserve(tasks.size());
-    for (size_t i = 0; i < tasks.size(); i++) {
-        tasks[i].index = i;
-        states.push_back(tasks[i].create_state());
+    size_t index = 0;
+    for (auto & task : tasks) {
+        task.index = index++;
+        states.push_back(task.create_state());
+        // for child tasks
+        for (auto & child_task : task.child_tasks) {
+            child_task.index = index++;
+            states.push_back(child_task.create_state());
+        }
     }
     }
-    queue_results.add_waiting_tasks(tasks);
+    GGML_ASSERT(states.size() == id_tasks.size());
+    queue_results.add_waiting_task_ids(id_tasks);
     queue_tasks.post(std::move(tasks), front);
     queue_tasks.post(std::move(tasks), front);
 }
 }
 
 

+ 3 - 2
tools/server/server-queue.h

@@ -44,7 +44,8 @@ public:
     int get_new_id();
     int get_new_id();
 
 
     // Call when the state of one slot is changed, it will move one task from deferred to main queue
     // Call when the state of one slot is changed, it will move one task from deferred to main queue
-    void pop_deferred_task();
+    // prioritize tasks that use the specified slot (otherwise, pop the first deferred task)
+    void pop_deferred_task(int id_slot);
 
 
     // if sleeping, request exiting sleep state and wait until it is done
     // if sleeping, request exiting sleep state and wait until it is done
     // returns immediately if not sleeping
     // returns immediately if not sleeping
@@ -124,7 +125,7 @@ public:
     // add the id_task to the list of tasks waiting for response
     // add the id_task to the list of tasks waiting for response
     void add_waiting_task_id(int id_task);
     void add_waiting_task_id(int id_task);
 
 
-    void add_waiting_tasks(const std::vector<server_task> & tasks);
+    void add_waiting_task_ids(const std::unordered_set<int> & id_tasks);
 
 
     // when the request is finished, we can remove task associated with it
     // when the request is finished, we can remove task associated with it
     void remove_waiting_task_id(int id_task);
     void remove_waiting_task_id(int id_task);

+ 23 - 3
tools/server/server-task.h

@@ -121,8 +121,10 @@ struct server_task {
     int id_slot   = -1;
     int id_slot   = -1;
 
 
     // used by parallel sampling (multiple completions from same prompt)
     // used by parallel sampling (multiple completions from same prompt)
-    int n_children =  0; // number of tasks reusing this prompt
     int id_parent  = -1;
     int id_parent  = -1;
+    // temporary store of child tasks for scheduling
+    // note: accessing to elements is invalid after the task is moved to server_slot
+    std::vector<server_task> child_tasks;
 
 
     // used by SERVER_TASK_TYPE_INFERENCE
     // used by SERVER_TASK_TYPE_INFERENCE
     task_params   params;
     task_params   params;
@@ -197,11 +199,14 @@ struct server_task {
         std::unordered_set<int> ids(tasks.size());
         std::unordered_set<int> ids(tasks.size());
         for (size_t i = 0; i < tasks.size(); i++) {
         for (size_t i = 0; i < tasks.size(); i++) {
             ids.insert(tasks[i].id);
             ids.insert(tasks[i].id);
+            for (auto & child : tasks[i].child_tasks) {
+                ids.insert(child.id);
+            }
         }
         }
         return ids;
         return ids;
     }
     }
 
 
-    server_task create_child(int id_parent, int id_child) const {
+    void add_child(int id_parent, int id_child) {
         server_task copy;
         server_task copy;
 
 
         copy.id        = id_child;
         copy.id        = id_child;
@@ -209,8 +214,15 @@ struct server_task {
         copy.params    = params;
         copy.params    = params;
         copy.type      = type;
         copy.type      = type;
         copy.tokens    = tokens.clone();
         copy.tokens    = tokens.clone();
+        copy.id_slot   = -1; // child tasks cannot specify slot
+
+        // use different sampling seed for each child
+        // note: https://github.com/ggml-org/llama.cpp/pull/18700#discussion_r2675115723
+        if (copy.params.sampling.seed != LLAMA_DEFAULT_SEED) {
+            copy.params.sampling.seed += (uint32_t)child_tasks.size() + 1;
+        }
 
 
-        return copy;
+        child_tasks.push_back(std::move(copy));
     }
     }
 
 
     // the task will be moved into queue, then onto slots
     // the task will be moved into queue, then onto slots
@@ -218,6 +230,14 @@ struct server_task {
     task_result_state create_state() const {
     task_result_state create_state() const {
         return task_result_state(params.oaicompat_chat_syntax);
         return task_result_state(params.oaicompat_chat_syntax);
     }
     }
+
+    bool is_parent() const {
+        return child_tasks.size() > 0;
+    }
+
+    bool is_child() const {
+        return id_parent != -1;
+    }
 };
 };
 
 
 struct result_timings {
 struct result_timings {

+ 19 - 13
tools/server/tests/unit/test_chat_completion.py

@@ -491,16 +491,22 @@ def test_return_progress(n_batch, batch_count, reuse_cache):
 def test_chat_completions_multiple_choices():
 def test_chat_completions_multiple_choices():
     global server
     global server
     server.start()
     server.start()
-    res = server.make_request("POST", "/chat/completions", data={
-        "max_tokens": 8,
-        "n": 2,
-        "messages": [
-            {"role": "system", "content": "Book"},
-            {"role": "user", "content": "What is the best book"},
-        ],
-    })
-    assert res.status_code == 200
-    assert len(res.body["choices"]) == 2
-    for choice in res.body["choices"]:
-        assert "assistant" == choice["message"]["role"]
-        assert choice["finish_reason"] == "length"
+    # make sure cache can be reused across multiple choices and multiple requests
+    # ref: https://github.com/ggml-org/llama.cpp/pull/18663
+    for _ in range(2):
+        res = server.make_request("POST", "/chat/completions", data={
+            "max_tokens": 8,
+            "n": 2,
+            "messages": [
+                {"role": "system", "content": "Book"},
+                {"role": "user", "content": "What is the best book"},
+            ],
+            # test forcing the same slot to be used
+            # the scheduler should not be locked up in this case
+            "id_slot": 0,
+        })
+        assert res.status_code == 200
+        assert len(res.body["choices"]) == 2
+        for choice in res.body["choices"]:
+            assert "assistant" == choice["message"]["role"]
+            assert choice["finish_reason"] == "length"