Selaa lähdekoodia

server : make cache_reuse configurable per request (#17858)

Georgi Gerganov 1 kuukausi sitten
vanhempi
sitoutus
2bc96931d2

+ 2 - 0
tools/server/README.md

@@ -495,6 +495,8 @@ By default, this value is set to `0`, meaning no tokens are kept. Use `-1` to re
 
 `n_cmpl`: Number of completions to generate from the current prompt. If input has multiple prompts, the output will have N prompts times `n_cmpl` entries.
 
+`n_cache_reuse`: Min chunk size to attempt reusing from the cache via KV shifting. For more info, see `--cache-reuse` arg. Default: `0`, which is disabled.
+
 `stream`: Allows receiving each predicted token in real-time instead of waiting for the completion to finish (uses a different response format). To enable this, set to `true`.
 
 `stop`: Specify a JSON array of stopping strings.

+ 13 - 4
tools/server/server-context.cpp

@@ -1880,8 +1880,18 @@ struct server_context_impl {
                                     n_past = std::min(n_past, slot.alora_invocation_start - 1);
                                 }
 
+                                const auto n_cache_reuse = slot.task->params.n_cache_reuse;
+
+                                const bool can_cache_reuse =
+                                    llama_memory_can_shift(llama_get_memory(ctx)) &&
+                                    !slot.prompt.tokens.has_mtmd;
+
+                                if (!can_cache_reuse && n_cache_reuse > 0) {
+                                    SLT_WRN(slot, "cache reuse is not supported - ignoring n_cache_reuse = %d\n", n_cache_reuse);
+                                }
+
                                 // reuse chunks from the cached prompt by shifting their KV cache in the new position
-                                if (params_base.n_cache_reuse > 0) {
+                                if (can_cache_reuse && n_cache_reuse > 0) {
                                     GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
 
                                     size_t head_c = n_past; // cache
@@ -1892,7 +1902,7 @@ struct server_context_impl {
                                         GGML_ABORT("not supported by multimodal");
                                     }
 
-                                    SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", params_base.n_cache_reuse, n_past);
+                                    SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", n_cache_reuse, n_past);
 
                                     while (head_c < slot.prompt.tokens.size() &&
                                            head_p < input_tokens.size()) {
@@ -1901,11 +1911,10 @@ struct server_context_impl {
                                         while (head_c + n_match < slot.prompt.tokens.size() &&
                                                head_p + n_match < input_tokens.size()       &&
                                                slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) {
-
                                             n_match++;
                                         }
 
-                                        if (n_match >= (size_t) params_base.n_cache_reuse) {
+                                        if (n_match >= (size_t) n_cache_reuse) {
                                             SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
                                             //for (size_t i = head_p; i < head_p + n_match; i++) {
                                             //    SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());

+ 7 - 5
tools/server/server-task.cpp

@@ -155,11 +155,12 @@ task_params server_task::params_from_json_cmpl(
 
     // Sampling parameter defaults are loaded from the global server context (but individual requests can still them)
     task_params defaults;
-    defaults.sampling    = params_base.sampling;
-    defaults.speculative = params_base.speculative;
-    defaults.n_keep      = params_base.n_keep;
-    defaults.n_predict   = params_base.n_predict;
-    defaults.antiprompt  = params_base.antiprompt;
+    defaults.sampling      = params_base.sampling;
+    defaults.speculative   = params_base.speculative;
+    defaults.n_keep        = params_base.n_keep;
+    defaults.n_predict     = params_base.n_predict;
+    defaults.n_cache_reuse = params_base.n_cache_reuse;
+    defaults.antiprompt    = params_base.antiprompt;
 
     // enabling this will output extra debug information in the HTTP responses from the server
     params.verbose           = params_base.verbosity > 9;
@@ -176,6 +177,7 @@ task_params server_task::params_from_json_cmpl(
     params.n_keep           = json_value(data,       "n_keep",             defaults.n_keep);
     params.n_discard        = json_value(data,       "n_discard",          defaults.n_discard);
     params.n_cmpl           = json_value(data,       "n_cmpl",             json_value(data, "n", 1));
+    params.n_cache_reuse    = json_value(data,       "n_cache_reuse",      defaults.n_cache_reuse);
     //params.t_max_prompt_ms  = json_value(data,       "t_max_prompt_ms",    defaults.t_max_prompt_ms); // TODO: implement
     params.t_max_predict_ms = json_value(data,       "t_max_predict_ms",   defaults.t_max_predict_ms);
     params.response_fields  = json_value(data,       "response_fields",    std::vector<std::string>());

+ 9 - 6
tools/server/server-task.h

@@ -55,6 +55,8 @@ struct task_params {
     int32_t n_indent  =  0; // minimum line indentation for the generated text in number of whitespace characters
     int32_t n_cmpl    =  1; // number of completions to generate from this prompt
 
+    int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled)
+
     int64_t t_max_prompt_ms  = -1; // TODO: implement
     int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
 
@@ -62,18 +64,19 @@ struct task_params {
 
     std::vector<std::string> antiprompt;
     std::vector<std::string> response_fields;
-    bool timings_per_token = false;
+
+    bool timings_per_token   = false;
     bool post_sampling_probs = false;
 
     struct common_params_sampling sampling;
     struct common_params_speculative speculative;
 
     // response formatting
-    bool                         verbose                   = false;
-    task_response_type           res_type                  = TASK_RESPONSE_TYPE_NONE;
-    std::string                  oaicompat_model;
-    std::string                  oaicompat_cmpl_id;
-    common_chat_syntax           oaicompat_chat_syntax;
+    bool               verbose  = false;
+    task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
+    std::string        oaicompat_model;
+    std::string        oaicompat_cmpl_id;
+    common_chat_syntax oaicompat_chat_syntax;
 
     // Embeddings
     int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)