소스 검색

server: delegate result_state creation to server_task (#17835)

* server: delegate result_state creation to server_task

* remove unued states

* add more docs
Xuan-Son Nguyen 1 개월 전
부모
커밋
951520ddb0
6개의 변경된 파일76개의 추가작업 그리고 40개의 파일을 삭제
  1. 26 1
      tools/server/README-dev.md
  2. 9 12
      tools/server/server-context.cpp
  3. 2 3
      tools/server/server-context.h
  4. 11 2
      tools/server/server-queue.cpp
  5. 3 3
      tools/server/server-queue.h
  6. 25 19
      tools/server/server-task.h

+ 26 - 1
tools/server/README-dev.md

@@ -42,7 +42,15 @@ graph TD
     server_response --> server_routes
 ```
 
-TODO: mention about how batching is handled by `server_slot`
+### Batching
+
+The server context maintains a single batch shared across all slots. When `update_slots()` is invoked, the system iterates through all active slots to populate this batch. For each slot, either a generated token from the previous decoding step or available prompt tokens are added to the batch.
+
+Batching constraints apply: slots can only be batched together if they share compatible configurations. For instance, slots using a specific LoRA adapter can be batched with each other, but not with slots using a different LoRA adapter or no adapter at all.
+
+Once the batch reaches capacity or all slots have been processed, `llama_decode` is called to execute the inference. This operation represents the primary computational bottleneck in `update_slots()`.
+
+Following decoding, the system either retrieves embeddings or samples the next token using `common_sampler_sample`. If a slot has remaining prompt tokens to process, it yields until the next `update_slots()` iteration.
 
 ### Thread Management
 
@@ -62,6 +70,23 @@ Each incoming HTTP request is handled by its own thread managed by the HTTP libr
 - All JSON formatting and chat template logic must stay in the HTTP layer.
 - Avoid passing raw JSON between the HTTP layer and `server_slot`. Instead, parse everything into native C++ types as early as possible.
 
+### Example trace of a request
+
+Here is an example trace of an API request for text completion:
+
+- A request arrives at the HTTP layer.
+- The request is routed to the corresponding handler inside `server_routes`. In this case, `handle_completions_impl` is invoked.
+- The handler parses the input request, constructs a new `server_task`, and passes it to `server_res_generator`.
+- `server_res_generator` creates a new `task_result_state` for each task:
+    - `task_result_state` stays in the HTTP layer, responsible for keeping track of the current state of the response (e.g., parsing tool calls or thinking messages).
+    - `server_task` is moved into `server_queue` inside `server_context`.
+- `server_context` launches the task by moving it into an available slot (see `launch_slot_with_task()`).
+- `update_slot()` processes the task as described in the "Batching" section above.
+- Results may be sent using `send_partial_response` or `send_final_response`, which creates a new `server_task_result` and pushes it to the response queue.
+- At the same time, `server_res_generator` listens to the response queue and retrieves this response.
+- As the response is stateless, `server_res_generator` calls `response->update()` to update the response with the current state.
+- `server_res_generator` then calls `response->to_json()` and passes the response to the HTTP layer.
+
 ### Testing
 
 `llama-server` includes an automated test suite based on `pytest`.

+ 9 - 12
tools/server/server-context.cpp

@@ -2589,6 +2589,10 @@ struct server_context_impl {
     int get_slot_n_ctx() {
         return slots.back().n_ctx;
     }
+
+    server_response_reader get_response_reader() {
+        return server_response_reader(queue_tasks, queue_results, HTTP_POLLING_SECONDS);
+    }
 };
 
 //
@@ -2618,8 +2622,8 @@ llama_context * server_context::get_llama_context() const {
     return impl->ctx;
 }
 
-std::pair<server_queue &, server_response &> server_context::get_queues() {
-    return { impl->queue_tasks, impl->queue_results };
+server_response_reader server_context::get_response_reader() {
+    return impl->get_response_reader();
 }
 
 
@@ -2628,7 +2632,7 @@ std::pair<server_queue &, server_response &> server_context::get_queues() {
 struct server_res_generator : server_http_res {
     server_response_reader rd;
     server_res_generator(server_context_impl & ctx_server)
-        : rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS) {}
+        : rd(ctx_server.queue_tasks, ctx_server.queue_results, HTTP_POLLING_SECONDS) {}
     void ok(const json & response_data) {
         status = 200;
         data = safe_json_to_str(response_data);
@@ -2661,9 +2665,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
     try {
         std::vector<server_task> tasks;
 
-        // tracking generation state and partial tool calls
-        std::vector<task_result_state> states;
-
         const auto & prompt = data.at("prompt");
         // TODO: this log can become very long, put it behind a flag or think about a more compact format
         //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
@@ -2679,7 +2680,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
             inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
         }
         tasks.reserve(inputs.size());
-        states.reserve(inputs.size());
         int idx = 0;
         for (size_t i = 0; i < inputs.size(); i++) {
             server_task task = server_task(type);
@@ -2698,7 +2698,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
             task.params.res_type          = res_type;
             task.params.oaicompat_cmpl_id = completion_id;
             task.params.oaicompat_model   = ctx_server.model_name;
-            states.push_back(task.params.oaicompat_chat_syntax);
 
             if (task.params.n_cmpl > 1) {
                 task.n_children = task.params.n_cmpl - 1;
@@ -2707,7 +2706,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
                         task.id,
                         ctx_server.queue_tasks.get_new_id(),
                         idx++);
-                    states.push_back(child.params.oaicompat_chat_syntax);
                     tasks.push_back(std::move(child));
                 }
             }
@@ -2715,7 +2713,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
             tasks.push_back(std::move(task));
         }
 
-        rd.set_states(std::move(states));
         rd.post_tasks(std::move(tasks));
     } catch (const std::exception & e) {
         res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
@@ -3445,7 +3442,7 @@ void server_routes::init_routes() {
 
         // create and queue the task
         json responses = json::array();
-        server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS);
+        server_response_reader rd = ctx_server.get_response_reader();
         {
             std::vector<server_task> tasks;
             tasks.reserve(documents.size());
@@ -3705,7 +3702,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_embeddings_impl(cons
 
     // create and queue the task
     json responses = json::array();
-    server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS);
+    server_response_reader rd = ctx_server.get_response_reader();
     {
         std::vector<server_task> tasks;
         for (size_t i = 0; i < tokenized_prompts.size(); i++) {

+ 2 - 3
tools/server/server-context.h

@@ -31,9 +31,8 @@ struct server_context {
     // get the underlaying llama_context
     llama_context * get_llama_context() const;
 
-    // get the underlaying queue_tasks and queue_results
-    // used by CLI application
-    std::pair<server_queue &, server_response &> get_queues();
+    // get a new response reader, used by CLI application
+    server_response_reader get_response_reader();
 };
 
 

+ 11 - 2
tools/server/server-queue.cpp

@@ -271,12 +271,21 @@ void server_response::terminate() {
 // server_response_reader
 //
 
-void server_response_reader::set_states(std::vector<task_result_state> && states) {
-    this->states = std::move(states);
+void server_response_reader::post_task(server_task && task) {
+    GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader");
+    id_tasks.insert(task.id);
+    states.push_back(task.create_state());
+    queue_results.add_waiting_task_id(task.id);
+    queue_tasks.post(std::move(task));
 }
 
 void server_response_reader::post_tasks(std::vector<server_task> && tasks) {
+    GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader");
     id_tasks = server_task::get_list_id(tasks);
+    states.reserve(tasks.size());
+    for (size_t i = 0; i < tasks.size(); i++) {
+        states.push_back(tasks[i].create_state());
+    }
     queue_results.add_waiting_tasks(tasks);
     queue_tasks.post(std::move(tasks));
 }

+ 3 - 3
tools/server/server-queue.h

@@ -129,13 +129,13 @@ struct server_response_reader {
     std::vector<task_result_state> states;
 
     // should_stop function will be called each polling_interval_seconds
-    server_response_reader(std::pair<server_queue &, server_response &> server_queues, int polling_interval_seconds)
-        : queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {}
+    server_response_reader(server_queue & queue_tasks, server_response & queue_results, int polling_interval_seconds)
+        : queue_tasks(queue_tasks), queue_results(queue_results), polling_interval_seconds(polling_interval_seconds) {}
     ~server_response_reader() {
         stop();
     }
 
-    void set_states(std::vector<task_result_state> && states);
+    void post_task(server_task && tasks);
     void post_tasks(std::vector<server_task> && tasks);
     bool has_next() const;
 

+ 25 - 19
tools/server/server-task.h

@@ -85,6 +85,25 @@ struct task_params {
     json to_json(bool only_metrics = false) const;
 };
 
+// struct for tracking the state of a task (e.g., for streaming)
+struct task_result_state {
+    // tracking diffs for partial tool calls
+    std::vector<common_chat_msg_diff> diffs;
+    common_chat_syntax oaicompat_chat_syntax;
+    common_chat_msg chat_msg;
+    std::string generated_text; // append new chunks of generated text here
+    std::vector<std::string> generated_tool_call_ids;
+
+    task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
+        : oaicompat_chat_syntax(oaicompat_chat_syntax) {}
+
+    // parse partial tool calls and update the internal state
+    common_chat_msg update_chat_msg(
+        const std::string & text_added,
+        bool is_partial,
+        std::vector<common_chat_msg_diff> & diffs);
+};
+
 struct server_task {
     int id    = -1; // to be filled by server_queue
     int index = -1; // used when there are multiple prompts (batch request)
@@ -149,6 +168,12 @@ struct server_task {
         copy.tokens    = tokens.clone();
         return copy;
     }
+
+    // the task will be moved into queue, then onto slots
+    // however, the state must be kept by caller (e.g., HTTP thread)
+    task_result_state create_state() const {
+        return task_result_state(params.oaicompat_chat_syntax);
+    }
 };
 
 struct result_timings {
@@ -180,25 +205,6 @@ struct result_prompt_progress {
     json to_json() const;
 };
 
-// struct for tracking the state of a task (e.g., for streaming)
-struct task_result_state {
-    // tracking diffs for partial tool calls
-    std::vector<common_chat_msg_diff> diffs;
-    common_chat_syntax oaicompat_chat_syntax;
-    common_chat_msg chat_msg;
-    std::string generated_text; // append new chunks of generated text here
-    std::vector<std::string> generated_tool_call_ids;
-
-    task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
-        : oaicompat_chat_syntax(oaicompat_chat_syntax) {}
-
-    // parse partial tool calls and update the internal state
-    common_chat_msg update_chat_msg(
-        const std::string & text_added,
-        bool is_partial,
-        std::vector<common_chat_msg_diff> & diffs);
-};
-
 struct server_task_result {
     int id           = -1;
     int id_slot      = -1;