Sfoglia il codice sorgente

server: tests: add truncated prompt tests, better kv cache size (#5933)

* server: tests: add truncated prompt tests, better size

* server, tests : update regex

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Pierrick Hymbert 1 anno fa
parent
commit
fd72d2d2a5

+ 19 - 4
examples/server/server.cpp

@@ -1128,6 +1128,7 @@ struct server_context {
 
 
             LOG_VERBOSE("stopped by limit", {
             LOG_VERBOSE("stopped by limit", {
                 {"id_slot",   slot.id},
                 {"id_slot",   slot.id},
+                {"id_task",   slot.id_task},
                 {"n_decoded", slot.n_decoded},
                 {"n_decoded", slot.n_decoded},
                 {"n_predict", slot.params.n_predict},
                 {"n_predict", slot.params.n_predict},
             });
             });
@@ -1141,6 +1142,8 @@ struct server_context {
         }
         }
 
 
         LOG_VERBOSE("next token", {
         LOG_VERBOSE("next token", {
+            {"id_slot",        slot.id},
+            {"id_task",        slot.id_task},
             {"token",          result.tok},
             {"token",          result.tok},
             {"token_text",     tokens_to_output_formatted_string(ctx, result.tok)},
             {"token_text",     tokens_to_output_formatted_string(ctx, result.tok)},
             {"has_next_token", slot.has_next_token},
             {"has_next_token", slot.has_next_token},
@@ -1750,6 +1753,15 @@ struct server_context {
                         slot.n_past = 0;
                         slot.n_past = 0;
                         slot.n_prompt_tokens = prompt_tokens.size();
                         slot.n_prompt_tokens = prompt_tokens.size();
 
 
+                        LOG_VERBOSE("prompt tokenized", {
+                            {"id_slot",         slot.id},
+                            {"id_task",         slot.id_task},
+                            {"n_ctx",           slot.n_ctx},
+                            {"n_keep",          slot.params.n_keep},
+                            {"n_prompt_tokens", slot.n_prompt_tokens},
+                            {"prompt_tokens",   tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
+                        });
+
                         if (slot.embedding) {
                         if (slot.embedding) {
                             // this prompt is too large to process - discard it
                             // this prompt is too large to process - discard it
                             if (slot.n_prompt_tokens > n_batch) {
                             if (slot.n_prompt_tokens > n_batch) {
@@ -1788,10 +1800,13 @@ struct server_context {
                                 slot.n_prompt_tokens = prompt_tokens.size();
                                 slot.n_prompt_tokens = prompt_tokens.size();
 
 
                                 LOG_VERBOSE("input truncated", {
                                 LOG_VERBOSE("input truncated", {
-                                    {"n_ctx",         slot.n_ctx},
-                                    {"n_keep",        slot.params.n_keep},
-                                    {"n_left",        n_left},
-                                    {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
+                                    {"id_slot",         slot.id},
+                                    {"id_task",         slot.id_task},
+                                    {"n_ctx",           slot.n_ctx},
+                                    {"n_keep",          slot.params.n_keep},
+                                    {"n_left",          n_left},
+                                    {"n_prompt_tokens", slot.n_prompt_tokens},
+                                    {"prompt_tokens",   tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
                                 });
                                 });
 
 
                                 GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
                                 GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);

+ 3 - 2
examples/server/tests/features/parallel.feature

@@ -6,8 +6,8 @@ Feature: Parallel
     Given a server listening on localhost:8080
     Given a server listening on localhost:8080
     And   a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
     And   a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
     And   42 as server seed
     And   42 as server seed
-    And   512 as batch size
-    And   64 KV cache size
+    And   128 as batch size
+    And   256 KV cache size
     And   2 slots
     And   2 slots
     And   continuous batching
     And   continuous batching
     Then  the server is starting
     Then  the server is starting
@@ -76,6 +76,7 @@ Feature: Parallel
       | disabled  | 128       |
       | disabled  | 128       |
       | enabled   | 64        |
       | enabled   | 64        |
 
 
+
   Scenario:  Multi users with total number of tokens to predict exceeds the KV Cache size #3969
   Scenario:  Multi users with total number of tokens to predict exceeds the KV Cache size #3969
     Given a prompt:
     Given a prompt:
       """
       """

+ 30 - 11
examples/server/tests/features/server.feature

@@ -10,11 +10,10 @@ Feature: llama.cpp server
       # KV Cache corresponds to the total amount of tokens
       # KV Cache corresponds to the total amount of tokens
       # that can be stored across all independent sequences: #4130
       # that can be stored across all independent sequences: #4130
       # see --ctx-size and #5568
       # see --ctx-size and #5568
-    And   32 KV cache size
-    And   512 as batch size
-    And   1 slots
-    And   embeddings extraction
-    And   32 server max tokens to predict
+    And   256 KV cache size
+    And   32 as batch size
+    And   2 slots
+    And   64 server max tokens to predict
     And   prometheus compatible metrics exposed
     And   prometheus compatible metrics exposed
     Then  the server is starting
     Then  the server is starting
     Then  the server is healthy
     Then  the server is healthy
@@ -23,18 +22,35 @@ Feature: llama.cpp server
     Then the server is ready
     Then the server is ready
     And  all slots are idle
     And  all slots are idle
 
 
+
   Scenario Outline: Completion
   Scenario Outline: Completion
     Given a prompt <prompt>
     Given a prompt <prompt>
     And   <n_predict> max tokens to predict
     And   <n_predict> max tokens to predict
     And   a completion request with no api error
     And   a completion request with no api error
     Then  <n_predicted> tokens are predicted matching <re_content>
     Then  <n_predicted> tokens are predicted matching <re_content>
+    And   the completion is <truncated> truncated
+    And   <n_prompt> prompt tokens are processed
     And   prometheus metrics are exposed
     And   prometheus metrics are exposed
     And   metric llamacpp:tokens_predicted is <n_predicted>
     And   metric llamacpp:tokens_predicted is <n_predicted>
 
 
     Examples: Prompts
     Examples: Prompts
-      | prompt                           | n_predict | re_content                       | n_predicted |
-      | I believe the meaning of life is | 8         | (read\|going)+                   | 8           |
-      | Write a joke about AI            | 64        | (park\|friends\|scared\|always)+ | 32          |
+      | prompt                                                                    | n_predict | re_content                    | n_prompt | n_predicted | truncated |
+      | I believe the meaning of life is                                          | 8         | (read\|going)+                | 18       | 8           | not       |
+      | Write a joke about AI from a very long prompt which will not be truncated | 256       | (princesses\|everyone\|kids)+ | 46       | 64          | not       |
+
+  Scenario: Completion prompt truncated
+    Given a prompt:
+    """
+    Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
+    Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
+    Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
+    Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
+    """
+    And   a completion request with no api error
+    Then  64 tokens are predicted matching fun|Annaks|popcorns
+    And   the completion is  truncated
+    And   109 prompt tokens are processed
+
 
 
   Scenario Outline: OAI Compatibility
   Scenario Outline: OAI Compatibility
     Given a model <model>
     Given a model <model>
@@ -44,11 +60,14 @@ Feature: llama.cpp server
     And   streaming is <enable_streaming>
     And   streaming is <enable_streaming>
     Given an OAI compatible chat completions request with no api error
     Given an OAI compatible chat completions request with no api error
     Then  <n_predicted> tokens are predicted matching <re_content>
     Then  <n_predicted> tokens are predicted matching <re_content>
+    And   <n_prompt> prompt tokens are processed
+    And   the completion is <truncated> truncated
 
 
     Examples: Prompts
     Examples: Prompts
-      | model        | system_prompt               | user_prompt                          | max_tokens | re_content             | n_predicted | enable_streaming |
-      | llama-2      | Book                        | What is the best book                | 8          | (Mom\|what)+           | 8           | disabled         |
-      | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64         | (thanks\|happy\|bird)+ | 32          | enabled          |
+      | model        | system_prompt               | user_prompt                          | max_tokens | re_content             | n_prompt | n_predicted | enable_streaming | truncated |
+      | llama-2      | Book                        | What is the best book                | 8          | (Here\|what)+          | 77       | 8           | disabled         | not       |
+      | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 128        | (thanks\|happy\|bird)+ | -1       | 64          | enabled          |           |
+
 
 
   Scenario: Tokenize / Detokenize
   Scenario: Tokenize / Detokenize
     When tokenizing:
     When tokenizing:

+ 29 - 6
examples/server/tests/features/steps/steps.py

@@ -196,12 +196,30 @@ async def step_request_completion(context, api_error):
 
 
 @step(u'{predicted_n:d} tokens are predicted matching {re_content}')
 @step(u'{predicted_n:d} tokens are predicted matching {re_content}')
 def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
 def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
-    assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n, re_content)
+    context.completion = context.tasks_result.pop()
+    assert_n_tokens_predicted(context.completion, predicted_n, re_content)
 
 
 
 
 @step(u'{predicted_n:d} tokens are predicted')
 @step(u'{predicted_n:d} tokens are predicted')
 def step_n_tokens_predicted(context, predicted_n):
 def step_n_tokens_predicted(context, predicted_n):
-    assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n)
+    context.completion = context.tasks_result.pop()
+    assert_n_tokens_predicted(context.completion, predicted_n)
+
+
+@step(u'the completion is  truncated')
+def step_assert_completion_truncated(context):
+    step_assert_completion_truncated(context, '')
+
+
+@step(u'the completion is {truncated} truncated')
+def step_assert_completion_truncated(context, truncated):
+    truncated = truncated != "not"
+    assert context.completion['truncated'] == truncated, f'{context.completion}'
+
+
+@step(u'{n_prompt:d} prompt tokens are processed')
+def step_impl(context, n_prompt):
+    assert n_prompt < 0 or n_prompt == context.completion['timings']['prompt_n'], f"n_prompt={context.completion['timings']['prompt_n']}"
 
 
 
 
 @step(u'a user prompt {user_prompt}')
 @step(u'a user prompt {user_prompt}')
@@ -722,7 +740,8 @@ async def oai_chat_completions(user_prompt,
     completion_response = {
     completion_response = {
         'content': '',
         'content': '',
         'timings': {
         'timings': {
-            'predicted_n': 0
+            'predicted_n': 0,
+            'prompt_n': 0
         }
         }
     }
     }
     if async_client:
     if async_client:
@@ -763,7 +782,8 @@ async def oai_chat_completions(user_prompt,
                         completion_response = {
                         completion_response = {
                             'content': chat_completion_raw['choices'][0]['message'],
                             'content': chat_completion_raw['choices'][0]['message'],
                             'timings': {
                             'timings': {
-                                'predicted_n': chat_completion_raw['usage']['completion_tokens']
+                                'predicted_n': chat_completion_raw['usage']['completion_tokens'],
+                                'prompt_n': chat_completion_raw['usage']['prompt_tokens']
                             }
                             }
                         }
                         }
                     else:
                     else:
@@ -792,13 +812,16 @@ async def oai_chat_completions(user_prompt,
                 if 'content' in delta:
                 if 'content' in delta:
                     completion_response['content'] += delta['content']
                     completion_response['content'] += delta['content']
                     completion_response['timings']['predicted_n'] += 1
                     completion_response['timings']['predicted_n'] += 1
+                completion_response['truncated'] = chunk.choices[0].finish_reason != 'stop'
         else:
         else:
             assert len(chat_completion.choices) == 1
             assert len(chat_completion.choices) == 1
             completion_response = {
             completion_response = {
                 'content': chat_completion.choices[0].message.content,
                 'content': chat_completion.choices[0].message.content,
                 'timings': {
                 'timings': {
-                    'predicted_n': chat_completion.usage.completion_tokens
-                }
+                    'predicted_n': chat_completion.usage.completion_tokens,
+                    'prompt_n': chat_completion.usage.prompt_tokens
+                    },
+                'truncated': chat_completion.choices[0].finish_reason != 'stop'
             }
             }
     if debug:
     if debug:
         print("OAI response formatted to llama.cpp:", completion_response)
         print("OAI response formatted to llama.cpp:", completion_response)