Sfoglia il codice sorgente

server: split server.cpp code into server/common/task/queue (#17362)

* add server-task, server-common

* add server-queue

* rm redundant includes

* move enum stop_type to server-task

* server : headers cleanup

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Xuan-Son Nguyen 1 mese fa
parent
commit
b8372eecd9

+ 6 - 1
tools/server/CMakeLists.txt

@@ -13,9 +13,14 @@ endif()
 
 set(TARGET_SRCS
     server.cpp
-    utils.hpp
     server-http.cpp
     server-http.h
+    server-task.cpp
+    server-task.h
+    server-queue.cpp
+    server-queue.h
+    server-common.cpp
+    server-common.h
 )
 set(PUBLIC_ASSETS
     index.html.gz

File diff suppressed because it is too large
+ 627 - 377
tools/server/server-common.cpp


+ 349 - 0
tools/server/server-common.h

@@ -0,0 +1,349 @@
+#pragma once
+
+#include "common.h"
+#include "log.h"
+#include "llama.h"
+#include "chat.h"
+#include "mtmd.h"
+
+#define JSON_ASSERT GGML_ASSERT
+#include <nlohmann/json.hpp>
+
+#include <string>
+#include <vector>
+#include <cinttypes>
+
+#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo"
+
+const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT);
+
+using json = nlohmann::ordered_json;
+
+#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
+#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
+#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
+#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
+
+#define SRV_INF(fmt, ...) LOG_INF("srv  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define SRV_WRN(fmt, ...) LOG_WRN("srv  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define SRV_ERR(fmt, ...) LOG_ERR("srv  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define SRV_DBG(fmt, ...) LOG_DBG("srv  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+
+using raw_buffer = std::vector<uint8_t>;
+
+template <typename T>
+static T json_value(const json & body, const std::string & key, const T & default_value) {
+    // Fallback null to default value
+    if (body.contains(key) && !body.at(key).is_null()) {
+        try {
+            return body.at(key);
+        } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const & err) {
+            LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value: %s\n", key.c_str(), json(default_value).type_name(), err.what());
+            return default_value;
+        }
+    } else {
+        return default_value;
+    }
+}
+
+// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
+enum error_type {
+    ERROR_TYPE_INVALID_REQUEST,
+    ERROR_TYPE_AUTHENTICATION,
+    ERROR_TYPE_SERVER,
+    ERROR_TYPE_NOT_FOUND,
+    ERROR_TYPE_PERMISSION,
+    ERROR_TYPE_UNAVAILABLE, // custom error
+    ERROR_TYPE_NOT_SUPPORTED, // custom error
+    ERROR_TYPE_EXCEED_CONTEXT_SIZE, // custom error
+};
+
+// thin wrapper around common_grammar_trigger with (de)serialization functions
+struct server_grammar_trigger {
+    common_grammar_trigger value;
+
+    server_grammar_trigger() = default;
+    server_grammar_trigger(const common_grammar_trigger & value) : value(value) {}
+    server_grammar_trigger(const json & in) {
+        value.type = (common_grammar_trigger_type) in.at("type").get<int>();
+        value.value = in.at("value").get<std::string>();
+        if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
+            value.token = (llama_token) in.at("token").get<int>();
+        }
+    }
+
+    json to_json() const {
+        json out {
+            {"type", (int) value.type},
+            {"value", value.value},
+        };
+        if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
+            out["token"] = (int) value.token;
+        }
+        return out;
+    }
+};
+
+json format_error_response(const std::string & message, const enum error_type type);
+
+//
+// random string / id
+//
+
+std::string random_string();
+std::string gen_chatcmplid();
+std::string gen_tool_call_id();
+
+//
+// lora utils
+//
+
+// check whether the given lora set has only aloras activated (empty => false)
+bool lora_all_alora(const std::vector<common_adapter_lora_info> & loras);
+
+// if the two sets of loras are different, they require a cache clear unless the
+// change is only from aloras to aloras.
+bool lora_should_clear_cache(
+        const std::vector<common_adapter_lora_info> & current,
+        const std::vector<common_adapter_lora_info> & next);
+
+std::vector<common_adapter_lora_info> parse_lora_request(
+        const std::vector<common_adapter_lora_info> & lora_base,
+        const json & data);
+
+bool are_lora_equal(
+        const std::vector<common_adapter_lora_info> & l1,
+        const std::vector<common_adapter_lora_info> & l2);
+
+// get the ids of all enabled loras
+std::vector<size_t> lora_get_enabled_ids(const std::vector<common_adapter_lora_info> & loras);
+
+//
+// server_tokens
+//
+
+/**
+ * server_tokens is a helper to manage the input tokens and image for the server.
+ * it is made this way to simplify the logic of KV cache management.
+ */
+struct server_tokens {
+    bool has_mtmd = false;
+
+private: // disallow accessing these members directly, risking out-of-sync
+
+    // map a **start** index in tokens to the image chunk
+    // note: the order need to be in-sync with tokens
+    std::map<size_t, mtmd::input_chunk_ptr> map_idx_to_media;
+
+    // list of tokens
+    //   if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk
+    //   otherwise, it is a normal text token
+    // note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list
+    // note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos
+    llama_tokens tokens;
+
+    // for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos):
+    //      [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1]
+    // idx  0   1   2   3   4   5      6      7      8      9      10
+    // pos  0   1   2   3   4   5      5      5      7      7      7
+    // map_idx_to_media will contain: {5, img0}, {8, img1}
+
+public:
+    server_tokens() = default;
+    ~server_tokens() = default;
+
+    // Prevent copying
+    // TODO: server_tokens should be copyable - remove this:
+    server_tokens(const server_tokens&) = delete;
+    server_tokens& operator=(const server_tokens&) = delete;
+
+    // Allow moving (usually implicitly generated if members are movable)
+    server_tokens(server_tokens&&) = default;
+    server_tokens& operator=(server_tokens&&) = default;
+
+    // Allow accessing elements using [] operator
+    llama_token operator[](size_t index) { return tokens[index]; }
+    const llama_token& operator[](size_t index) const { return tokens[index]; }
+
+    server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd);
+    server_tokens(const llama_tokens & tokens, bool has_mtmd);
+
+    // for debugging
+    std::string str() const;
+
+    llama_pos pos_next() const;
+    const mtmd::input_chunk_ptr & find_chunk(size_t idx) const;
+
+    void push_back(llama_token tok);
+
+    // will create a copy of the chunk if it contains non-text data
+    void push_back(const mtmd_input_chunk * chunk);
+
+    // appends server tokens, updates the media map. copies media chunks.
+    void push_back(server_tokens & tokens);
+
+    // for compatibility with context shift and prompt truncation
+    void insert(const llama_tokens & inp_tokens);
+
+    // for compatibility with speculative decoding, ctx shift, slot save/load
+    const llama_tokens & get_text_tokens() const;
+
+    // for compatibility with speculative decoding
+    void set_token(llama_pos pos, llama_token id);
+
+    size_t size() const { return tokens.size(); }
+
+    bool empty() const { return tokens.empty(); }
+
+    void clear() {
+        map_idx_to_media.clear();
+        tokens.clear();
+    }
+
+    void keep_first(size_t n);
+
+    std::string detokenize(const llama_context * ctx, bool special) const;
+
+    size_t get_common_prefix(const server_tokens & b) const;
+
+    // make sure all text tokens are within the vocab range
+    bool validate(const struct llama_context * ctx) const;
+
+    // encode and decode the image chunk
+    int32_t process_chunk(
+                llama_context * ctx,
+                mtmd_context * mctx,
+                size_t idx,
+                llama_pos pos,
+                int32_t seq_id,
+                size_t & n_tokens_out) const;
+};
+
+
+//
+// tokenizer and input processing utils
+//
+
+bool json_is_array_of_numbers(const json & data);
+
+// is array having BOTH numbers & strings?
+bool json_is_array_of_mixed_numbers_strings(const json & data);
+
+// does array have any individual integers/tokens?
+bool json_is_array_and_contains_numbers(const json & data);
+
+// get value by path(key1 / key2)
+json json_get_nested_values(const std::vector<std::string> & paths, const json & js);
+
+/**
+ * this handles 2 cases:
+ * - only string, example: "string"
+ * - mixed string and tokens, example: [12, 34, "string", 56, 78]
+ */
+llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special);
+
+// return the last index of character that can form a valid string
+// if the last character is potentially cut in half, return the index before the cut
+// if validate_utf8(text) == text.size(), then the whole text is valid utf8
+size_t validate_utf8(const std::string& text);
+
+// process mtmd prompt, return the server_tokens containing both text tokens and media chunks
+server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector<raw_buffer> files);
+
+/**
+ * break the input "prompt" object into multiple prompt if needed, then tokenize them
+ * this supports these cases:
+ * - "prompt": "string"
+ * - "prompt": [12, 34, 56]
+ * - "prompt": [12, 34, "string", 56, 78]
+ * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] }
+ * and multiple prompts (multi-tasks):
+ * - "prompt": ["string1", "string2"]
+ * - "prompt": ["string1", [12, 34, 56]]
+ * - "prompt": [[12, 34, 56], [78, 90, 12]]
+ * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56], { "prompt_string": "string", "multimodal_data": [ "base64" ]}]
+ */
+std::vector<server_tokens> tokenize_input_prompts(
+                                        const llama_vocab * vocab,
+                                        mtmd_context * mctx,
+                                        const json & json_prompt,
+                                        bool add_special,
+                                        bool parse_special);
+
+//
+// OAI utils
+//
+
+// used by /completions endpoint
+json oaicompat_completion_params_parse(const json & body);
+
+struct oaicompat_parser_options {
+    bool use_jinja;
+    bool prefill_assistant;
+    common_reasoning_format reasoning_format;
+    std::map<std::string,std::string> chat_template_kwargs;
+    common_chat_templates * tmpls;
+    bool allow_image;
+    bool allow_audio;
+    bool enable_thinking = true;
+};
+
+// used by /chat/completions endpoint
+json oaicompat_chat_params_parse(
+    json & body, /* openai api json semantics */
+    const oaicompat_parser_options & opt,
+    std::vector<raw_buffer> & out_files);
+
+// TODO: move it to server-task.cpp
+json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false);
+
+// TODO: move it to server-task.cpp
+json format_response_rerank(
+        const json & request,
+        const json & ranks,
+        bool is_tei_format,
+        std::vector<std::string> & texts,
+        int top_n);
+
+//
+// other utils
+//
+
+std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx);
+
+std::string safe_json_to_str(const json & data);
+
+std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens);
+
+// format incomplete utf-8 multibyte character for output
+std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token);
+
+// format server-sent event (SSE), return the formatted string to send
+// note: if data is a json array, it will be sent as multiple events, one per item
+std::string format_sse(const json & data);
+
+bool is_valid_utf8(const std::string & str);
+
+//
+// formatting output responses
+// TODO: move these to server-task.cpp
+//
+
+llama_tokens format_prompt_infill(
+        const llama_vocab * vocab,
+        const json & input_prefix,
+        const json & input_suffix,
+        const json & input_extra,
+        const int n_batch,
+        const int n_predict,
+        const int n_ctx,
+        const bool spm_infill,
+        const llama_tokens & tokens_prompt);
+
+// format rerank task: [BOS]query[EOS][SEP]doc[EOS].
+server_tokens format_prompt_rerank(
+        const struct llama_model * model,
+        const struct llama_vocab * vocab,
+        mtmd_context * mctx,
+        const std::string & query,
+        const std::string & doc);

+ 1 - 1
tools/server/server-http.cpp

@@ -1,6 +1,6 @@
-#include "utils.hpp"
 #include "common.h"
 #include "server-http.h"
+#include "server-common.h"
 
 #include <cpp-httplib/httplib.h>
 

+ 268 - 0
tools/server/server-queue.cpp

@@ -0,0 +1,268 @@
+#include "server-task.h"
+#include "server-queue.h"
+
+#include "log.h"
+
+#include <chrono>
+
+#define QUE_INF(fmt, ...) LOG_INF("que  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define QUE_WRN(fmt, ...) LOG_WRN("que  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define QUE_ERR(fmt, ...) LOG_ERR("que  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define QUE_DBG(fmt, ...) LOG_DBG("que  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+
+#define RES_INF(fmt, ...) LOG_INF("res  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define RES_WRN(fmt, ...) LOG_WRN("res  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define RES_ERR(fmt, ...) LOG_ERR("res  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define RES_DBG(fmt, ...) LOG_DBG("res  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+
+//
+// server_queue
+//
+
+int server_queue::post(server_task && task, bool front) {
+    std::unique_lock<std::mutex> lock(mutex_tasks);
+    GGML_ASSERT(task.id != -1);
+    // if this is cancel task make sure to clean up pending tasks
+    if (task.type == SERVER_TASK_TYPE_CANCEL) {
+        cleanup_pending_task(task.id_target);
+    }
+    const int task_id = task.id;
+    QUE_DBG("new task, id = %d, front = %d\n", task_id, front);
+    if (front) {
+        queue_tasks.push_front(std::move(task));
+    } else {
+        queue_tasks.push_back(std::move(task));
+    }
+    condition_tasks.notify_one();
+    return task_id;
+}
+
+int server_queue::post(std::vector<server_task> && tasks, bool front) {
+    std::unique_lock<std::mutex> lock(mutex_tasks);
+    for (auto & task : tasks) {
+        if (task.id == -1) {
+            task.id = id++;
+        }
+        // if this is cancel task make sure to clean up pending tasks
+        if (task.type == SERVER_TASK_TYPE_CANCEL) {
+            cleanup_pending_task(task.id_target);
+        }
+        QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front);
+        if (front) {
+            queue_tasks.push_front(std::move(task));
+        } else {
+            queue_tasks.push_back(std::move(task));
+        }
+    }
+    condition_tasks.notify_one();
+    return 0;
+}
+
+void server_queue::defer(server_task && task) {
+    std::unique_lock<std::mutex> lock(mutex_tasks);
+    QUE_DBG("defer task, id = %d\n", task.id);
+    queue_tasks_deferred.push_back(std::move(task));
+    condition_tasks.notify_one();
+}
+
+int server_queue::get_new_id() {
+    std::unique_lock<std::mutex> lock(mutex_tasks);
+    int new_id = id++;
+    return new_id;
+}
+
+void server_queue::on_new_task(std::function<void(server_task &&)> callback) {
+    callback_new_task = std::move(callback);
+}
+
+void server_queue::on_update_slots(std::function<void(void)> callback) {
+    callback_update_slots = std::move(callback);
+}
+
+void server_queue::pop_deferred_task() {
+    std::unique_lock<std::mutex> lock(mutex_tasks);
+    if (!queue_tasks_deferred.empty()) {
+        queue_tasks.emplace_front(std::move(queue_tasks_deferred.front()));
+        queue_tasks_deferred.pop_front();
+    }
+    condition_tasks.notify_one();
+}
+
+void server_queue::terminate() {
+    std::unique_lock<std::mutex> lock(mutex_tasks);
+    running = false;
+    condition_tasks.notify_all();
+}
+
+void server_queue::start_loop() {
+    running = true;
+
+    while (true) {
+        QUE_DBG("%s", "processing new tasks\n");
+
+        while (true) {
+            std::unique_lock<std::mutex> lock(mutex_tasks);
+            if (!running) {
+                QUE_DBG("%s", "terminate\n");
+                return;
+            }
+            if (queue_tasks.empty()) {
+                lock.unlock();
+                break;
+            }
+            server_task task = std::move(queue_tasks.front());
+            queue_tasks.pop_front();
+            lock.unlock();
+
+            QUE_DBG("processing task, id = %d\n", task.id);
+            callback_new_task(std::move(task));
+        }
+
+        // all tasks in the current loop is processed, slots data is now ready
+        QUE_DBG("%s", "update slots\n");
+
+        callback_update_slots();
+
+        QUE_DBG("%s", "waiting for new tasks\n");
+        {
+            std::unique_lock<std::mutex> lock(mutex_tasks);
+            if (!running) {
+                QUE_DBG("%s", "terminate\n");
+                return;
+            }
+            if (queue_tasks.empty()) {
+                condition_tasks.wait(lock, [&]{
+                    return (!queue_tasks.empty() || !running);
+                });
+            }
+        }
+    }
+}
+
+void server_queue::cleanup_pending_task(int id_target) {
+    // no need lock because this is called exclusively by post()
+    auto rm_func = [id_target](const server_task & task) {
+        return task.id == id_target;
+    };
+    queue_tasks.erase(
+        std::remove_if(queue_tasks.begin(),          queue_tasks.end(),          rm_func),
+        queue_tasks.end());
+    queue_tasks_deferred.erase(
+        std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func),
+        queue_tasks_deferred.end());
+}
+
+//
+// server_response
+//
+
+void server_response::add_waiting_task_id(int id_task) {
+    RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
+
+    std::unique_lock<std::mutex> lock(mutex_results);
+    waiting_task_ids.insert(id_task);
+}
+
+void server_response::add_waiting_tasks(const std::vector<server_task> & tasks) {
+    std::unique_lock<std::mutex> lock(mutex_results);
+
+    for (const auto & task : tasks) {
+        RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size());
+        waiting_task_ids.insert(task.id);
+    }
+}
+
+void server_response::remove_waiting_task_id(int id_task) {
+    RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
+
+    std::unique_lock<std::mutex> lock(mutex_results);
+    waiting_task_ids.erase(id_task);
+    // make sure to clean up all pending results
+    queue_results.erase(
+        std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) {
+            return res->id == id_task;
+        }),
+        queue_results.end());
+}
+
+void server_response::remove_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
+    std::unique_lock<std::mutex> lock(mutex_results);
+
+    for (const auto & id_task : id_tasks) {
+        RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
+        waiting_task_ids.erase(id_task);
+    }
+}
+
+server_task_result_ptr server_response::recv(const std::unordered_set<int> & id_tasks) {
+    while (true) {
+        std::unique_lock<std::mutex> lock(mutex_results);
+        condition_results.wait(lock, [&]{
+            if (!running) {
+                RES_DBG("%s : queue result stop\n", __func__);
+                std::terminate(); // we cannot return here since the caller is HTTP code
+            }
+            return !queue_results.empty();
+        });
+
+        for (size_t i = 0; i < queue_results.size(); i++) {
+            if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
+                server_task_result_ptr res = std::move(queue_results[i]);
+                queue_results.erase(queue_results.begin() + i);
+                return res;
+            }
+        }
+    }
+
+    // should never reach here
+}
+
+server_task_result_ptr server_response::recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout) {
+    while (true) {
+        std::unique_lock<std::mutex> lock(mutex_results);
+
+        for (int i = 0; i < (int) queue_results.size(); i++) {
+            if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
+                server_task_result_ptr res = std::move(queue_results[i]);
+                queue_results.erase(queue_results.begin() + i);
+                return res;
+            }
+        }
+
+        std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout));
+        if (!running) {
+            RES_DBG("%s : queue result stop\n", __func__);
+            std::terminate(); // we cannot return here since the caller is HTTP code
+        }
+        if (cr_res == std::cv_status::timeout) {
+            return nullptr;
+        }
+    }
+
+    // should never reach here
+}
+
+server_task_result_ptr server_response::recv(int id_task) {
+    std::unordered_set<int> id_tasks = {id_task};
+    return recv(id_tasks);
+}
+
+void server_response::send(server_task_result_ptr && result) {
+    RES_DBG("sending result for task id = %d\n", result->id);
+
+    std::unique_lock<std::mutex> lock(mutex_results);
+    for (const auto & id_task : waiting_task_ids) {
+        if (result->id == id_task) {
+            RES_DBG("task id = %d pushed to result queue\n", result->id);
+
+            queue_results.emplace_back(std::move(result));
+            condition_results.notify_all();
+            return;
+        }
+    }
+}
+
+void server_response::terminate() {
+    running = false;
+    condition_results.notify_all();
+}

+ 110 - 0
tools/server/server-queue.h

@@ -0,0 +1,110 @@
+#pragma once
+
+#include "server-task.h"
+
+#include <condition_variable>
+#include <deque>
+#include <mutex>
+#include <unordered_set>
+
+struct server_queue {
+private:
+    int id = 0;
+    bool running;
+
+    // queues
+    std::deque<server_task> queue_tasks;
+    std::deque<server_task> queue_tasks_deferred;
+
+    std::mutex mutex_tasks;
+    std::condition_variable condition_tasks;
+
+    // callback functions
+    std::function<void(server_task &&)> callback_new_task;
+    std::function<void(void)>           callback_update_slots;
+
+public:
+    // Add a new task to the end of the queue
+    int post(server_task && task, bool front = false);
+
+    // multi-task version of post()
+    int post(std::vector<server_task> && tasks, bool front = false);
+
+    // Add a new task, but defer until one slot is available
+    void defer(server_task && task);
+
+    // Get the next id for creating a new task
+    int get_new_id();
+
+    // Register function to process a new task
+    void on_new_task(std::function<void(server_task &&)> callback);
+
+    // Register the function to be called when all slots data is ready to be processed
+    void on_update_slots(std::function<void(void)> callback);
+
+    // Call when the state of one slot is changed, it will move one task from deferred to main queue
+    void pop_deferred_task();
+
+    // end the start_loop routine
+    void terminate();
+
+    /**
+     * Main loop consists of these steps:
+     * - Wait until a new task arrives
+     * - Process the task (i.e. maybe copy data into slot)
+     * - Check if multitask is finished
+     * - Update all slots
+     */
+    void start_loop();
+
+    // for metrics
+    size_t queue_tasks_deferred_size() {
+        std::unique_lock<std::mutex> lock(mutex_tasks);
+        return queue_tasks_deferred.size();
+    }
+
+private:
+    void cleanup_pending_task(int id_target);
+};
+
+struct server_response {
+private:
+    bool running = true;
+
+    // for keeping track of all tasks waiting for the result
+    std::unordered_set<int> waiting_task_ids;
+
+    // the main result queue (using ptr for polymorphism)
+    std::vector<server_task_result_ptr> queue_results;
+
+    std::mutex mutex_results;
+    std::condition_variable condition_results;
+
+public:
+    // add the id_task to the list of tasks waiting for response
+    void add_waiting_task_id(int id_task);
+
+    void add_waiting_tasks(const std::vector<server_task> & tasks);
+
+    // when the request is finished, we can remove task associated with it
+    void remove_waiting_task_id(int id_task);
+
+    // remove multiple tasks from waiting list
+    void remove_waiting_task_ids(const std::unordered_set<int> & id_tasks);
+
+    // This function blocks the thread until there is a response for one of the id_tasks
+    server_task_result_ptr recv(const std::unordered_set<int> & id_tasks);
+
+    // same as recv(), but have timeout in seconds
+    // if timeout is reached, nullptr is returned
+    server_task_result_ptr recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout);
+
+    // single-task version of recv()
+    server_task_result_ptr recv(int id_task);
+
+    // Send a new result to a waiting id_task
+    void send(server_task_result_ptr && result);
+
+    // terminate the waiting loop
+    void terminate();
+};

+ 1192 - 0
tools/server/server-task.cpp

@@ -0,0 +1,1192 @@
+#include "server-common.h"
+#include "server-task.h"
+
+#include "common.h"
+#include "llama.h"
+#include "chat.h"
+#include "sampling.h"
+#include "json-schema-to-grammar.h"
+
+using json = nlohmann::ordered_json;
+
+//
+// task_params
+//
+
+json task_params::format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) const {
+    json data = json::array();
+    for (const auto & lb : logit_bias) {
+        data.push_back(json{
+            {"bias", lb.bias},
+            {"token", lb.token},
+        });
+    }
+    return data;
+}
+
+json task_params::to_json(bool only_metrics) const {
+    std::vector<std::string> samplers;
+    samplers.reserve(sampling.samplers.size());
+    for (const auto & sampler : sampling.samplers) {
+        samplers.emplace_back(common_sampler_type_to_str(sampler));
+    }
+
+    json lora = json::array();
+    for (size_t i = 0; i < this->lora.size(); ++i) {
+        lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
+    }
+
+    if (only_metrics) {
+        return json {
+            {"seed",                      sampling.seed},
+            {"temperature",               sampling.temp},
+            {"dynatemp_range",            sampling.dynatemp_range},
+            {"dynatemp_exponent",         sampling.dynatemp_exponent},
+            {"top_k",                     sampling.top_k},
+            {"top_p",                     sampling.top_p},
+            {"min_p",                     sampling.min_p},
+            {"top_n_sigma",               sampling.top_n_sigma},
+            {"xtc_probability",           sampling.xtc_probability},
+            {"xtc_threshold",             sampling.xtc_threshold},
+            {"typical_p",                 sampling.typ_p},
+            {"repeat_last_n",             sampling.penalty_last_n},
+            {"repeat_penalty",            sampling.penalty_repeat},
+            {"presence_penalty",          sampling.penalty_present},
+            {"frequency_penalty",         sampling.penalty_freq},
+            {"dry_multiplier",            sampling.dry_multiplier},
+            {"dry_base",                  sampling.dry_base},
+            {"dry_allowed_length",        sampling.dry_allowed_length},
+            {"dry_penalty_last_n",        sampling.dry_penalty_last_n},
+            {"mirostat",                  sampling.mirostat},
+            {"mirostat_tau",              sampling.mirostat_tau},
+            {"mirostat_eta",              sampling.mirostat_eta},
+            {"max_tokens",                n_predict},
+            {"n_predict",                 n_predict}, // TODO: deduplicate?
+            {"n_keep",                    n_keep},
+            {"n_discard",                 n_discard},
+            {"ignore_eos",                sampling.ignore_eos},
+            {"stream",                    stream},
+            {"n_probs",                   sampling.n_probs},
+            {"min_keep",                  sampling.min_keep},
+            {"chat_format",               common_chat_format_name(oaicompat_chat_syntax.format)},
+            {"reasoning_format",          common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)},
+            {"reasoning_in_content",      oaicompat_chat_syntax.reasoning_in_content},
+            {"thinking_forced_open",      oaicompat_chat_syntax.thinking_forced_open},
+            {"samplers",                  samplers},
+            {"speculative.n_max",         speculative.n_max},
+            {"speculative.n_min",         speculative.n_min},
+            {"speculative.p_min",         speculative.p_min},
+            {"timings_per_token",         timings_per_token},
+            {"post_sampling_probs",       post_sampling_probs},
+            {"lora",                      lora},
+        };
+    }
+
+    auto grammar_triggers = json::array();
+    for (const auto & trigger : sampling.grammar_triggers) {
+        server_grammar_trigger ct(trigger);
+        grammar_triggers.push_back(ct.to_json());
+    }
+
+    return json {
+        {"seed",                      sampling.seed},
+        {"temperature",               sampling.temp},
+        {"dynatemp_range",            sampling.dynatemp_range},
+        {"dynatemp_exponent",         sampling.dynatemp_exponent},
+        {"top_k",                     sampling.top_k},
+        {"top_p",                     sampling.top_p},
+        {"min_p",                     sampling.min_p},
+        {"top_n_sigma",               sampling.top_n_sigma},
+        {"xtc_probability",           sampling.xtc_probability},
+        {"xtc_threshold",             sampling.xtc_threshold},
+        {"typical_p",                 sampling.typ_p},
+        {"repeat_last_n",             sampling.penalty_last_n},
+        {"repeat_penalty",            sampling.penalty_repeat},
+        {"presence_penalty",          sampling.penalty_present},
+        {"frequency_penalty",         sampling.penalty_freq},
+        {"dry_multiplier",            sampling.dry_multiplier},
+        {"dry_base",                  sampling.dry_base},
+        {"dry_allowed_length",        sampling.dry_allowed_length},
+        {"dry_penalty_last_n",        sampling.dry_penalty_last_n},
+        {"dry_sequence_breakers",     sampling.dry_sequence_breakers},
+        {"mirostat",                  sampling.mirostat},
+        {"mirostat_tau",              sampling.mirostat_tau},
+        {"mirostat_eta",              sampling.mirostat_eta},
+        {"stop",                      antiprompt},
+        {"max_tokens",                n_predict},
+        {"n_predict",                 n_predict}, // TODO: deduplicate?
+        {"n_keep",                    n_keep},
+        {"n_discard",                 n_discard},
+        {"ignore_eos",                sampling.ignore_eos},
+        {"stream",                    stream},
+        {"logit_bias",                format_logit_bias(sampling.logit_bias)},
+        {"n_probs",                   sampling.n_probs},
+        {"min_keep",                  sampling.min_keep},
+        {"grammar",                   sampling.grammar},
+        {"grammar_lazy",              sampling.grammar_lazy},
+        {"grammar_triggers",          grammar_triggers},
+        {"preserved_tokens",          sampling.preserved_tokens},
+        {"chat_format",               common_chat_format_name(oaicompat_chat_syntax.format)},
+        {"reasoning_format",          common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)},
+        {"reasoning_in_content",      oaicompat_chat_syntax.reasoning_in_content},
+        {"thinking_forced_open",      oaicompat_chat_syntax.thinking_forced_open},
+        {"samplers",                  samplers},
+        {"speculative.n_max",         speculative.n_max},
+        {"speculative.n_min",         speculative.n_min},
+        {"speculative.p_min",         speculative.p_min},
+        {"timings_per_token",         timings_per_token},
+        {"post_sampling_probs",       post_sampling_probs},
+        {"lora",                      lora},
+    };
+}
+
+//
+// server_task
+//
+
+task_params server_task::params_from_json_cmpl(
+        const llama_context * ctx,
+        const common_params & params_base,
+        const json & data) {
+    const llama_model * model = llama_get_model(ctx);
+    const llama_vocab * vocab = llama_model_get_vocab(model);
+
+    task_params params;
+
+    // Sampling parameter defaults are loaded from the global server context (but individual requests can still them)
+    task_params defaults;
+    defaults.sampling    = params_base.sampling;
+    defaults.speculative = params_base.speculative;
+    defaults.n_keep      = params_base.n_keep;
+    defaults.n_predict   = params_base.n_predict;
+    defaults.antiprompt  = params_base.antiprompt;
+
+    // enabling this will output extra debug information in the HTTP responses from the server
+    params.verbose           = params_base.verbosity > 9;
+    params.timings_per_token = json_value(data, "timings_per_token", false);
+
+    params.stream           = json_value(data,       "stream",             false);
+    auto stream_opt         = json_value(data,       "stream_options",     json::object());
+    params.include_usage    = json_value(stream_opt, "include_usage",      false);
+    params.cache_prompt     = json_value(data,       "cache_prompt",       true);
+    params.return_tokens    = json_value(data,       "return_tokens",      false);
+    params.return_progress  = json_value(data,       "return_progress",    false);
+    params.n_predict        = json_value(data,       "n_predict",          json_value(data, "max_tokens", defaults.n_predict));
+    params.n_indent         = json_value(data,       "n_indent",           defaults.n_indent);
+    params.n_keep           = json_value(data,       "n_keep",             defaults.n_keep);
+    params.n_discard        = json_value(data,       "n_discard",          defaults.n_discard);
+    //params.t_max_prompt_ms  = json_value(data,       "t_max_prompt_ms",    defaults.t_max_prompt_ms); // TODO: implement
+    params.t_max_predict_ms = json_value(data,       "t_max_predict_ms",   defaults.t_max_predict_ms);
+    params.response_fields  = json_value(data,       "response_fields",    std::vector<std::string>());
+
+    params.sampling.top_k              = json_value(data, "top_k",               defaults.sampling.top_k);
+    params.sampling.top_p              = json_value(data, "top_p",               defaults.sampling.top_p);
+    params.sampling.min_p              = json_value(data, "min_p",               defaults.sampling.min_p);
+    params.sampling.top_n_sigma        = json_value(data, "top_n_sigma",         defaults.sampling.top_n_sigma);
+    params.sampling.xtc_probability    = json_value(data, "xtc_probability",     defaults.sampling.xtc_probability);
+    params.sampling.xtc_threshold      = json_value(data, "xtc_threshold",       defaults.sampling.xtc_threshold);
+    params.sampling.typ_p              = json_value(data, "typical_p",           defaults.sampling.typ_p);
+    params.sampling.temp               = json_value(data, "temperature",         defaults.sampling.temp);
+    params.sampling.dynatemp_range     = json_value(data, "dynatemp_range",      defaults.sampling.dynatemp_range);
+    params.sampling.dynatemp_exponent  = json_value(data, "dynatemp_exponent",   defaults.sampling.dynatemp_exponent);
+    params.sampling.penalty_last_n     = json_value(data, "repeat_last_n",       defaults.sampling.penalty_last_n);
+    params.sampling.penalty_repeat     = json_value(data, "repeat_penalty",      defaults.sampling.penalty_repeat);
+    params.sampling.penalty_freq       = json_value(data, "frequency_penalty",   defaults.sampling.penalty_freq);
+    params.sampling.penalty_present    = json_value(data, "presence_penalty",    defaults.sampling.penalty_present);
+    params.sampling.dry_multiplier     = json_value(data, "dry_multiplier",      defaults.sampling.dry_multiplier);
+    params.sampling.dry_base           = json_value(data, "dry_base",            defaults.sampling.dry_base);
+    params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length",  defaults.sampling.dry_allowed_length);
+    params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n",  defaults.sampling.dry_penalty_last_n);
+    params.sampling.mirostat           = json_value(data, "mirostat",            defaults.sampling.mirostat);
+    params.sampling.mirostat_tau       = json_value(data, "mirostat_tau",        defaults.sampling.mirostat_tau);
+    params.sampling.mirostat_eta       = json_value(data, "mirostat_eta",        defaults.sampling.mirostat_eta);
+    params.sampling.seed               = json_value(data, "seed",                defaults.sampling.seed);
+    params.sampling.n_probs            = json_value(data, "n_probs",             defaults.sampling.n_probs);
+    params.sampling.min_keep           = json_value(data, "min_keep",            defaults.sampling.min_keep);
+    params.post_sampling_probs         = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
+
+    params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
+    params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
+    params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
+
+    params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min);
+    params.speculative.n_min = std::max(params.speculative.n_min, 0);
+    params.speculative.n_max = std::max(params.speculative.n_max, 0);
+
+    // Use OpenAI API logprobs only if n_probs wasn't provided
+    if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){
+        params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs);
+    }
+
+    if (data.contains("lora")) {
+        if (data.at("lora").is_array()) {
+            params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora"));
+        } else {
+            throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
+        }
+    } else {
+        params.lora = params_base.lora_adapters;
+    }
+
+    // TODO: add more sanity checks for the input parameters
+
+    if (params.sampling.penalty_last_n < -1) {
+        throw std::runtime_error("Error: repeat_last_n must be >= -1");
+    }
+
+    if (params.sampling.dry_penalty_last_n < -1) {
+        throw std::runtime_error("Error: dry_penalty_last_n must be >= -1");
+    }
+
+    if (params.sampling.penalty_last_n == -1) {
+        // note: should be the slot's context and not the full context, but it's ok
+        params.sampling.penalty_last_n = llama_n_ctx(ctx);
+    }
+
+    if (params.sampling.dry_penalty_last_n == -1) {
+        params.sampling.dry_penalty_last_n = llama_n_ctx(ctx);
+    }
+
+    if (params.sampling.dry_base < 1.0f) {
+        params.sampling.dry_base = defaults.sampling.dry_base;
+    }
+
+    // sequence breakers for DRY
+    {
+        // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
+        // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
+
+        if (data.contains("dry_sequence_breakers")) {
+            params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
+            if (params.sampling.dry_sequence_breakers.empty()) {
+                throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings");
+            }
+        }
+    }
+
+    // process "json_schema" and "grammar"
+    if (data.contains("json_schema") && !data.contains("grammar")) {
+        try {
+            auto schema                  = json_value(data, "json_schema", json::object());
+            SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str());
+            params.sampling.grammar      = json_schema_to_grammar(schema);
+            SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str());
+        } catch (const std::exception & e) {
+            throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
+        }
+    } else {
+        params.sampling.grammar      = json_value(data, "grammar", defaults.sampling.grammar);
+        SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str());
+        params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy);
+        SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false");
+    }
+
+    {
+        auto it = data.find("chat_format");
+        if (it != data.end()) {
+            params.oaicompat_chat_syntax.format = static_cast<common_chat_format>(it->get<int>());
+            SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format));
+        } else {
+            params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format;
+        }
+        common_reasoning_format reasoning_format = params_base.reasoning_format;
+        if (data.contains("reasoning_format")) {
+            reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get<std::string>());
+        }
+        params.oaicompat_chat_syntax.reasoning_format = reasoning_format;
+        params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
+        params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
+        params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false);
+    }
+
+    {
+        const auto preserved_tokens = data.find("preserved_tokens");
+        if (preserved_tokens != data.end()) {
+            for (const auto & t : *preserved_tokens) {
+                auto ids = common_tokenize(vocab, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
+                if (ids.size() == 1) {
+                    SRV_DBG("Preserved token: %d\n", ids[0]);
+                    params.sampling.preserved_tokens.insert(ids[0]);
+                } else {
+                    // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
+                    SRV_DBG("Not preserved because more than 1 token: %s\n", t.get<std::string>().c_str());
+                }
+            }
+        }
+        const auto grammar_triggers = data.find("grammar_triggers");
+        if (grammar_triggers != data.end()) {
+            for (const auto & t : *grammar_triggers) {
+                server_grammar_trigger ct(t);
+                if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
+                    const auto & word = ct.value.value;
+                    auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true);
+                    if (ids.size() == 1) {
+                        auto token = ids[0];
+                        if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) {
+                            throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word);
+                        }
+                        SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str());
+                        common_grammar_trigger trigger;
+                        trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
+                        trigger.value = word;
+                        trigger.token = token;
+                        params.sampling.grammar_triggers.push_back(std::move(trigger));
+                    } else {
+                        SRV_DBG("Grammar trigger word: `%s`\n", word.c_str());
+                        params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
+                    }
+                } else {
+                    if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) {
+                        SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str());
+                    } else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) {
+                        SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str());
+                    } else {
+                        throw std::runtime_error("Unknown grammar trigger type");
+                    }
+                    params.sampling.grammar_triggers.emplace_back(std::move(ct.value));
+                }
+            }
+        }
+        if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) {
+            throw std::runtime_error("Error: no triggers set for lazy grammar!");
+        }
+    }
+
+    {
+        params.sampling.logit_bias.clear();
+
+        const auto & logit_bias = data.find("logit_bias");
+        if (logit_bias != data.end() && logit_bias->is_array()) {
+            const int n_vocab = llama_vocab_n_tokens(vocab);
+            for (const auto & el : *logit_bias) {
+                // TODO: we may want to throw errors here, in case "el" is incorrect
+                if (el.is_array() && el.size() == 2) {
+                    float bias;
+                    if (el[1].is_number()) {
+                        bias = el[1].get<float>();
+                    } else if (el[1].is_boolean() && !el[1].get<bool>()) {
+                        bias = -INFINITY;
+                    } else {
+                        continue;
+                    }
+
+                    if (el[0].is_number_integer()) {
+                        llama_token tok = el[0].get<llama_token>();
+                        if (tok >= 0 && tok < n_vocab) {
+                            params.sampling.logit_bias.push_back({tok, bias});
+                        }
+                    } else if (el[0].is_string()) {
+                        auto toks = common_tokenize(vocab, el[0].get<std::string>(), false);
+                        for (auto tok : toks) {
+                            params.sampling.logit_bias.push_back({tok, bias});
+                        }
+                    }
+                }
+            }
+        } else if (logit_bias != data.end() && logit_bias->is_object()) {
+            const int n_vocab = llama_vocab_n_tokens(vocab);
+            for (const auto & el : logit_bias->items()) {
+                float bias;
+                const auto & key = el.key();
+                const auto & value = el.value();
+                if (value.is_number()) {
+                    bias = value.get<float>();
+                } else if (value.is_boolean() && !value.get<bool>()) {
+                    bias = -INFINITY;
+                } else {
+                    continue;
+                }
+
+                char *end;
+                llama_token tok = strtol(key.c_str(), &end, 10);
+                if (*end == 0) {
+                    if (tok >= 0 && tok < n_vocab) {
+                        params.sampling.logit_bias.push_back({tok, bias});
+                    }
+                } else {
+                    auto toks = common_tokenize(vocab, key, false);
+                    for (auto tok : toks) {
+                        params.sampling.logit_bias.push_back({tok, bias});
+                    }
+                }
+            }
+        }
+
+        params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos);
+        if (params.sampling.ignore_eos) {
+            params.sampling.logit_bias.insert(
+                    params.sampling.logit_bias.end(),
+                    defaults.sampling.logit_bias_eog.begin(), defaults.sampling.logit_bias_eog.end());
+        }
+    }
+
+    {
+        params.antiprompt.clear();
+
+        const auto & stop = data.find("stop");
+        if (stop != data.end() && stop->is_array()) {
+            for (const auto & word : *stop) {
+                if (!word.empty()) {
+                    params.antiprompt.push_back(word);
+                }
+            }
+        }
+        // set reverse prompt from cli args if not set in the request
+        if (params.antiprompt.empty()) {
+            params.antiprompt = defaults.antiprompt;
+        }
+    }
+
+    {
+        const auto samplers = data.find("samplers");
+        if (samplers != data.end()) {
+            if (samplers->is_array()) {
+                params.sampling.samplers = common_sampler_types_from_names(*samplers, false);
+            } else if (samplers->is_string()){
+                params.sampling.samplers = common_sampler_types_from_chars(samplers->get<std::string>());
+            }
+        } else {
+            params.sampling.samplers = defaults.sampling.samplers;
+        }
+    }
+
+    std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
+    params.oaicompat_model = json_value(data, "model", model_name);
+
+    return params;
+}
+
+//
+// result_timings
+//
+
+json result_timings::to_json() const {
+    json base = {
+        {"cache_n",                cache_n},
+
+        {"prompt_n",               prompt_n},
+        {"prompt_ms",              prompt_ms},
+        {"prompt_per_token_ms",    prompt_per_token_ms},
+        {"prompt_per_second",      prompt_per_second},
+
+        {"predicted_n",            predicted_n},
+        {"predicted_ms",           predicted_ms},
+        {"predicted_per_token_ms", predicted_per_token_ms},
+        {"predicted_per_second",   predicted_per_second},
+    };
+
+    if (draft_n > 0) {
+        base["draft_n"] = draft_n;
+        base["draft_n_accepted"] = draft_n_accepted;
+    }
+
+    return base;
+}
+
+//
+// result_prompt_progress
+//
+json result_prompt_progress::to_json() const {
+    return json {
+        {"total",     total},
+        {"cache",     cache},
+        {"processed", processed},
+        {"time_ms",   time_ms},
+    };
+}
+
+static inline std::string stop_type_to_str(stop_type type) {
+    switch (type) {
+        case STOP_TYPE_EOS:   return "eos";
+        case STOP_TYPE_WORD:  return "word";
+        case STOP_TYPE_LIMIT: return "limit";
+        default:              return "none";
+    }
+}
+
+//
+// completion_token_output
+//
+
+json completion_token_output::to_json(bool post_sampling_probs) const {
+    json probs_for_token = json::array();
+    for (const auto & p : probs) {
+        std::string txt(p.txt);
+        txt.resize(validate_utf8(txt));
+        probs_for_token.push_back(json {
+            {"id",      p.tok},
+            {"token",   txt},
+            {"bytes",   str_to_bytes(p.txt)},
+            {
+                post_sampling_probs ? "prob" : "logprob",
+                post_sampling_probs ? p.prob : logarithm(p.prob)
+            },
+        });
+    }
+    return probs_for_token;
+}
+
+json completion_token_output::probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs) {
+    json out = json::array();
+    for (const auto & p : probs) {
+        std::string txt(p.text_to_send);
+        txt.resize(validate_utf8(txt));
+        out.push_back(json {
+            {"id",           p.tok},
+            {"token",        txt},
+            {"bytes",        str_to_bytes(p.text_to_send)},
+            {
+                post_sampling_probs ? "prob" : "logprob",
+                post_sampling_probs ? p.prob : logarithm(p.prob)
+            },
+            {
+                post_sampling_probs ? "top_probs" : "top_logprobs",
+                p.to_json(post_sampling_probs)
+            },
+        });
+    }
+    return out;
+}
+
+float completion_token_output::logarithm(float x) {
+    // nlohmann::json converts -inf to null, so we need to prevent that
+    return x == 0.0f ? std::numeric_limits<float>::lowest() : std::log(x);
+}
+
+std::vector<unsigned char> completion_token_output::str_to_bytes(const std::string & str) {
+    std::vector<unsigned char> bytes;
+    for (unsigned char c : str) {
+        bytes.push_back(c);
+    }
+    return bytes;
+}
+
+//
+// server_task_result_cmpl_final
+//
+json server_task_result_cmpl_final::to_json() {
+    switch (oaicompat) {
+        case OAICOMPAT_TYPE_NONE:
+            return to_json_non_oaicompat();
+        case OAICOMPAT_TYPE_COMPLETION:
+            return to_json_oaicompat();
+        case OAICOMPAT_TYPE_CHAT:
+            return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat();
+        default:
+            GGML_ASSERT(false && "Invalid oaicompat_type");
+    }
+}
+
+json server_task_result_cmpl_final::to_json_non_oaicompat() {
+    json res = json {
+        {"index",               index},
+        {"content",             stream ? "" : content}, // in stream mode, content is already in last partial chunk
+        {"tokens",              stream ? llama_tokens {} : tokens},
+        {"id_slot",             id_slot},
+        {"stop",                true},
+        {"model",               oaicompat_model},
+        {"tokens_predicted",    n_decoded},
+        {"tokens_evaluated",    n_prompt_tokens},
+        {"generation_settings", generation_params.to_json()},
+        {"prompt",              prompt},
+        {"has_new_line",        has_new_line},
+        {"truncated",           truncated},
+        {"stop_type",           stop_type_to_str(stop)},
+        {"stopping_word",       stopping_word},
+        {"tokens_cached",       n_tokens_cached},
+        {"timings",             timings.to_json()},
+    };
+    if (!stream && !probs_output.empty()) {
+        res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
+    }
+    return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
+}
+
+json server_task_result_cmpl_final::to_json_oaicompat() {
+    std::time_t t = std::time(0);
+    json logprobs = json(nullptr); // OAI default to null
+    if (!stream && probs_output.size() > 0) {
+        logprobs = json{
+            {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
+        };
+    }
+    json finish_reason = "length";
+    if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
+        finish_reason = "stop";
+    }
+    json res = json {
+        {"choices",            json::array({
+            json{
+                {"text",          stream ? "" : content}, // in stream mode, content is already in last partial chunk
+                {"index",         index},
+                {"logprobs",      logprobs},
+                {"finish_reason", finish_reason},
+            }
+        })},
+        {"created",            t},
+        {"model",              oaicompat_model},
+        {"system_fingerprint", build_info},
+        {"object",             "text_completion"},
+        {"usage", json {
+            {"completion_tokens", n_decoded},
+            {"prompt_tokens",     n_prompt_tokens},
+            {"total_tokens",      n_decoded + n_prompt_tokens}
+        }},
+        {"id", oaicompat_cmpl_id}
+    };
+
+    // extra fields for debugging purposes
+    if (verbose) {
+        res["__verbose"] = to_json_non_oaicompat();
+    }
+    if (timings.prompt_n >= 0) {
+        res.push_back({"timings", timings.to_json()});
+    }
+
+    return res;
+}
+
+json server_task_result_cmpl_final::to_json_oaicompat_chat() {
+    std::string finish_reason = "length";
+    common_chat_msg msg;
+    if (!oaicompat_msg.empty()) {
+        msg = oaicompat_msg;
+    } else {
+        msg.role = "assistant";
+        msg.content = content;
+    }
+    if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
+        finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
+    }
+
+    json choice {
+        {"finish_reason", finish_reason},
+        {"index", 0},
+        {"message", msg.to_json_oaicompat<json>()},
+    };
+
+    if (!stream && probs_output.size() > 0) {
+        choice["logprobs"] = json{
+            {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
+        };
+    }
+
+    std::time_t t = std::time(0);
+
+    json res = json {
+        {"choices",            json::array({choice})},
+        {"created",            t},
+        {"model",              oaicompat_model},
+        {"system_fingerprint", build_info},
+        {"object",             "chat.completion"},
+        {"usage", json {
+            {"completion_tokens", n_decoded},
+            {"prompt_tokens",     n_prompt_tokens},
+            {"total_tokens",      n_decoded + n_prompt_tokens}
+        }},
+        {"id", oaicompat_cmpl_id}
+    };
+
+    // extra fields for debugging purposes
+    if (verbose) {
+        res["__verbose"] = to_json_non_oaicompat();
+    }
+    if (timings.prompt_n >= 0) {
+        res.push_back({"timings", timings.to_json()});
+    }
+
+    return res;
+}
+
+json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() {
+    std::time_t t = std::time(0);
+    std::string finish_reason = "length";
+    if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
+        finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls";
+    }
+
+    json deltas = json::array();
+    for (const auto & diff : oaicompat_msg_diffs) {
+        deltas.push_back({
+            {"choices", json::array({
+                json {
+                    {"finish_reason", nullptr},
+                    {"index", 0},
+                    {"delta", common_chat_msg_diff_to_json_oaicompat<json>(diff)},
+                },
+            })},
+            {"created", t},
+            {"id", oaicompat_cmpl_id},
+            {"model", oaicompat_model},
+            {"system_fingerprint", build_info},
+            {"object", "chat.completion.chunk"},
+        });
+    }
+
+    deltas.push_back({
+        {"choices", json::array({
+            json {
+                {"finish_reason", finish_reason},
+                {"index", 0},
+                {"delta", json::object()},
+            },
+        })},
+        {"created",            t},
+        {"id",                 oaicompat_cmpl_id},
+        {"model",              oaicompat_model},
+        {"system_fingerprint", build_info},
+        {"object",             "chat.completion.chunk"},
+    });
+
+    if (include_usage) {
+        // OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage
+        // https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices
+        deltas.push_back({
+            {"choices", json::array()},
+            {"created",            t},
+            {"id",                 oaicompat_cmpl_id},
+            {"model",              oaicompat_model},
+            {"system_fingerprint", build_info},
+            {"object",             "chat.completion.chunk"},
+            {"usage", json {
+                {"completion_tokens", n_decoded},
+                {"prompt_tokens",     n_prompt_tokens},
+                {"total_tokens",      n_decoded + n_prompt_tokens},
+            }},
+        });
+    }
+
+    if (timings.prompt_n >= 0) {
+        deltas.back().push_back({"timings", timings.to_json()});
+    }
+
+    // extra fields for debugging purposes
+    if (verbose && !deltas.empty()) {
+        deltas.front()["__verbose"] = to_json_non_oaicompat();
+    }
+
+    return deltas;
+}
+
+//
+// server_task_result_cmpl_partial
+//
+json server_task_result_cmpl_partial::to_json() {
+    switch (oaicompat) {
+        case OAICOMPAT_TYPE_NONE:
+            return to_json_non_oaicompat();
+        case OAICOMPAT_TYPE_COMPLETION:
+            return to_json_oaicompat();
+        case OAICOMPAT_TYPE_CHAT:
+            return to_json_oaicompat_chat();
+        default:
+            GGML_ASSERT(false && "Invalid oaicompat_type");
+    }
+}
+
+json server_task_result_cmpl_partial::to_json_non_oaicompat() {
+    // non-OAI-compat JSON
+    json res = json {
+        {"index",            index},
+        {"content",          content},
+        {"tokens",           tokens},
+        {"stop",             false},
+        {"id_slot",          id_slot},
+        {"tokens_predicted", n_decoded},
+        {"tokens_evaluated", n_prompt_tokens},
+    };
+    // populate the timings object when needed (usually for the last response or with timings_per_token enabled)
+    if (timings.prompt_n > 0) {
+        res.push_back({"timings", timings.to_json()});
+    }
+    if (is_progress) {
+        res.push_back({"prompt_progress", progress.to_json()});
+    }
+    if (!prob_output.probs.empty()) {
+        res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs);
+    }
+    return res;
+}
+
+json server_task_result_cmpl_partial::to_json_oaicompat() {
+    std::time_t t = std::time(0);
+    json logprobs = json(nullptr); // OAI default to null
+    if (prob_output.probs.size() > 0) {
+        logprobs = json{
+            {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
+        };
+    }
+    json res = json {
+        {"choices",            json::array({
+            json{
+                {"text",          content},
+                {"index",         index},
+                {"logprobs",      logprobs},
+                {"finish_reason", nullptr},
+            }
+        })},
+        {"created",            t},
+        {"model",              oaicompat_model},
+        {"system_fingerprint", build_info},
+        {"object",             "text_completion"},
+        {"id",                 oaicompat_cmpl_id}
+    };
+
+    // extra fields for debugging purposes
+    if (verbose) {
+        res["__verbose"] = to_json_non_oaicompat();
+    }
+    if (timings.prompt_n >= 0) {
+        res.push_back({"timings", timings.to_json()});
+    }
+    if (is_progress) {
+        res.push_back({"prompt_progress", progress.to_json()});
+    }
+
+    return res;
+}
+
+json server_task_result_cmpl_partial::to_json_oaicompat_chat() {
+    bool first = n_decoded == 1;
+    std::time_t t = std::time(0);
+    json choices;
+
+    std::vector<json> deltas;
+    auto add_delta = [&](const json & delta) {
+        deltas.push_back({
+            {"choices", json::array({
+                json {
+                    {"finish_reason", nullptr},
+                    {"index", 0},
+                    {"delta", delta},
+                },
+            })},
+            {"created", t},
+            {"id", oaicompat_cmpl_id},
+            {"model", oaicompat_model},
+            {"system_fingerprint", build_info},
+            {"object", "chat.completion.chunk"},
+        });
+    };
+    // We have to send an initial update to conform to openai behavior
+    if (first || is_progress) {
+        add_delta({
+            {"role", "assistant"},
+            {"content", nullptr},
+        });
+    }
+
+    for (const auto & diff : oaicompat_msg_diffs) {
+        add_delta(common_chat_msg_diff_to_json_oaicompat<json>(diff));
+    }
+
+    if (!deltas.empty()) {
+        auto & last_json = deltas[deltas.size() - 1];
+        GGML_ASSERT(last_json.at("choices").size() >= 1);
+
+        if (prob_output.probs.size() > 0) {
+            last_json.at("choices").at(0)["logprobs"] = json {
+                {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
+            };
+        }
+
+        if (timings.prompt_n >= 0) {
+            last_json.push_back({"timings", timings.to_json()});
+        }
+        if (is_progress) {
+            last_json.push_back({"prompt_progress", progress.to_json()});
+        }
+    }
+
+    return deltas;
+}
+
+//
+// server_task_result_embd
+//
+json server_task_result_embd::to_json() {
+    return oaicompat == OAICOMPAT_TYPE_EMBEDDING
+        ? to_json_oaicompat()
+        : to_json_non_oaicompat();
+}
+
+json server_task_result_embd::to_json_non_oaicompat() {
+    return json {
+        {"index",     index},
+        {"embedding", embedding},
+    };
+}
+
+json server_task_result_embd::to_json_oaicompat() {
+    return json {
+        {"index",            index},
+        {"embedding",        embedding[0]},
+        {"tokens_evaluated", n_tokens},
+    };
+}
+
+//
+// server_task_result_rerank
+//
+json server_task_result_rerank::to_json() {
+    return json {
+        {"index",            index},
+        {"score",            score},
+        {"tokens_evaluated", n_tokens},
+    };
+}
+
+//
+// server_task_result_error
+//
+json server_task_result_error::to_json() {
+    json res = format_error_response(err_msg, err_type);
+    if (err_type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) {
+        res["n_prompt_tokens"] = n_prompt_tokens;
+        res["n_ctx"]           = n_ctx;
+    }
+    return res;
+}
+
+//
+// server_task_result_metrics
+//
+json server_task_result_metrics::to_json() {
+    return json {
+        { "idle",                            n_idle_slots },
+        { "processing",                      n_processing_slots },
+        { "deferred",                        n_tasks_deferred },
+        { "t_start",                         t_start },
+
+        { "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total },
+        { "t_tokens_generation_total",       t_tokens_generation_total },
+        { "n_tokens_predicted_total",        n_tokens_predicted_total },
+        { "t_prompt_processing_total",       t_prompt_processing_total },
+
+        { "n_tokens_max",                    n_tokens_max },
+
+        { "n_prompt_tokens_processed",       n_prompt_tokens_processed },
+        { "t_prompt_processing",             t_prompt_processing },
+        { "n_tokens_predicted",              n_tokens_predicted },
+        { "t_tokens_generation",             t_tokens_generation },
+
+        { "n_decode_total",                  n_decode_total },
+        { "n_busy_slots_total",              n_busy_slots_total },
+
+        { "slots",                           slots_data },
+    };
+}
+
+//
+// server_task_result_slot_save_load
+//
+json server_task_result_slot_save_load::to_json() {
+    if (is_save) {
+        return json {
+            { "id_slot",   id_slot },
+            { "filename",  filename },
+            { "n_saved",   n_tokens },
+            { "n_written", n_bytes },
+            { "timings", {
+                { "save_ms", t_ms }
+            }},
+        };
+    }
+
+    return json {
+        { "id_slot",    id_slot },
+        { "filename",   filename },
+        { "n_restored", n_tokens },
+        { "n_read",     n_bytes },
+        { "timings", {
+            { "restore_ms", t_ms }
+        }},
+    };
+}
+
+//
+// server_task_result_slot_erase
+//
+json server_task_result_slot_erase::to_json() {
+    return json {
+        { "id_slot",  id_slot },
+        { "n_erased", n_erased },
+    };
+}
+
+//
+// server_task_result_apply_lora
+//
+
+json server_task_result_apply_lora::to_json() {
+    return json {{ "success", true }};
+}
+
+//
+// server_prompt_cache
+//
+size_t server_prompt_cache::size() const {
+    size_t res = 0;
+
+    for (const auto & state : states) {
+        res += state.size();
+    }
+
+    return res;
+}
+
+size_t server_prompt_cache::n_tokens() const {
+    size_t res = 0;
+
+    for (const auto & state : states) {
+        res += state.n_tokens();
+    }
+
+    return res;
+}
+
+server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size) {
+    // first check if the current state is contained fully in the cache
+    for (auto it = states.begin(); it != states.end(); ++it) {
+        const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens);
+
+        if (cur_lcp_len == (int) prompt.tokens.size()) {
+            SRV_WRN("%s", " - prompt is already in the cache, skipping\n");
+            return nullptr;
+        }
+    }
+
+    // next, remove any cached prompts that are fully contained in the current prompt
+    for (auto it = states.begin(); it != states.end();) {
+        const int len = it->tokens.get_common_prefix(prompt.tokens);
+
+        if (len == (int) it->tokens.size()) {
+            SRV_WRN(" - removing obsolete cached prompt with length %d\n", len);
+
+            it = states.erase(it);
+        } else {
+            ++it;
+        }
+    }
+
+    std::vector<uint8_t> state_data;
+
+    // check if we can allocate enough memory for the new state
+    try {
+        state_data.resize(state_size);
+    } catch (const std::bad_alloc & e) {
+        SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what());
+
+        limit_size = std::max<size_t>(1, 0.4*size());
+
+        SRV_WRN(" - cache size limit reduced to %.3f MiB\n", limit_size / (1024.0 * 1024.0));
+
+        update();
+
+        return nullptr;
+    }
+
+    // TODO: for some reason we can't copy server_tokens, so we have to do this workaround
+    auto & cur = states.emplace_back();
+    cur = {
+        /*.tokens      =*/ server_tokens(prompt.tokens.get_text_tokens(), false),
+        /*.data        =*/ std::move(state_data),
+        /*.checkpoints =*/ prompt.checkpoints,
+    };
+
+    return &cur;
+}
+
+bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) {
+    const int lcp_best = prompt.tokens.get_common_prefix(tokens_new);
+
+    float f_keep_best = float(lcp_best) / prompt.tokens.size();
+    float sim_best    = float(lcp_best) / tokens_new.size();
+
+    SRV_WRN(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
+
+    auto it_best = states.end();
+
+    // find the most similar cached prompt, that would also preserve the most context
+    for (auto it = states.begin(); it != states.end(); ++it) {
+        const int lcp_cur = it->tokens.get_common_prefix(tokens_new);
+
+        const float f_keep_cur = float(lcp_cur) / it->tokens.size();
+        const float sim_cur    = float(lcp_cur) / tokens_new.size();
+
+        // don't trash large prompts
+        if (f_keep_cur < 0.25f) {
+            continue;
+        }
+
+        if (f_keep_best < f_keep_cur && sim_best < sim_cur) {
+            f_keep_best = f_keep_cur;
+            sim_best    = sim_cur;
+
+            it_best = it;
+        }
+    }
+
+    if (it_best != states.end()) {
+        SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
+
+        const size_t size = it_best->data.size();
+        const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0);
+        if (n != size) {
+            SRV_WRN("failed to restore state with size %zu\n", size);
+
+            return false;
+        }
+
+        it_best->data.clear();
+        it_best->data.shrink_to_fit();
+
+        prompt = std::move(*it_best);
+
+        states.erase(it_best);
+    }
+
+    return true;
+}
+
+void server_prompt_cache::update() {
+    if (limit_size > 0) {
+        // always keep at least one state, regardless of the limits
+        while (states.size() > 1 && size() > limit_size) {
+            if (states.empty()) {
+                break;
+            }
+
+            SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0));
+
+            states.pop_front();
+        }
+    }
+
+    // average size per token
+    const float size_per_token = std::max<float>(1.0f, float(size()) / (std::max<size_t>(1, n_tokens())));
+
+    // dynamically increase the token limit if it can fit in the memory limit
+    const size_t limit_tokens_cur = limit_size > 0 ? std::max<size_t>(limit_tokens, limit_size/size_per_token) : limit_tokens;
+
+    if (limit_tokens > 0) {
+        while (states.size() > 1 && n_tokens() > limit_tokens_cur) {
+            if (states.empty()) {
+                break;
+            }
+
+            SRV_WRN(" - cache token limit (%zu, est: %zu) reached, removing oldest entry (size = %.3f MiB)\n",
+                    limit_tokens, limit_tokens_cur, states.front().size() / (1024.0 * 1024.0));
+
+            states.pop_front();
+        }
+    }
+
+    SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n",
+            states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur);
+
+    for (const auto & state : states) {
+        SRV_WRN("   - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n",
+                (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
+    }
+}

+ 453 - 0
tools/server/server-task.h

@@ -0,0 +1,453 @@
+#pragma once
+
+#include "common.h"
+#include "llama.h"
+
+#include <string>
+#include <unordered_set>
+#include <list>
+
+// TODO: prevent including the whole server-common.h as we only use server_tokens
+#include "server-common.h"
+
+using json = nlohmann::ordered_json;
+
+enum server_task_type {
+    SERVER_TASK_TYPE_COMPLETION,
+    SERVER_TASK_TYPE_EMBEDDING,
+    SERVER_TASK_TYPE_RERANK,
+    SERVER_TASK_TYPE_INFILL,
+    SERVER_TASK_TYPE_CANCEL,
+    SERVER_TASK_TYPE_NEXT_RESPONSE,
+    SERVER_TASK_TYPE_METRICS,
+    SERVER_TASK_TYPE_SLOT_SAVE,
+    SERVER_TASK_TYPE_SLOT_RESTORE,
+    SERVER_TASK_TYPE_SLOT_ERASE,
+    SERVER_TASK_TYPE_SET_LORA,
+};
+
+// TODO: change this to more generic "response_format" to replace the "format_response_*" in server-common
+enum oaicompat_type {
+    OAICOMPAT_TYPE_NONE,
+    OAICOMPAT_TYPE_CHAT,
+    OAICOMPAT_TYPE_COMPLETION,
+    OAICOMPAT_TYPE_EMBEDDING,
+};
+
+enum stop_type {
+    STOP_TYPE_NONE,
+    STOP_TYPE_EOS,
+    STOP_TYPE_WORD,
+    STOP_TYPE_LIMIT,
+};
+
+struct task_params {
+    bool stream          = true;
+    bool include_usage   = false;
+    bool cache_prompt    = true; // remember the prompt to avoid reprocessing all prompt
+    bool return_tokens   = false;
+    bool return_progress = false;
+
+    int32_t n_keep    =  0; // number of tokens to keep from initial prompt
+    int32_t n_discard =  0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
+    int32_t n_predict = -1; // new tokens to predict
+    int32_t n_indent  =  0; // minimum line indentation for the generated text in number of whitespace characters
+
+    int64_t t_max_prompt_ms  = -1; // TODO: implement
+    int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
+
+    std::vector<common_adapter_lora_info> lora;
+
+    std::vector<std::string> antiprompt;
+    std::vector<std::string> response_fields;
+    bool timings_per_token = false;
+    bool post_sampling_probs = false;
+
+    struct common_params_sampling sampling;
+    struct common_params_speculative speculative;
+
+    // OAI-compat fields
+    bool                         verbose                   = false;
+    oaicompat_type               oaicompat                 = OAICOMPAT_TYPE_NONE;
+    std::string                  oaicompat_model;
+    std::string                  oaicompat_cmpl_id;
+    common_chat_syntax           oaicompat_chat_syntax;
+
+    // Embeddings
+    int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
+
+    json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) const;
+    json to_json(bool only_metrics = false) const;
+};
+
+struct server_task {
+    int id    = -1; // to be filled by server_queue
+    int index = -1; // used when there are multiple prompts (batch request)
+
+    // used by SERVER_TASK_TYPE_CANCEL
+    int id_target = -1;
+    int id_slot   = -1;
+
+    // used by SERVER_TASK_TYPE_INFERENCE
+    task_params   params;
+    server_tokens tokens;
+
+    server_task_type type;
+
+    // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
+    struct slot_action {
+        int slot_id;
+        std::string filename;
+        std::string filepath;
+    };
+    slot_action slot_action;
+
+    // used by SERVER_TASK_TYPE_METRICS
+    bool metrics_reset_bucket = false;
+
+    // used by SERVER_TASK_TYPE_SET_LORA
+    std::vector<common_adapter_lora_info> set_lora;
+
+    server_task() = default;
+
+    server_task(server_task_type type) : type(type) {}
+
+    int32_t n_tokens() const {
+        return tokens.size();
+    }
+
+    static task_params params_from_json_cmpl(
+            const llama_context * ctx,
+            const common_params & params_base,
+            const json & data);
+
+    // utility function
+    static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
+        std::unordered_set<int> ids(tasks.size());
+        for (size_t i = 0; i < tasks.size(); i++) {
+            ids.insert(tasks[i].id);
+        }
+        return ids;
+    }
+};
+
+struct result_timings {
+    int32_t cache_n = -1;
+
+    int32_t prompt_n = -1;
+    double prompt_ms;
+    double prompt_per_token_ms;
+    double prompt_per_second;
+
+    int32_t predicted_n = -1;
+    double predicted_ms;
+    double predicted_per_token_ms;
+    double predicted_per_second;
+
+    // Optional speculative metrics - only included when > 0
+    int32_t draft_n = 0;
+    int32_t draft_n_accepted = 0;
+
+    json to_json() const;
+};
+
+struct result_prompt_progress {
+    int32_t total = 0;
+    int32_t cache = 0;
+    int32_t processed = 0;
+    int64_t time_ms = 0;
+
+    json to_json() const;
+};
+
+struct server_task_result {
+    int id           = -1;
+    int id_slot      = -1;
+    virtual bool is_error() {
+        // only used by server_task_result_error
+        return false;
+    }
+    virtual bool is_stop() {
+        // only used by server_task_result_cmpl_*
+        return true;
+    }
+    virtual int get_index() {
+        return -1;
+    }
+    virtual json to_json() = 0;
+    virtual ~server_task_result() = default;
+};
+
+// using shared_ptr for polymorphism of server_task_result
+using server_task_result_ptr = std::unique_ptr<server_task_result>;
+
+struct completion_token_output {
+    llama_token tok;
+    float prob;
+    std::string text_to_send;
+    struct prob_info {
+        llama_token tok;
+        std::string txt;
+        float prob;
+    };
+    std::vector<prob_info> probs;
+
+    json to_json(bool post_sampling_probs) const;
+
+    static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs);
+
+    static float logarithm(float x);
+
+    static std::vector<unsigned char> str_to_bytes(const std::string & str);
+
+};
+
+struct server_task_result_cmpl_final : server_task_result {
+    int index = 0;
+
+    std::string content;
+    llama_tokens tokens;
+
+    bool stream;
+    bool include_usage;
+    result_timings timings;
+    std::string prompt;
+
+    bool truncated;
+    int32_t n_decoded;
+    int32_t n_prompt_tokens;
+    int32_t n_tokens_cached;
+    bool has_new_line;
+    std::string stopping_word;
+    stop_type stop = STOP_TYPE_NONE;
+
+    bool post_sampling_probs;
+    std::vector<completion_token_output> probs_output;
+    std::vector<std::string>  response_fields;
+
+    task_params generation_params;
+
+    // OAI-compat fields
+    bool            verbose   = false;
+    oaicompat_type  oaicompat = OAICOMPAT_TYPE_NONE;
+    std::string     oaicompat_model;
+    std::string     oaicompat_cmpl_id;
+    common_chat_msg oaicompat_msg;
+
+    std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
+
+    virtual int get_index() override {
+        return index;
+    }
+
+    virtual bool is_stop() override {
+        return true; // in stream mode, final responses are considered stop
+    }
+
+    virtual json to_json() override;
+
+    json to_json_non_oaicompat();
+
+    json to_json_oaicompat();
+
+    json to_json_oaicompat_chat();
+
+    json to_json_oaicompat_chat_stream();
+};
+
+struct server_task_result_cmpl_partial : server_task_result {
+    int index = 0;
+
+    std::string  content;
+    llama_tokens tokens;
+
+    int32_t n_decoded;
+    int32_t n_prompt_tokens;
+
+    bool post_sampling_probs;
+    bool is_progress = false;
+    completion_token_output prob_output;
+    result_timings timings;
+    result_prompt_progress progress;
+
+    // OAI-compat fields
+    bool            verbose   = false;
+    oaicompat_type  oaicompat = OAICOMPAT_TYPE_NONE;
+    std::string     oaicompat_model;
+    std::string     oaicompat_cmpl_id;
+    std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
+
+    virtual int get_index() override {
+        return index;
+    }
+
+    virtual bool is_stop() override {
+        return false; // in stream mode, partial responses are not considered stop
+    }
+
+    virtual json to_json() override;
+
+    json to_json_non_oaicompat();
+
+    json to_json_oaicompat();
+
+    json to_json_oaicompat_chat();
+};
+
+struct server_task_result_embd : server_task_result {
+    int index = 0;
+    std::vector<std::vector<float>> embedding;
+
+    int32_t n_tokens;
+
+    // OAI-compat fields
+    oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
+
+    virtual int get_index() override {
+        return index;
+    }
+
+    virtual json to_json() override;
+
+    json to_json_non_oaicompat();
+
+    json to_json_oaicompat();
+};
+
+struct server_task_result_rerank : server_task_result {
+    int index = 0;
+    float score = -1e6;
+
+    int32_t n_tokens;
+
+    virtual int get_index() override {
+        return index;
+    }
+
+    virtual json to_json() override;
+};
+
+struct server_task_result_error : server_task_result {
+    int index = 0;
+    error_type err_type = ERROR_TYPE_SERVER;
+    std::string err_msg;
+
+    // for ERROR_TYPE_EXCEED_CONTEXT_SIZE
+    int32_t n_prompt_tokens = 0;
+    int32_t n_ctx           = 0;
+
+    virtual bool is_error() override {
+        return true;
+    }
+
+    virtual json to_json() override;
+};
+
+struct server_task_result_metrics : server_task_result {
+    int n_idle_slots;
+    int n_processing_slots;
+    int n_tasks_deferred;
+    int64_t t_start;
+
+    // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields
+    uint64_t n_prompt_tokens_processed_total = 0;
+    uint64_t t_prompt_processing_total       = 0;
+    uint64_t n_tokens_predicted_total        = 0;
+    uint64_t t_tokens_generation_total       = 0;
+
+    uint64_t n_tokens_max = 0;
+
+    uint64_t n_prompt_tokens_processed = 0;
+    uint64_t t_prompt_processing       = 0;
+
+    uint64_t n_tokens_predicted  = 0;
+    uint64_t t_tokens_generation = 0;
+
+    uint64_t n_decode_total     = 0;
+    uint64_t n_busy_slots_total = 0;
+
+    // while we can also use std::vector<server_slot> this requires copying the slot object which can be quite messy
+    // therefore, we use json to temporarily store the slot.to_json() result
+    json slots_data = json::array();
+
+    virtual json to_json() override;
+};
+
+struct server_task_result_slot_save_load : server_task_result {
+    std::string filename;
+    bool is_save; // true = save, false = load
+
+    size_t n_tokens;
+    size_t n_bytes;
+    double t_ms;
+
+    virtual json to_json() override;
+};
+
+struct server_task_result_slot_erase : server_task_result {
+    size_t n_erased;
+
+    virtual json to_json() override;
+};
+
+struct server_task_result_apply_lora : server_task_result {
+    virtual json to_json() override;
+};
+
+struct server_prompt_checkpoint {
+    llama_pos pos_min;
+    llama_pos pos_max;
+
+    std::vector<uint8_t> data;
+
+    size_t size() const {
+        return data.size();
+    }
+};
+
+struct server_prompt {
+    server_tokens tokens;
+
+    std::vector<uint8_t> data;
+
+    std::list<server_prompt_checkpoint> checkpoints;
+
+    size_t size() const {
+        size_t res = data.size();
+
+        for (const auto & checkpoint : checkpoints) {
+            res += checkpoint.size();
+        }
+
+        return res;
+    }
+
+    int n_tokens() const {
+        return tokens.size();
+    }
+};
+
+struct server_prompt_cache {
+    server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) {
+        this->limit_size   = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib);
+        this->limit_tokens = limit_tokens;
+    }
+
+    std::list<server_prompt> states;
+
+    // in bytes, 0 = no limit
+    size_t limit_size = 0;
+
+    // in tokens, 0 = no limit
+    size_t limit_tokens = 0;
+
+    size_t size() const;
+
+    size_t n_tokens() const;
+
+    server_prompt * alloc(const server_prompt & prompt, size_t state_size);
+
+    bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot);
+
+    void update();
+};

File diff suppressed because it is too large
+ 30 - 1584
tools/server/server.cpp


Some files were not shown because too many files changed in this diff