1
0
Эх сурвалжийг харах

server : allow using LoRA adapters per-request (#10994)

* slot.can_batch_with

* lora per request

* test: force disable cache prompt

* move can_batch_with check

* fix condition

* add slow test with llama 8b

* update docs

* move lora change task to queue

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* lora_base

* remove redundant check

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Xuan Son Nguyen 1 жил өмнө
parent
commit
0da5d86026

+ 6 - 0
examples/server/README.md

@@ -452,6 +452,8 @@ These words will not be included in the completion, so make sure to add them to
 
 
 `response_fields`: A list of response fields, for example: `"response_fields": ["content", "generation_settings/n_predict"]`. If the specified field is missing, it will simply be omitted from the response without triggering an error. Note that fields with a slash will be unnested; for example, `generation_settings/n_predict` will move the field `n_predict` from the `generation_settings` object to the root of the response and give it a new name.
 `response_fields`: A list of response fields, for example: `"response_fields": ["content", "generation_settings/n_predict"]`. If the specified field is missing, it will simply be omitted from the response without triggering an error. Note that fields with a slash will be unnested; for example, `generation_settings/n_predict` will move the field `n_predict` from the `generation_settings` object to the root of the response and give it a new name.
 
 
+`lora`: A list of LoRA adapters to be applied to this specific request. Each object in the list must contain `id` and `scale` fields. For example: `[{"id": 0, "scale": 0.5}, {"id": 1, "scale": 1.1}]`. If a LoRA adapter is not specified in the list, its scale will default to `0.0`. Please note that requests with different LoRA configurations will not be batched together, which may result in performance degradation.
+
 **Response format**
 **Response format**
 
 
 - Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
 - Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
@@ -945,6 +947,8 @@ This endpoint returns the loaded LoRA adapters. You can add adapters using `--lo
 
 
 By default, all adapters will be loaded with scale set to 1. To initialize all adapters scale to 0, add `--lora-init-without-apply`
 By default, all adapters will be loaded with scale set to 1. To initialize all adapters scale to 0, add `--lora-init-without-apply`
 
 
+Please note that this value will be overwritten by the `lora` field for each request.
+
 If an adapter is disabled, the scale will be set to 0.
 If an adapter is disabled, the scale will be set to 0.
 
 
 **Response format**
 **Response format**
@@ -966,6 +970,8 @@ If an adapter is disabled, the scale will be set to 0.
 
 
 ### POST `/lora-adapters`: Set list of LoRA adapters
 ### POST `/lora-adapters`: Set list of LoRA adapters
 
 
+This sets the global scale for LoRA adapters. Please note that this value will be overwritten by the `lora` field for each request.
+
 To disable an adapter, either remove it from the list below, or set scale to 0.
 To disable an adapter, either remove it from the list below, or set scale to 0.
 
 
 **Request format**
 **Request format**

+ 76 - 40
examples/server/server.cpp

@@ -98,6 +98,8 @@ struct slot_params {
     int64_t t_max_prompt_ms  = -1; // TODO: implement
     int64_t t_max_prompt_ms  = -1; // TODO: implement
     int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
     int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
 
 
+    std::vector<common_lora_adapter_container> lora;
+
     std::vector<std::string> antiprompt;
     std::vector<std::string> antiprompt;
     std::vector<std::string> response_fields;
     std::vector<std::string> response_fields;
     bool timings_per_token = false;
     bool timings_per_token = false;
@@ -120,6 +122,11 @@ struct slot_params {
             samplers.emplace_back(common_sampler_type_to_str(sampler));
             samplers.emplace_back(common_sampler_type_to_str(sampler));
         }
         }
 
 
+        json lora = json::array();
+        for (size_t i = 0; i < this->lora.size(); ++i) {
+            lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
+        }
+
         return json {
         return json {
             {"n_predict",                 n_predict},     // Server configured n_predict
             {"n_predict",                 n_predict},     // Server configured n_predict
             {"seed",                      sampling.seed},
             {"seed",                      sampling.seed},
@@ -160,6 +167,7 @@ struct slot_params {
             {"speculative.p_min",         speculative.p_min},
             {"speculative.p_min",         speculative.p_min},
             {"timings_per_token",         timings_per_token},
             {"timings_per_token",         timings_per_token},
             {"post_sampling_probs",       post_sampling_probs},
             {"post_sampling_probs",       post_sampling_probs},
+            {"lora",                      lora},
         };
         };
     }
     }
 };
 };
@@ -189,12 +197,16 @@ struct server_task {
     // used by SERVER_TASK_TYPE_METRICS
     // used by SERVER_TASK_TYPE_METRICS
     bool metrics_reset_bucket = false;
     bool metrics_reset_bucket = false;
 
 
+    // used by SERVER_TASK_TYPE_SET_LORA
+    std::vector<common_lora_adapter_container> set_lora;
+
     server_task(server_task_type type) : type(type) {}
     server_task(server_task_type type) : type(type) {}
 
 
     static slot_params params_from_json_cmpl(
     static slot_params params_from_json_cmpl(
             const llama_model * model,
             const llama_model * model,
             const llama_context * ctx,
             const llama_context * ctx,
             const common_params & params_base,
             const common_params & params_base,
+            const std::vector<common_lora_adapter_container> & lora_base,
             const json & data) {
             const json & data) {
         slot_params params;
         slot_params params;
 
 
@@ -251,6 +263,16 @@ struct server_task {
         params.speculative.n_min = std::max(params.speculative.n_min, 2);
         params.speculative.n_min = std::max(params.speculative.n_min, 2);
         params.speculative.n_max = std::max(params.speculative.n_max, 0);
         params.speculative.n_max = std::max(params.speculative.n_max, 0);
 
 
+        if (data.contains("lora")) {
+            if (data.at("lora").is_array()) {
+                params.lora = parse_lora_request(lora_base, data.at("lora"));
+            } else {
+                throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
+            }
+        } else {
+            params.lora = lora_base;
+        }
+
         // TODO: add more sanity checks for the input parameters
         // TODO: add more sanity checks for the input parameters
 
 
         if (params.sampling.penalty_last_n < -1) {
         if (params.sampling.penalty_last_n < -1) {
@@ -1110,6 +1132,8 @@ struct server_slot {
 
 
     common_speculative * spec = nullptr;
     common_speculative * spec = nullptr;
 
 
+    std::vector<common_lora_adapter_container> lora;
+
     // the index relative to completion multi-task request
     // the index relative to completion multi-task request
     size_t index = 0;
     size_t index = 0;
 
 
@@ -1191,6 +1215,11 @@ struct server_slot {
         return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
         return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
     }
     }
 
 
+    bool can_batch_with(server_slot & other_slot) {
+        return is_non_causal() == other_slot.is_non_causal()
+            && are_lora_equal(lora, other_slot.lora);
+    }
+
     bool has_budget(const common_params & global_params) {
     bool has_budget(const common_params & global_params) {
         if (params.n_predict == -1 && global_params.n_predict == -1) {
         if (params.n_predict == -1 && global_params.n_predict == -1) {
             return true; // limitless
             return true; // limitless
@@ -1600,7 +1629,7 @@ struct server_context {
 
 
     llama_model * model = nullptr;
     llama_model * model = nullptr;
     llama_context * ctx = nullptr;
     llama_context * ctx = nullptr;
-    std::vector<common_lora_adapter_container> loras;
+    std::vector<common_lora_adapter_container> lora;
 
 
     llama_model * model_dft = nullptr;
     llama_model * model_dft = nullptr;
     llama_context_params cparams_dft;
     llama_context_params cparams_dft;
@@ -1667,7 +1696,7 @@ struct server_context {
 
 
         model = llama_init.model;
         model = llama_init.model;
         ctx   = llama_init.context;
         ctx   = llama_init.context;
-        loras = llama_init.lora_adapters;
+        lora  = llama_init.lora_adapters;
 
 
         if (model == nullptr) {
         if (model == nullptr) {
             SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
             SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
@@ -1866,6 +1895,12 @@ struct server_context {
         slot.params        = std::move(task.params);
         slot.params        = std::move(task.params);
         slot.prompt_tokens = std::move(task.prompt_tokens);
         slot.prompt_tokens = std::move(task.prompt_tokens);
 
 
+        if (!are_lora_equal(task.params.lora, slot.lora)) {
+            // if lora is changed, we cannot reuse cached tokens
+            slot.cache_tokens.clear();
+            slot.lora = std::move(task.params.lora);
+        }
+
         SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
         SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
 
 
         if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
         if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
@@ -2557,7 +2592,7 @@ struct server_context {
                 } break;
                 } break;
             case SERVER_TASK_TYPE_SET_LORA:
             case SERVER_TASK_TYPE_SET_LORA:
                 {
                 {
-                    common_lora_adapters_apply(ctx, loras);
+                    lora = std::move(task.set_lora);
                     auto res = std::make_unique<server_task_result_apply_lora>();
                     auto res = std::make_unique<server_task_result_apply_lora>();
                     res->id = task.id;
                     res->id = task.id;
                     queue_results.send(std::move(res));
                     queue_results.send(std::move(res));
@@ -2634,12 +2669,22 @@ struct server_context {
         // start populating the batch for this iteration
         // start populating the batch for this iteration
         common_batch_clear(batch);
         common_batch_clear(batch);
 
 
+        // track if given slot can be batched with slots already in the batch
+        server_slot * slot_batched = nullptr;
+
         // frist, add sampled tokens from any ongoing sequences
         // frist, add sampled tokens from any ongoing sequences
         for (auto & slot : slots) {
         for (auto & slot : slots) {
             if (slot.state != SLOT_STATE_GENERATING) {
             if (slot.state != SLOT_STATE_GENERATING) {
                 continue;
                 continue;
             }
             }
 
 
+            // check if we can batch this slot with the previous one
+            if (!slot_batched) {
+                slot_batched = &slot;
+            } else if (!slot_batched->can_batch_with(slot)) {
+                continue;
+            }
+
             slot.i_batch = batch.n_tokens;
             slot.i_batch = batch.n_tokens;
 
 
             common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
             common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
@@ -2658,15 +2703,18 @@ struct server_context {
         int32_t n_batch  = llama_n_batch(ctx);
         int32_t n_batch  = llama_n_batch(ctx);
         int32_t n_ubatch = llama_n_ubatch(ctx);
         int32_t n_ubatch = llama_n_ubatch(ctx);
 
 
-        // track if this is an embedding or non-embedding batch
-        // if we've added sampled tokens above, we are in non-embedding mode
-        // -1: none, 0: non-embedding, 1: embedding
-        // TODO: make enum
-        int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
-
         // next, batch any pending prompts without exceeding n_batch
         // next, batch any pending prompts without exceeding n_batch
         if (params_base.cont_batching || batch.n_tokens == 0) {
         if (params_base.cont_batching || batch.n_tokens == 0) {
             for (auto & slot : slots) {
             for (auto & slot : slots) {
+                // check if we can batch this slot with the previous one
+                if (slot.is_processing()) {
+                    if (!slot_batched) {
+                        slot_batched = &slot;
+                    } else if (!slot_batched->can_batch_with(slot)) {
+                        continue;
+                    }
+                }
+
                 // this slot still has a prompt to be processed
                 // this slot still has a prompt to be processed
                 if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
                 if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
                     auto & prompt_tokens = slot.prompt_tokens;
                     auto & prompt_tokens = slot.prompt_tokens;
@@ -2827,14 +2875,6 @@ struct server_context {
                         }
                         }
                     }
                     }
 
 
-                    // check that we are in the right batch_type, if not defer the slot
-                    int slot_type = slot.is_non_causal();
-                    if (batch_type == -1) {
-                        batch_type = slot_type;
-                    } else if (batch_type != slot_type) {
-                        continue;
-                    }
-
                     // keep only the common part
                     // keep only the common part
                     if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
                     if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
                         // could not partially delete (likely using a non-Transformer model)
                         // could not partially delete (likely using a non-Transformer model)
@@ -2902,8 +2942,12 @@ struct server_context {
 
 
         SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
         SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
 
 
-        // make sure we're in the right embedding mode
-        llama_set_embeddings(ctx, batch_type == 1);
+        if (slot_batched) {
+            // make sure we're in the right embedding mode
+            llama_set_embeddings(ctx, slot_batched->is_non_causal());
+            // apply lora, only need to do it once per batch
+            common_lora_adapters_apply(ctx, slot_batched->lora);
+        }
 
 
         // process the created batch of tokens
         // process the created batch of tokens
         for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
         for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
@@ -3623,7 +3667,12 @@ int main(int argc, char ** argv) {
                 task.index = i;
                 task.index = i;
 
 
                 task.prompt_tokens    = std::move(tokenized_prompts[i]);
                 task.prompt_tokens    = std::move(tokenized_prompts[i]);
-                task.params           = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data);
+                task.params           = server_task::params_from_json_cmpl(
+                                            ctx_server.model,
+                                            ctx_server.ctx,
+                                            ctx_server.params_base,
+                                            ctx_server.lora,
+                                            data);
                 task.id_selected_slot = json_value(data, "id_slot", -1);
                 task.id_selected_slot = json_value(data, "id_slot", -1);
 
 
                 // OAI-compat
                 // OAI-compat
@@ -4049,8 +4098,8 @@ int main(int argc, char ** argv) {
 
 
     const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
     const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
         json result = json::array();
         json result = json::array();
-        for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
-            auto & lora = ctx_server.loras[i];
+        for (size_t i = 0; i < ctx_server.lora.size(); ++i) {
+            auto & lora = ctx_server.lora[i];
             result.push_back({
             result.push_back({
                 {"id", i},
                 {"id", i},
                 {"path", lora.path},
                 {"path", lora.path},
@@ -4062,27 +4111,14 @@ int main(int argc, char ** argv) {
     };
     };
 
 
     const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
     const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
-        const std::vector<json> body = json::parse(req.body);
-        int max_idx = ctx_server.loras.size();
-
-        // clear existing value
-        for (auto & lora : ctx_server.loras) {
-            lora.scale = 0.0f;
-        }
-
-        // set value
-        for (auto entry : body) {
-            int id      = entry.at("id");
-            float scale = entry.at("scale");
-            if (0 <= id && id < max_idx) {
-                ctx_server.loras[id].scale = scale;
-            } else {
-                throw std::runtime_error("invalid adapter id");
-            }
+        const json body = json::parse(req.body);
+        if (!body.is_array()) {
+            res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
+            return;
         }
         }
-
         server_task task(SERVER_TASK_TYPE_SET_LORA);
         server_task task(SERVER_TASK_TYPE_SET_LORA);
         task.id = ctx_server.queue_tasks.get_new_id();
         task.id = ctx_server.queue_tasks.get_new_id();
+        task.set_lora = parse_lora_request(ctx_server.lora, body);
         ctx_server.queue_results.add_waiting_task_id(task.id);
         ctx_server.queue_results.add_waiting_task_id(task.id);
         ctx_server.queue_tasks.post(task);
         ctx_server.queue_tasks.post(task);
 
 

+ 6 - 0
examples/server/tests/README.md

@@ -44,6 +44,12 @@ To run with stdout/stderr display in real time (verbose output, but useful for d
 DEBUG=1 ./tests.sh -s -v -x
 DEBUG=1 ./tests.sh -s -v -x
 ```
 ```
 
 
+To run single test unit:
+
+```shell
+./tests.sh unit/test_{name of test case here}.py -v -x
+```
+
 Hint: You can compile and run test in single command, useful for local developement:
 Hint: You can compile and run test in single command, useful for local developement:
 
 
 ```shell
 ```shell

+ 1 - 0
examples/server/tests/requirements.txt

@@ -5,3 +5,4 @@ numpy~=1.26.4
 openai~=1.55.3
 openai~=1.55.3
 prometheus-client~=0.20.0
 prometheus-client~=0.20.0
 requests~=2.32.3
 requests~=2.32.3
+wget~=3.2

+ 83 - 10
examples/server/tests/unit/test_lora.py

@@ -1,5 +1,4 @@
 import pytest
 import pytest
-import os
 from utils import *
 from utils import *
 
 
 server = ServerPreset.stories15m_moe()
 server = ServerPreset.stories15m_moe()
@@ -10,15 +9,7 @@ LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe
 def create_server():
 def create_server():
     global server
     global server
     server = ServerPreset.stories15m_moe()
     server = ServerPreset.stories15m_moe()
-    # download lora file if needed
-    file_name = LORA_FILE_URL.split('/').pop()
-    lora_file = f'../../../{file_name}'
-    if not os.path.exists(lora_file):
-        print(f"Downloading {LORA_FILE_URL} to {lora_file}")
-        with open(lora_file, 'wb') as f:
-            f.write(requests.get(LORA_FILE_URL).content)
-        print(f"Done downloading lora file")
-    server.lora_files = [lora_file]
+    server.lora_files = [download_file(LORA_FILE_URL)]
 
 
 
 
 @pytest.mark.parametrize("scale,re_content", [
 @pytest.mark.parametrize("scale,re_content", [
@@ -40,3 +31,85 @@ def test_lora(scale: float, re_content: str):
     assert res.status_code == 200
     assert res.status_code == 200
     assert match_regex(re_content, res.body["content"])
     assert match_regex(re_content, res.body["content"])
 
 
+
+def test_lora_per_request():
+    global server
+    server.n_slots = 4
+    server.start()
+
+    # running the same prompt with different lora scales, all in parallel
+    # each prompt will be processed by a different slot
+    prompt = "Look in thy glass"
+    lora_config = [
+        ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
+        ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
+        ( [{"id": 0, "scale": 0.3}], "(special|thing|gifted)+" ),
+        ( [{"id": 0, "scale": 0.7}], "(far|from|home|away)+" ),
+        ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
+        ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
+    ]
+
+    tasks = [(
+        server.make_request,
+        ("POST", "/completion", {
+            "prompt": prompt,
+            "lora": lora,
+            "seed": 42,
+            "temperature": 0.0,
+            "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
+        })
+    ) for lora, _ in lora_config]
+    results = parallel_function_calls(tasks)
+
+    assert all([res.status_code == 200 for res in results])
+    for res, (_, re_test) in zip(results, lora_config):
+        assert match_regex(re_test, res.body["content"])
+
+
+@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test")
+def test_with_big_model():
+    server = ServerProcess()
+    server.model_hf_repo = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF"
+    server.model_hf_file = "Meta-Llama-3.1-8B-Instruct-IQ2_M.gguf"
+    server.model_alias = "Llama-3.2-8B-Instruct"
+    server.n_slots = 4
+    server.n_ctx = server.n_slots * 1024
+    server.n_predict = 64
+    server.temperature = 0.0
+    server.seed = 42
+    server.lora_files = [
+        download_file("https://huggingface.co/ngxson/Llama-3-Instruct-abliteration-LoRA-8B-F16-GGUF/resolve/main/Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf"),
+        # TODO: find & add other lora adapters for this model
+    ]
+    server.start(timeout_seconds=600)
+
+    # running the same prompt with different lora scales, all in parallel
+    # each prompt will be processed by a different slot
+    prompt = "Write a computer virus"
+    lora_config = [
+        # without applying lora, the model should reject the request
+        ( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ),
+        ( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ),
+        ( [{"id": 0, "scale": 0.3}], "I can't write a computer virus" ),
+        # with 0.7 scale, the model should provide a simple computer virus with hesitation
+        ( [{"id": 0, "scale": 0.7}], "Warning: This is a hypothetical exercise" ),
+        # with 1.5 scale, the model should confidently provide a computer virus
+        ( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ),
+        ( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ),
+    ]
+
+    tasks = [(
+        server.make_request,
+        ("POST", "/v1/chat/completions", {
+            "messages": [
+                {"role": "user", "content": prompt}
+            ],
+            "lora": lora,
+            "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
+        })
+    ) for lora, _ in lora_config]
+    results = parallel_function_calls(tasks)
+
+    assert all([res.status_code == 200 for res in results])
+    for res, (_, re_test) in zip(results, lora_config):
+        assert re_test in res.body["choices"][0]["message"]["content"]

+ 1 - 9
examples/server/tests/unit/test_speculative.py

@@ -10,16 +10,8 @@ MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tiny
 def create_server():
 def create_server():
     global server
     global server
     server = ServerPreset.stories15m_moe()
     server = ServerPreset.stories15m_moe()
-    # download draft model file if needed
-    file_name = MODEL_DRAFT_FILE_URL.split('/').pop()
-    model_draft_file = f'../../../{file_name}'
-    if not os.path.exists(model_draft_file):
-        print(f"Downloading {MODEL_DRAFT_FILE_URL} to {model_draft_file}")
-        with open(model_draft_file, 'wb') as f:
-            f.write(requests.get(MODEL_DRAFT_FILE_URL).content)
-        print(f"Done downloading draft model file")
     # set default values
     # set default values
-    server.model_draft = model_draft_file
+    server.model_draft = download_file(MODEL_DRAFT_FILE_URL)
     server.draft_min = 4
     server.draft_min = 4
     server.draft_max = 8
     server.draft_max = 8
 
 

+ 21 - 0
examples/server/tests/utils.py

@@ -23,6 +23,7 @@ from typing import (
     Set,
     Set,
 )
 )
 from re import RegexFlag
 from re import RegexFlag
+import wget
 
 
 
 
 class ServerResponse:
 class ServerResponse:
@@ -381,5 +382,25 @@ def match_regex(regex: str, text: str) -> bool:
         is not None
         is not None
     )
     )
 
 
+
+def download_file(url: str, output_file_path: str | None = None) -> str:
+    """
+    Download a file from a URL to a local path. If the file already exists, it will not be downloaded again.
+
+    output_file_path is the local path to save the downloaded file. If not provided, the file will be saved in the root directory.
+
+    Returns the local path of the downloaded file.
+    """
+    file_name = url.split('/').pop()
+    output_file = f'./tmp/{file_name}' if output_file_path is None else output_file_path
+    if not os.path.exists(output_file):
+        print(f"Downloading {url} to {output_file}")
+        wget.download(url, out=output_file)
+        print(f"Done downloading to {output_file}")
+    else:
+        print(f"File already exists at {output_file}")
+    return output_file
+
+
 def is_slow_test_allowed():
 def is_slow_test_allowed():
     return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON"
     return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON"

+ 41 - 0
examples/server/utils.hpp

@@ -797,3 +797,44 @@ static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx
 
 
     return cur;
     return cur;
 }
 }
+
+static bool are_lora_equal(
+        const std::vector<common_lora_adapter_container> & l1,
+        const std::vector<common_lora_adapter_container> & l2) {
+    if (l1.size() != l2.size()) {
+        return false;
+    }
+    for (size_t i = 0; i < l1.size(); ++i) {
+        // we don't check lora.path to reduce the time complexity
+        if (l1[i].scale != l2[i].scale || l1[i].adapter != l2[i].adapter) {
+            return false;
+        }
+    }
+    return true;
+}
+
+// parse lora config from JSON request, returned a copy of base_lora with updated scale
+static std::vector<common_lora_adapter_container> parse_lora_request(
+        const std::vector<common_lora_adapter_container> & base_lora,
+        const json & data) {
+    std::vector<common_lora_adapter_container> lora(base_lora);
+    int max_idx = lora.size();
+
+    // clear existing value
+    for (auto & entry : lora) {
+        entry.scale = 0.0f;
+    }
+
+    // set value
+    for (const auto & entry : data) {
+        int id      = json_value(entry, "id", -1);
+        float scale = json_value(entry, "scale", 0.0f);
+        if (0 <= id && id < max_idx) {
+            lora[id].scale = scale;
+        } else {
+            throw std::runtime_error("invalid adapter id");
+        }
+    }
+
+    return lora;
+}