Просмотр исходного кода

server: prevent data race from HTTP threads (#18263)

* server: prevent data race from HTTP threads

* fix params

* fix default_generation_settings

* nits: make handle_completions_impl looks less strange

* stricter const

* fix GGML_ASSERT(idx < states.size())

* move index to be managed by server_response_reader

* http: make sure req & res lifecycle are tied together

* fix compile

* fix index handling buggy

* fix data race for lora endpoint

* nits: fix shadow variable

* nits: revert redundant changes

* nits: correct naming for json_webui_settings
Xuan-Son Nguyen 1 месяц назад
Родитель
Сommit
6ce863c803

+ 1 - 1
tools/cli/cli.cpp

@@ -216,7 +216,7 @@ int main(int argc, char ** argv) {
         ctx_cli.ctx_server.start_loop();
     });
 
-    auto inf = ctx_cli.ctx_server.get_info();
+    auto inf = ctx_cli.ctx_server.get_meta();
     std::string modalities = "text";
     if (inf.has_inp_image) {
         modalities += ", vision";

+ 10 - 17
tools/server/server-common.cpp

@@ -115,26 +115,14 @@ bool lora_should_clear_cache(
         !lora_all_alora(next));
 }
 
-std::vector<common_adapter_lora_info> parse_lora_request(
-        const std::vector<common_adapter_lora_info> & lora_base,
-        const json & data) {
-    std::vector<common_adapter_lora_info> lora(lora_base);
-    int max_idx = lora.size();
-
-    // clear existing value
-    for (auto & entry : lora) {
-        entry.scale = 0.0f;
-    }
+std::map<int, float> parse_lora_request(const json & data) {
+    std::map<int, float> lora;
 
     // set value
     for (const auto & entry : data) {
         int id      = json_value(entry, "id", -1);
         float scale = json_value(entry, "scale", 0.0f);
-        if (0 <= id && id < max_idx) {
-            lora[id].scale = scale;
-        } else {
-            throw std::runtime_error("invalid adapter id");
-        }
+        lora[id] = scale;
     }
 
     return lora;
@@ -1435,7 +1423,7 @@ std::string safe_json_to_str(const json & data) {
 
 // TODO: reuse llama_detokenize
 template <class Iter>
-static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
+static std::string tokens_to_str(const llama_vocab * ctx, Iter begin, Iter end) {
     std::string ret;
     for (; begin != end; ++begin) {
         ret += common_token_to_piece(ctx, *begin);
@@ -1445,7 +1433,12 @@ static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
 }
 
 std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens) {
-    return tokens_to_str(ctx, tokens.begin(), tokens.end());
+    auto model = llama_get_model(ctx);
+    return tokens_to_str(llama_model_get_vocab(model), tokens.begin(), tokens.end());
+}
+
+std::string tokens_to_str(const llama_vocab * vocab, const llama_tokens & tokens) {
+    return tokens_to_str(vocab, tokens.begin(), tokens.end());
 }
 
 // format incomplete utf-8 multibyte character for output

+ 2 - 3
tools/server/server-common.h

@@ -107,9 +107,7 @@ bool lora_should_clear_cache(
         const std::vector<common_adapter_lora_info> & current,
         const std::vector<common_adapter_lora_info> & next);
 
-std::vector<common_adapter_lora_info> parse_lora_request(
-        const std::vector<common_adapter_lora_info> & lora_base,
-        const json & data);
+std::map<int, float> parse_lora_request(const json & data);
 
 bool are_lora_equal(
         const std::vector<common_adapter_lora_info> & l1,
@@ -325,6 +323,7 @@ std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int i
 std::string safe_json_to_str(const json & data);
 
 std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens);
+std::string tokens_to_str(const llama_vocab * vocab, const llama_tokens & tokens);
 
 // format incomplete utf-8 multibyte character for output
 std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token);

Разница между файлами не показана из-за своего большого размера
+ 288 - 250
tools/server/server-context.cpp


+ 51 - 11
tools/server/server-context.h

@@ -9,11 +9,35 @@
 
 struct server_context_impl; // private implementation
 
-struct server_context_info {
+struct server_context_meta {
     std::string build_info;
     std::string model_name;
+    std::string model_path;
+    bool has_mtmd;
     bool has_inp_image;
     bool has_inp_audio;
+    json json_webui_settings;
+    int slot_n_ctx;
+    enum llama_pooling_type pooling_type;
+
+    // chat template
+    std::string chat_template;
+    std::string chat_template_tool_use;
+
+    // tokens
+    std::string bos_token_str;
+    std::string eos_token_str;
+    llama_token fim_pre_token;
+    llama_token fim_sub_token;
+    llama_token fim_mid_token;
+
+    // model meta
+    enum llama_vocab_type model_vocab_type;
+    int32_t model_vocab_n_tokens;
+    int32_t model_n_ctx_train;
+    int32_t model_n_embd_inp;
+    uint64_t model_n_params;
+    uint64_t model_size;
 };
 
 struct server_context {
@@ -33,14 +57,15 @@ struct server_context {
     void terminate();
 
     // get the underlaying llama_context, can return nullptr if sleeping
+    // not thread-safe, should only be used from the main thread
     llama_context * get_llama_context() const;
 
     // get a new response reader, used by CLI application
     server_response_reader get_response_reader();
 
-    // get server info
-    // used by CLI application
-    server_context_info get_info() const;
+    // get server metadata (read-only), can only be called after load_model()
+    // not thread-safe, should only be used from the main thread
+    server_context_meta get_meta() const;
 };
 
 
@@ -48,13 +73,17 @@ struct server_context {
 struct server_res_generator;
 
 struct server_routes {
-    server_routes(const common_params & params, server_context & ctx_server, std::function<bool()> is_ready = []() { return true; })
-            : params(params), ctx_server(*ctx_server.impl), is_ready(is_ready) {
-        init_routes();
-    }
+    server_routes(const common_params & params, server_context & ctx_server);
 
     void init_routes();
+
+    // note: this is not thread-safe and can only when ctx_http.is_ready is false
+    void update_meta(const server_context & ctx_server) {
+        this->meta = std::make_unique<server_context_meta>(ctx_server.get_meta());
+    }
+
     // handlers using lambda function, so that they can capture `this` without `std::bind`
+    // they won't be called until ctx_http.is_ready is set to true
     server_http_context::handler_t get_health;
     server_http_context::handler_t get_metrics;
     server_http_context::handler_t get_slots;
@@ -78,13 +107,24 @@ struct server_routes {
     server_http_context::handler_t get_lora_adapters;
     server_http_context::handler_t post_lora_adapters;
 private:
-    // TODO: move these outside of server_routes?
+    std::unique_ptr<server_res_generator> handle_completions_impl(
+            const server_http_req & req,
+            server_task_type type,
+            const json & data,
+            const std::vector<raw_buffer> & files,
+            task_response_type res_type);
     std::unique_ptr<server_res_generator> handle_slots_save(const server_http_req & req, int id_slot);
     std::unique_ptr<server_res_generator> handle_slots_restore(const server_http_req & req, int id_slot);
     std::unique_ptr<server_res_generator> handle_slots_erase(const server_http_req &, int id_slot);
     std::unique_ptr<server_res_generator> handle_embeddings_impl(const server_http_req & req, task_response_type res_type);
 
+    // using unique_ptr to allow late initialization of const
+    std::unique_ptr<const server_context_meta> meta;
+
     const common_params & params;
-    server_context_impl & ctx_server;
-    std::function<bool()> is_ready;
+    const server_context_impl & ctx_server;
+
+    server_queue & queue_tasks;
+    server_response & queue_results;
+    std::unique_ptr<server_res_generator> create_response(bool bypass_sleep = false);
 };

+ 16 - 10
tools/server/server-http.cpp

@@ -177,12 +177,11 @@ bool server_http_context::init(const common_params & params) {
         if (!ready) {
             auto tmp = string_split<std::string>(req.path, '.');
             if (req.path == "/" || tmp.back() == "html") {
-                res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
                 res.status = 503;
-            } else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") {
-                // allow the models endpoint to be accessed during loading
-                return true;
+                res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
             } else {
+                // no endpoints is allowed to be accessed when the server is not ready
+                // this is to prevent any data races or inconsistent states
                 res.status = 503;
                 res.set_content(
                     safe_json_to_str(json {
@@ -334,12 +333,16 @@ static std::map<std::string, std::string> get_headers(const httplib::Request & r
     return headers;
 }
 
-static void process_handler_response(server_http_res_ptr & response, httplib::Response & res) {
+// using unique_ptr for request to allow safe capturing in lambdas
+using server_http_req_ptr = std::unique_ptr<server_http_req>;
+
+static void process_handler_response(server_http_req_ptr && request, server_http_res_ptr & response, httplib::Response & res) {
     if (response->is_stream()) {
         res.status = response->status;
         set_headers(res, response->headers);
         std::string content_type = response->content_type;
         // convert to shared_ptr as both chunked_content_provider() and on_complete() need to use it
+        std::shared_ptr<server_http_req> q_ptr = std::move(request);
         std::shared_ptr<server_http_res> r_ptr = std::move(response);
         const auto chunked_content_provider = [response = r_ptr](size_t, httplib::DataSink & sink) -> bool {
             std::string chunk;
@@ -355,8 +358,9 @@ static void process_handler_response(server_http_res_ptr & response, httplib::Re
             }
             return has_next;
         };
-        const auto on_complete = [response = r_ptr](bool) mutable {
+        const auto on_complete = [request = q_ptr, response = r_ptr](bool) mutable {
             response.reset(); // trigger the destruction of the response object
+            request.reset();  // trigger the destruction of the request object
         };
         res.set_chunked_content_provider(content_type, chunked_content_provider, on_complete);
     } else {
@@ -368,27 +372,29 @@ static void process_handler_response(server_http_res_ptr & response, httplib::Re
 
 void server_http_context::get(const std::string & path, const server_http_context::handler_t & handler) const {
     pimpl->srv->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
-        server_http_res_ptr response = handler(server_http_req{
+        server_http_req_ptr request = std::make_unique<server_http_req>(server_http_req{
             get_params(req),
             get_headers(req),
             req.path,
             req.body,
             req.is_connection_closed
         });
-        process_handler_response(response, res);
+        server_http_res_ptr response = handler(*request);
+        process_handler_response(std::move(request), response, res);
     });
 }
 
 void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const {
     pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
-        server_http_res_ptr response = handler(server_http_req{
+        server_http_req_ptr request = std::make_unique<server_http_req>(server_http_req{
             get_params(req),
             get_headers(req),
             req.path,
             req.body,
             req.is_connection_closed
         });
-        process_handler_response(response, res);
+        server_http_res_ptr response = handler(*request);
+        process_handler_response(std::move(request), response, res);
     });
 }
 

+ 9 - 6
tools/server/server-queue.cpp

@@ -325,23 +325,25 @@ void server_response::terminate() {
 // server_response_reader
 //
 
-void server_response_reader::post_task(server_task && task) {
+void server_response_reader::post_task(server_task && task, bool front) {
     GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader");
+    task.index = 0;
     id_tasks.insert(task.id);
     states.push_back(task.create_state());
     queue_results.add_waiting_task_id(task.id);
-    queue_tasks.post(std::move(task));
+    queue_tasks.post(std::move(task), front);
 }
 
-void server_response_reader::post_tasks(std::vector<server_task> && tasks) {
+void server_response_reader::post_tasks(std::vector<server_task> && tasks, bool front) {
     GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader");
     id_tasks = server_task::get_list_id(tasks);
     states.reserve(tasks.size());
     for (size_t i = 0; i < tasks.size(); i++) {
+        tasks[i].index = i;
         states.push_back(tasks[i].create_state());
     }
     queue_results.add_waiting_tasks(tasks);
-    queue_tasks.post(std::move(tasks));
+    queue_tasks.post(std::move(tasks), front);
 }
 
 bool server_response_reader::has_next() const {
@@ -367,7 +369,7 @@ server_task_result_ptr server_response_reader::next(const std::function<bool()>
             }
             if (!states.empty()) {
                 // update the generation state if needed
-                size_t idx = result->get_index();
+                const size_t idx = result->index;
                 GGML_ASSERT(idx < states.size());
                 result->update(states[idx]);
             }
@@ -383,6 +385,7 @@ server_task_result_ptr server_response_reader::next(const std::function<bool()>
 
 server_response_reader::batch_response server_response_reader::wait_for_all(const std::function<bool()> & should_stop) {
     batch_response batch_res;
+    batch_res.results.clear();
     batch_res.results.resize(id_tasks.size());
     while (has_next()) {
         auto res = next(should_stop);
@@ -394,7 +397,7 @@ server_response_reader::batch_response server_response_reader::wait_for_all(cons
             batch_res.error = std::move(res);
             return batch_res;
         }
-        const size_t idx = res->get_index();
+        const size_t idx = res->index;
         GGML_ASSERT(idx < batch_res.results.size() && "index out of range");
         GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received");
         batch_res.results[idx] = std::move(res);

+ 5 - 2
tools/server/server-queue.h

@@ -5,6 +5,7 @@
 #include <condition_variable>
 #include <deque>
 #include <mutex>
+#include <vector>
 #include <unordered_set>
 
 // struct for managing server tasks
@@ -173,8 +174,10 @@ struct server_response_reader {
     int get_new_id() {
         return queue_tasks.get_new_id();
     }
-    void post_task(server_task && task);
-    void post_tasks(std::vector<server_task> && tasks);
+
+    // if front = true, the task will be posted to the front of the queue (high priority)
+    void post_task(server_task && task, bool front = false);
+    void post_tasks(std::vector<server_task> && tasks, bool front = false);
     bool has_next() const;
 
     // return nullptr if should_stop() is true before receiving a result

+ 32 - 10
tools/server/server-task.cpp

@@ -32,8 +32,8 @@ json task_params::to_json(bool only_metrics) const {
     }
 
     json lora = json::array();
-    for (size_t i = 0; i < this->lora.size(); ++i) {
-        lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
+    for (auto & it : this->lora) {
+        lora.push_back({{"id", it.first}, {"scale", it.second}});
     }
 
     if (only_metrics) {
@@ -145,12 +145,10 @@ json task_params::to_json(bool only_metrics) const {
 //
 
 task_params server_task::params_from_json_cmpl(
-        const llama_context * ctx,
+        const llama_vocab * vocab,
         const common_params & params_base,
+        const int n_ctx_slot,
         const json & data) {
-    const llama_model * model = llama_get_model(ctx);
-    const llama_vocab * vocab = llama_model_get_vocab(model);
-
     task_params params;
 
     // Sampling parameter defaults are loaded from the global server context (but individual requests can still them)
@@ -223,12 +221,12 @@ task_params server_task::params_from_json_cmpl(
 
     if (data.contains("lora")) {
         if (data.at("lora").is_array()) {
-            params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora"));
+            params.lora = parse_lora_request(data.at("lora"));
         } else {
             throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
         }
     } else {
-        params.lora = params_base.lora_adapters;
+        params.lora = {};
     }
 
     // TODO: add more sanity checks for the input parameters
@@ -243,11 +241,11 @@ task_params server_task::params_from_json_cmpl(
 
     if (params.sampling.penalty_last_n == -1) {
         // note: should be the slot's context and not the full context, but it's ok
-        params.sampling.penalty_last_n = llama_n_ctx(ctx);
+        params.sampling.penalty_last_n = n_ctx_slot;
     }
 
     if (params.sampling.dry_penalty_last_n == -1) {
-        params.sampling.dry_penalty_last_n = llama_n_ctx(ctx);
+        params.sampling.dry_penalty_last_n = n_ctx_slot;
     }
 
     if (params.sampling.dry_base < 1.0f) {
@@ -1324,6 +1322,30 @@ json server_task_result_slot_erase::to_json() {
     };
 }
 
+//
+// server_task_result_get_lora
+//
+
+json server_task_result_get_lora::to_json() {
+    json result = json::array();
+    for (size_t i = 0; i < loras.size(); ++i) {
+        auto & lora = loras[i];
+        json entry = {
+            {"id",            i},
+            {"path",          lora.info.path},
+            {"scale",         lora.info.scale},
+            {"task_name",     lora.info.task_name},
+            {"prompt_prefix", lora.info.prompt_prefix},
+        };
+        if (!lora.alora_invocation_tokens.empty()) {
+            entry["alora_invocation_string"] = lora.alora_invocation_string;
+            entry["alora_invocation_tokens"] = lora.alora_invocation_tokens;
+        }
+        result.push_back(std::move(entry));
+    }
+    return result;
+}
+
 //
 // server_task_result_apply_lora
 //

+ 28 - 35
tools/server/server-task.h

@@ -6,6 +6,7 @@
 #include <string>
 #include <unordered_set>
 #include <list>
+#include <map>
 
 // TODO: prevent including the whole server-common.h as we only use server_tokens
 #include "server-common.h"
@@ -23,6 +24,7 @@ enum server_task_type {
     SERVER_TASK_TYPE_SLOT_SAVE,
     SERVER_TASK_TYPE_SLOT_RESTORE,
     SERVER_TASK_TYPE_SLOT_ERASE,
+    SERVER_TASK_TYPE_GET_LORA,
     SERVER_TASK_TYPE_SET_LORA,
 };
 
@@ -60,7 +62,7 @@ struct task_params {
     int64_t t_max_prompt_ms  = -1; // TODO: implement
     int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
 
-    std::vector<common_adapter_lora_info> lora;
+    std::map<int, float> lora; // mapping adapter ID -> scale
 
     std::vector<std::string> antiprompt;
     std::vector<std::string> response_fields;
@@ -105,8 +107,10 @@ struct task_result_state {
 };
 
 struct server_task {
-    int id    = -1; // to be filled by server_queue
-    int index = -1; // used when there are multiple prompts (batch request)
+    int id = -1; // to be filled by server_queue
+
+    // TODO @ngxson : remove this field and implement a mapping task_id -> idx in the response_reader
+    size_t index = 0; // used when there are multiple prompts (batch request)
 
     // used by SERVER_TASK_TYPE_CANCEL
     int id_target = -1;
@@ -138,7 +142,7 @@ struct server_task {
     bool metrics_reset_bucket = false;
 
     // used by SERVER_TASK_TYPE_SET_LORA
-    std::vector<common_adapter_lora_info> set_lora;
+    std::map<int, float> set_lora; // mapping adapter ID -> scale
 
     server_task() = default;
 
@@ -149,9 +153,10 @@ struct server_task {
     }
 
     static task_params params_from_json_cmpl(
-            const llama_context * ctx,
-            const common_params & params_base,
-            const json & data);
+        const llama_vocab * vocab,
+        const common_params & params_base,
+        const int n_ctx_slot,
+        const json & data);
 
     // utility function
     static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
@@ -162,10 +167,9 @@ struct server_task {
         return ids;
     }
 
-    server_task create_child(int id_parent, int id_child, int idx) const {
+    server_task create_child(int id_parent, int id_child) const {
         server_task copy;
         copy.id        = id_child;
-        copy.index     = idx;
         copy.id_parent = id_parent;
         copy.params    = params;
         copy.type      = type;
@@ -212,6 +216,10 @@ struct result_prompt_progress {
 struct server_task_result {
     int id           = -1;
     int id_slot      = -1;
+
+    // TODO @ngxson : remove this field and implement a mapping task_id -> idx in the response_reader
+    size_t index = 0; // to be used for batched tasks
+
     virtual bool is_error() {
         // only used by server_task_result_error
         return false;
@@ -220,9 +228,6 @@ struct server_task_result {
         // only used by server_task_result_cmpl_*
         return true;
     }
-    virtual int get_index() {
-        return -1;
-    }
     virtual void update(task_result_state &) {
         // only used by server_task_result_cmpl_*
     }
@@ -255,8 +260,6 @@ struct completion_token_output {
 };
 
 struct server_task_result_cmpl_final : server_task_result {
-    int index = 0;
-
     std::string content;
     llama_tokens tokens;
 
@@ -289,10 +292,6 @@ struct server_task_result_cmpl_final : server_task_result {
     std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
     bool is_updated = false;
 
-    virtual int get_index() override {
-        return index;
-    }
-
     virtual bool is_stop() override {
         return true; // in stream mode, final responses are considered stop
     }
@@ -318,8 +317,6 @@ struct server_task_result_cmpl_final : server_task_result {
 };
 
 struct server_task_result_cmpl_partial : server_task_result {
-    int index = 0;
-
     std::string  content;
     llama_tokens tokens;
 
@@ -340,10 +337,6 @@ struct server_task_result_cmpl_partial : server_task_result {
     std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
     bool is_updated = false;
 
-    virtual int get_index() override {
-        return index;
-    }
-
     virtual bool is_stop() override {
         return false; // in stream mode, partial responses are not considered stop
     }
@@ -365,7 +358,6 @@ struct server_task_result_cmpl_partial : server_task_result {
 };
 
 struct server_task_result_embd : server_task_result {
-    int index = 0;
     std::vector<std::vector<float>> embedding;
 
     int32_t n_tokens;
@@ -373,10 +365,6 @@ struct server_task_result_embd : server_task_result {
     // response formatting
     task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
 
-    virtual int get_index() override {
-        return index;
-    }
-
     virtual json to_json() override;
 
     json to_json_non_oaicompat();
@@ -385,20 +373,14 @@ struct server_task_result_embd : server_task_result {
 };
 
 struct server_task_result_rerank : server_task_result {
-    int index = 0;
     float score = -1e6;
 
     int32_t n_tokens;
 
-    virtual int get_index() override {
-        return index;
-    }
-
     virtual json to_json() override;
 };
 
 struct server_task_result_error : server_task_result {
-    int index = 0;
     error_type err_type = ERROR_TYPE_SERVER;
     std::string err_msg;
 
@@ -460,6 +442,17 @@ struct server_task_result_slot_erase : server_task_result {
     virtual json to_json() override;
 };
 
+struct server_task_result_get_lora : server_task_result {
+    struct lora {
+        common_adapter_lora_info info;
+        std::string  alora_invocation_string;
+        llama_tokens alora_invocation_tokens;
+    };
+    std::vector<lora> loras;
+
+    virtual json to_json() override;
+};
+
 struct server_task_result_apply_lora : server_task_result {
     virtual json to_json() override;
 };

+ 2 - 1
tools/server/server.cpp

@@ -119,7 +119,7 @@ int main(int argc, char ** argv, char ** envp) {
     //
 
     // register API routes
-    server_routes routes(params, ctx_server, [&ctx_http]() { return ctx_http.is_ready.load(); });
+    server_routes routes(params, ctx_server);
 
     bool is_router_server = params.model.path.empty();
     std::optional<server_models_routes> models_routes{};
@@ -252,6 +252,7 @@ int main(int argc, char ** argv, char ** envp) {
             return 1;
         }
 
+        routes.update_meta(ctx_server);
         ctx_http.is_ready.store(true);
 
         LOG_INF("%s: model loaded\n", __func__);

Некоторые файлы не были показаны из-за большого количества измененных файлов