Преглед изворни кода

server : handle context overflow during decode (#17267)

* server : handle context overflow during decode

* server : minor refactor
Georgi Gerganov пре 2 месеци
родитељ
комит
5b2093becc
1 измењених фајлова са 30 додато и 29 уклоњено
  1. 30 29
      tools/server/server.cpp

+ 30 - 29
tools/server/server.cpp

@@ -1686,14 +1686,13 @@ struct server_slot {
         llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0);
     }
 
-    void prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
+    bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
         bool res = prompt_cache.load(prompt, tokens, ctx, id);
         if (!res) {
             SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
-
-            llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
-            prompt.tokens.clear();
         }
+
+        return res;
     }
 
     std::vector<common_adapter_lora_info> lora;
@@ -2339,7 +2338,6 @@ struct server_context {
 
     llama_batch batch {};
 
-    bool clean_kv_cache = true;
     bool add_bos_token  = true;
 
     int32_t n_ctx; // total context for all clients / slots
@@ -2702,7 +2700,10 @@ struct server_context {
                 const int64_t t_start = ggml_time_us();
 
                 ret->prompt_save(*prompt_cache);
-                ret->prompt_load(*prompt_cache, task.tokens);
+
+                if (!ret->prompt_load(*prompt_cache, task.tokens)) {
+                    clear_slot(*ret);
+                }
 
                 prompt_cache->update();
 
@@ -2713,12 +2714,21 @@ struct server_context {
         return ret;
     }
 
-    // return true if at least one slot has been purged
+    void clear_slot(server_slot & slot) const {
+        GGML_ASSERT(!slot.is_processing());
+
+        SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size());
+
+        llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
+        slot.prompt.tokens.clear();
+    }
+
+    // return true if at least one slot has been cleared
     // TODO: improve logic
-    //       - smarter decision which slot to purge (LRU or longest prompt?)
+    //       - smarter decision which slot to clear (LRU or longest prompt?)
     //       - move slot to level 2 cache instead of removing?
     //       - instead of purging, try to store and resume later?
-    bool try_purge_idle_slots() {
+    bool try_clear_idle_slots() {
         bool res = false;
 
         if (!params_base.kv_unified) {
@@ -2733,12 +2743,11 @@ struct server_context {
             if (slot.prompt.n_tokens() > 0) {
                 SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
 
-                llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
-                slot.prompt.tokens.clear();
+                clear_slot(slot);
 
                 res = true;
 
-                // purge slots one by one
+                // clear slots one by one
                 break;
             }
         }
@@ -2848,14 +2857,6 @@ struct server_context {
         return true;
     }
 
-    void kv_cache_clear() {
-        SRV_DBG("%s", "clearing KV cache\n");
-
-        // clear the entire KV cache
-        llama_memory_clear(llama_get_memory(ctx), true);
-        clean_kv_cache = false;
-    }
-
     bool process_token(completion_token_output & result, server_slot & slot) {
         // remember which tokens were sampled - used for repetition penalties during sampling
         const std::string token_str = result.text_to_send;
@@ -3443,8 +3444,8 @@ struct server_context {
 
                     // Erase token cache
                     const size_t n_erased = slot->prompt.tokens.size();
-                    llama_memory_seq_rm(llama_get_memory(ctx), slot->id, -1, -1);
-                    slot->prompt.tokens.clear();
+
+                    clear_slot(*slot);
 
                     auto res = std::make_unique<server_task_result_slot_erase>();
                     res->id       = task.id;
@@ -3477,9 +3478,6 @@ struct server_context {
 
             if (all_idle) {
                 SRV_INF("%s", "all slots are idle\n");
-                if (clean_kv_cache) {
-                    kv_cache_clear();
-                }
 
                 return;
             }
@@ -3873,12 +3871,11 @@ struct server_context {
 
                     if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
                         SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
-                        llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
+
+                        clear_slot(slot);
 
                         // there is no common part left
                         slot.n_prompt_tokens_cache = 0;
-
-                        slot.prompt.tokens.clear();
                     }
 
                     // check if we should process the image
@@ -4108,6 +4105,10 @@ struct server_context {
                             if (slot.is_processing()) {
                                 send_error(slot, err);
                                 slot.release();
+
+                                // note: it's complicated to keep track of how much of the current batch has been
+                                //       processed before the error occurred, so we simply clear the entire context
+                                clear_slot(slot);
                             }
                         }
 
@@ -4116,7 +4117,7 @@ struct server_context {
                 }
 
                 // retry with half the batch size to try to find a free slot in the KV cache
-                if (!try_purge_idle_slots()) {
+                if (!try_clear_idle_slots()) {
                     n_batch /= 2;
                 }