Просмотр исходного кода

server : fix crash when prompt exceeds context size (#3996)

Alexey Parfenov 2 лет назад
Родитель
Сommit
d96ca7ded7
1 измененных файлов с 29 добавлено и 29 удалено
  1. 29 29
      examples/server/server.cpp

+ 29 - 29
examples/server/server.cpp

@@ -1557,6 +1557,35 @@ struct llama_server_context
 
                     slot.num_prompt_tokens = prompt_tokens.size();
 
+                    if (slot.params.n_keep < 0)
+                    {
+                        slot.params.n_keep = slot.num_prompt_tokens;
+                    }
+                    slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
+
+                    // if input prompt is too big, truncate it
+                    if (slot.num_prompt_tokens >= slot.n_ctx)
+                    {
+                        const int n_left = slot.n_ctx - slot.params.n_keep;
+                        const int n_block_size = n_left / 2;
+                        const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
+
+                        std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep);
+                        new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
+
+                        LOG_VERBOSE("input truncated", {
+                            {"n_ctx",  slot.n_ctx},
+                            {"n_keep", slot.params.n_keep},
+                            {"n_left", n_left},
+                            {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
+                        });
+                        slot.truncated = true;
+                        prompt_tokens = new_tokens;
+
+                        slot.num_prompt_tokens = prompt_tokens.size();
+                        GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx);
+                    }
+
                     if (!slot.params.cache_prompt)
                     {
                         llama_sampling_reset(slot.ctx_sampling);
@@ -1566,35 +1595,6 @@ struct llama_server_context
                     }
                     else
                     {
-                        if (slot.params.n_keep < 0)
-                        {
-                            slot.params.n_keep = slot.num_prompt_tokens;
-                        }
-                        slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
-
-                        // if input prompt is too big, truncate it
-                        if (slot.num_prompt_tokens >= slot.n_ctx)
-                        {
-                            const int n_left = slot.n_ctx - slot.params.n_keep;
-                            const int n_block_size = n_left / 2;
-                            const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
-
-                            std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep);
-                            new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
-
-                            LOG_VERBOSE("input truncated", {
-                                                            {"n_ctx",  slot.n_ctx},
-                                                            {"n_keep", slot.params.n_keep},
-                                                            {"n_left", n_left},
-                                                            {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
-                                                        });
-                            slot.truncated = true;
-                            prompt_tokens = new_tokens;
-
-                            slot.num_prompt_tokens = prompt_tokens.size();
-                            GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx);
-                        }
-
                         // push the prompt into the sampling context (do not apply grammar)
                         for (auto &token : prompt_tokens)
                         {