Przeglądaj źródła

server : add --no-context-shift option (#9607)

* server : add --no-context-shift option

* small fix

* Update examples/server/tests/features/embeddings.feature

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* tests : minor fix

* revert usage of GGML_ASSERT

* update server documentation

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Xuan Son Nguyen 1 rok temu
rodzic
commit
0b3bf966f4

+ 1 - 1
common/arg.cpp

@@ -691,7 +691,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
         [](gpt_params & params) {
         [](gpt_params & params) {
             params.ctx_shift = false;
             params.ctx_shift = false;
         }
         }
-    ).set_examples({LLAMA_EXAMPLE_MAIN}));
+    ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}));
     add_opt(llama_arg(
     add_opt(llama_arg(
         {"--chunks"}, "N",
         {"--chunks"}, "N",
         format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
         format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),

+ 11 - 9
examples/server/README.md

@@ -21,8 +21,6 @@ The project is under active development, and we are [looking for feedback and co
 | -------- | ----------- |
 | -------- | ----------- |
 | `-h, --help, --usage` | print usage and exit |
 | `-h, --help, --usage` | print usage and exit |
 | `--version` | show version and build info |
 | `--version` | show version and build info |
-| `-v, --verbose` | print verbose information |
-| `--verbosity N` | set specific verbosity level (default: 0) |
 | `-t, --threads N` | number of threads to use during generation (default: -1)<br/>(env: LLAMA_ARG_THREADS) |
 | `-t, --threads N` | number of threads to use during generation (default: -1)<br/>(env: LLAMA_ARG_THREADS) |
 | `-tb, --threads-batch N` | number of threads to use during batch and prompt processing (default: same as --threads) |
 | `-tb, --threads-batch N` | number of threads to use during batch and prompt processing (default: same as --threads) |
 | `-C, --cpu-mask M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: "") |
 | `-C, --cpu-mask M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: "") |
@@ -40,15 +38,18 @@ The project is under active development, and we are [looking for feedback and co
 | `-b, --batch-size N` | logical maximum batch size (default: 2048)<br/>(env: LLAMA_ARG_BATCH) |
 | `-b, --batch-size N` | logical maximum batch size (default: 2048)<br/>(env: LLAMA_ARG_BATCH) |
 | `-ub, --ubatch-size N` | physical maximum batch size (default: 512)<br/>(env: LLAMA_ARG_UBATCH) |
 | `-ub, --ubatch-size N` | physical maximum batch size (default: 512)<br/>(env: LLAMA_ARG_UBATCH) |
 | `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) |
 | `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) |
+| `--no-context-shift` | disables context shift on inifinite text generation (default: disabled) |
 | `-fa, --flash-attn` | enable Flash Attention (default: disabled)<br/>(env: LLAMA_ARG_FLASH_ATTN) |
 | `-fa, --flash-attn` | enable Flash Attention (default: disabled)<br/>(env: LLAMA_ARG_FLASH_ATTN) |
 | `-p, --prompt PROMPT` | prompt to start generation with |
 | `-p, --prompt PROMPT` | prompt to start generation with |
+| `--no-perf` | disable internal libllama performance timings (default: false)<br/>(env: LLAMA_ARG_NO_PERF) |
 | `-f, --file FNAME` | a file containing the prompt (default: none) |
 | `-f, --file FNAME` | a file containing the prompt (default: none) |
 | `-bf, --binary-file FNAME` | binary file containing the prompt (default: none) |
 | `-bf, --binary-file FNAME` | binary file containing the prompt (default: none) |
 | `-e, --escape` | process escapes sequences (\n, \r, \t, \', \", \\) (default: true) |
 | `-e, --escape` | process escapes sequences (\n, \r, \t, \', \", \\) (default: true) |
 | `--no-escape` | do not process escape sequences |
 | `--no-escape` | do not process escape sequences |
+| `-sp, --special` | special tokens output enabled (default: false) |
 | `--spm-infill` | use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) |
 | `--spm-infill` | use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) |
 | `--samplers SAMPLERS` | samplers that will be used for generation in the order, separated by ';'<br/>(default: top_k;tfs_z;typ_p;top_p;min_p;temperature) |
 | `--samplers SAMPLERS` | samplers that will be used for generation in the order, separated by ';'<br/>(default: top_k;tfs_z;typ_p;top_p;min_p;temperature) |
-| `-s, --seed SEED` | RNG seed (default: -1, use random seed for < 0) |
+| `-s, --seed SEED` | RNG seed (default: 4294967295, use random seed for 4294967295) |
 | `--sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: kfypmt) |
 | `--sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: kfypmt) |
 | `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) |
 | `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) |
 | `--penalize-nl` | penalize newline tokens (default: false) |
 | `--penalize-nl` | penalize newline tokens (default: false) |
@@ -87,7 +88,7 @@ The project is under active development, and we are [looking for feedback and co
 | `-ctk, --cache-type-k TYPE` | KV cache data type for K (default: f16) |
 | `-ctk, --cache-type-k TYPE` | KV cache data type for K (default: f16) |
 | `-ctv, --cache-type-v TYPE` | KV cache data type for V (default: f16) |
 | `-ctv, --cache-type-v TYPE` | KV cache data type for V (default: f16) |
 | `-dt, --defrag-thold N` | KV cache defragmentation threshold (default: -1.0, < 0 - disabled)<br/>(env: LLAMA_ARG_DEFRAG_THOLD) |
 | `-dt, --defrag-thold N` | KV cache defragmentation threshold (default: -1.0, < 0 - disabled)<br/>(env: LLAMA_ARG_DEFRAG_THOLD) |
-| `-np, --parallel N` | number of parallel sequences to decode (default: 1)<br/>(env:  LLAMA_ARG_N_PARALLEL) |
+| `-np, --parallel N` | number of parallel sequences to decode (default: 1)<br/>(env: LLAMA_ARG_N_PARALLEL) |
 | `-cb, --cont-batching` | enable continuous batching (a.k.a dynamic batching) (default: enabled)<br/>(env: LLAMA_ARG_CONT_BATCHING) |
 | `-cb, --cont-batching` | enable continuous batching (a.k.a dynamic batching) (default: enabled)<br/>(env: LLAMA_ARG_CONT_BATCHING) |
 | `-nocb, --no-cont-batching` | disable continuous batching<br/>(env: LLAMA_ARG_NO_CONT_BATCHING) |
 | `-nocb, --no-cont-batching` | disable continuous batching<br/>(env: LLAMA_ARG_NO_CONT_BATCHING) |
 | `--mlock` | force system to keep model in RAM rather than swapping or compressing |
 | `--mlock` | force system to keep model in RAM rather than swapping or compressing |
@@ -128,12 +129,13 @@ The project is under active development, and we are [looking for feedback and co
 | `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)<br/> |
 | `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)<br/> |
 | `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) |
 | `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) |
 | `-ld, --logdir LOGDIR` | path under which to save YAML logs (no logging if unset) |
 | `-ld, --logdir LOGDIR` | path under which to save YAML logs (no logging if unset) |
-| `--log-test` | Log test |
 | `--log-disable` | Log disable |
 | `--log-disable` | Log disable |
-| `--log-enable` | Log enable |
-| `--log-new` | Log new |
-| `--log-append` | Log append |
-| `--log-file FNAME` | Log file |
+| `--log-file FNAME` | Log to file |
+| `--log-colors` | Enable colored logging<br/>(env: LLAMA_LOG_COLORS) |
+| `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) |
+| `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored.<br/>(env: LLAMA_LOG_VERBOSITY) |
+| `--log-prefix` | Enable prefx in log messages<br/>(env: LLAMA_LOG_PREFIX) |
+| `--log-timestamps` | Enable timestamps in log messages<br/>(env: LLAMA_LOG_TIMESTAMPS) |
 
 
 Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var.
 Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var.
 
 

+ 26 - 1
examples/server/server.cpp

@@ -1180,6 +1180,15 @@ struct server_context {
             SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
             SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
         }
         }
 
 
+        // if context shift is disabled, we stop when it reaches the context limit
+        if (slot.n_decoded >= slot.n_ctx) {
+            slot.truncated      = true;
+            slot.stopped_limit  = true;
+            slot.has_next_token = false;
+
+            SLT_DBG(slot, "stopped due to running out of context capacity, n_decoded = %d, n_ctx = %d\n", slot.n_decoded, slot.n_ctx);
+        }
+
         if (llama_token_is_eog(model, result.tok)) {
         if (llama_token_is_eog(model, result.tok)) {
             slot.stopped_eos    = true;
             slot.stopped_eos    = true;
             slot.has_next_token = false;
             slot.has_next_token = false;
@@ -1480,7 +1489,7 @@ struct server_context {
             if (result.error) {
             if (result.error) {
                 error_handler(result.data);
                 error_handler(result.data);
                 cancel_tasks(id_tasks);
                 cancel_tasks(id_tasks);
-                break;
+                return;
             }
             }
 
 
             size_t idx = result.data["index"];
             size_t idx = result.data["index"];
@@ -1827,6 +1836,14 @@ struct server_context {
         for (server_slot & slot : slots) {
         for (server_slot & slot : slots) {
             if (slot.ga_n == 1) {
             if (slot.ga_n == 1) {
                 if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
                 if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
+                    if (!params.ctx_shift) {
+                        // this check is redundant (for good)
+                        // we should never get here, because generation should already stopped in process_token()
+                        slot.release();
+                        send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
+                        continue;
+                    }
+
                     // Shift context
                     // Shift context
                     const int n_keep    = slot.params.n_keep + add_bos_token;
                     const int n_keep    = slot.params.n_keep + add_bos_token;
                     const int n_left    = (int) system_tokens.size() + slot.n_past - n_keep;
                     const int n_left    = (int) system_tokens.size() + slot.n_past - n_keep;
@@ -1961,6 +1978,14 @@ struct server_context {
                                 continue;
                                 continue;
                             }
                             }
                         } else {
                         } else {
+                            if (!params.ctx_shift) {
+                                // if context shift is disabled, we make sure prompt size is smaller than KV size
+                                if ((int) system_tokens.size() + slot.n_prompt_tokens >= slot.n_ctx) {
+                                    slot.release();
+                                    send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
+                                    continue;
+                                }
+                            }
                             if (slot.params.n_keep < 0) {
                             if (slot.params.n_keep < 0) {
                                 slot.params.n_keep = slot.n_prompt_tokens;
                                 slot.params.n_keep = slot.n_prompt_tokens;
                             }
                             }

+ 62 - 0
examples/server/tests/features/ctx_shift.feature

@@ -0,0 +1,62 @@
+@llama.cpp
+@ctx_shift
+Feature: llama.cpp server
+
+  Background: Server startup
+    Given a server listening on localhost:8080
+    And   a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
+    And   a model file test-model.gguf
+    And   a model alias tinyllama-2
+    And   BOS token is 1
+    And   42 as server seed
+    And   256 KV cache size
+    And   32 as batch size
+    And   2 slots
+
+  Scenario: Inference with context shift
+    And   64 server max tokens to predict
+    Then  the server is starting
+    Then  the server is healthy
+    Given a prompt:
+    """
+    Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
+    Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
+    Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
+    Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
+    """
+    And   a completion request with no api error
+    Then  64 tokens are predicted matching fun|Annaks|popcorns|pictry|bowl
+    And   the completion is  truncated
+    And   109 prompt tokens are processed
+
+  Scenario Outline: Inference without context shift
+    And   <n_predict> server max tokens to predict
+    And   disable context shifting
+    Then  the server is starting
+    Then  the server is healthy
+    Given a prompt:
+    """
+    Hi how are you
+    """
+    And   a completion request with no api error
+    Then  <n_token_output> tokens are predicted matching twind|Anna
+    And   the completion is <truncated> truncated
+    And   8 prompt tokens are processed
+    Examples:
+      | n_predict | n_token_output | truncated |
+      | 64        | 64             | not       |
+      | -1        | 120            |           |
+
+  Scenario: Inference without context shift (expected error: prompt too long)
+    And   disable context shifting
+    Then  the server is starting
+    Then  the server is healthy
+    Given a prompt:
+    """
+    Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
+    Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
+    Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
+    Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
+    """
+    And   a completion request with 400 api error
+

+ 18 - 4
examples/server/tests/features/embeddings.feature

@@ -10,11 +10,11 @@ Feature: llama.cpp server
     And   42 as server seed
     And   42 as server seed
     And   2 slots
     And   2 slots
     # the bert-bge-small model has context size of 512
     # the bert-bge-small model has context size of 512
-    # since the generated prompts are as big as the batch size, we need to set the batch size to 512
+    # since the generated prompts are as big as the batch size, we need to set the batch size to <= 512
     # ref: https://huggingface.co/BAAI/bge-small-en-v1.5/blob/5c38ec7c405ec4b44b94cc5a9bb96e735b38267a/config.json#L20
     # ref: https://huggingface.co/BAAI/bge-small-en-v1.5/blob/5c38ec7c405ec4b44b94cc5a9bb96e735b38267a/config.json#L20
-    And   512 as batch size
-    And   512 as ubatch size
-    And   2048 KV cache size
+    And   128 as batch size
+    And   128 as ubatch size
+    And   512 KV cache size
     And   embeddings extraction
     And   embeddings extraction
     Then  the server is starting
     Then  the server is starting
     Then  the server is healthy
     Then  the server is healthy
@@ -26,6 +26,20 @@ Feature: llama.cpp server
     """
     """
     Then embeddings are generated
     Then embeddings are generated
 
 
+  Scenario: Embedding (error: prompt too long)
+    When embeddings are computed for:
+    """
+    Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
+    Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
+    Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
+    Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
+    Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
+    Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
+    Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
+    Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
+    """
+    And  embeddings request with 500 api error
+
   Scenario: OAI Embeddings compatibility
   Scenario: OAI Embeddings compatibility
     Given a model bert-bge-small
     Given a model bert-bge-small
     When an OAI compatible embeddings computation request for:
     When an OAI compatible embeddings computation request for:

+ 21 - 7
examples/server/tests/features/steps/steps.py

@@ -77,6 +77,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
     context.response_format = None
     context.response_format = None
     context.temperature = None
     context.temperature = None
     context.lora_file = None
     context.lora_file = None
+    context.disable_ctx_shift = False
 
 
     context.tasks_result = []
     context.tasks_result = []
     context.concurrent_tasks = []
     context.concurrent_tasks = []
@@ -148,7 +149,7 @@ def step_n_slots(context, n_slots: int):
 
 
 @step('{n_predict:d} server max tokens to predict')
 @step('{n_predict:d} server max tokens to predict')
 def step_server_n_predict(context, n_predict: int):
 def step_server_n_predict(context, n_predict: int):
-    context.n_server_predict = n_predict
+    context.n_server_predict = n_predict if n_predict > 0 else None
 
 
 
 
 @step('{slot_save_path} as slot save path')
 @step('{slot_save_path} as slot save path')
@@ -180,6 +181,9 @@ def step_server_embeddings(context):
 def step_server_metrics(context):
 def step_server_metrics(context):
     context.server_metrics = True
     context.server_metrics = True
 
 
+@step('disable context shifting')
+def step_server_disable_ctx_shift(context):
+    context.disable_ctx_shift = True
 
 
 @step("the server is starting")
 @step("the server is starting")
 def step_start_server(context):
 def step_start_server(context):
@@ -257,7 +261,7 @@ async def step_all_slots_status(context, expected_slot_status_string: Literal['i
 @step('a completion request with {api_error} api error')
 @step('a completion request with {api_error} api error')
 @async_run_until_complete
 @async_run_until_complete
 async def step_request_completion(context, api_error: Literal['raised'] | str):
 async def step_request_completion(context, api_error: Literal['raised'] | str):
-    expect_api_error = api_error == 'raised'
+    expect_api_error = api_error == 'raised' or api_error != 'no'
     seeds = await completions_seed(context, num_seeds=1)
     seeds = await completions_seed(context, num_seeds=1)
     completion = await request_completion(context.prompts.pop(),
     completion = await request_completion(context.prompts.pop(),
                                           seeds[0] if seeds is not None else seeds,
                                           seeds[0] if seeds is not None else seeds,
@@ -272,8 +276,11 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
     context.tasks_result.append(completion)
     context.tasks_result.append(completion)
     if context.debug:
     if context.debug:
         print(f"Completion response: {completion}")
         print(f"Completion response: {completion}")
-    if expect_api_error:
+    if api_error == 'raised':
         assert completion == 401, f"completion must be an 401 status code: {completion}"
         assert completion == 401, f"completion must be an 401 status code: {completion}"
+    elif api_error.isdigit():
+        api_error_code = int(api_error)
+        assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}"
 
 
 
 
 @step('{predicted_n:d} tokens are predicted matching {re_content}')
 @step('{predicted_n:d} tokens are predicted matching {re_content}')
@@ -645,6 +652,9 @@ def step_assert_embeddings(context):
     for embedding in context.embeddings:
     for embedding in context.embeddings:
         assert_embeddings(embedding)
         assert_embeddings(embedding)
 
 
+@step('embeddings request with {api_error_code:d} api error')
+def step_assert_embeddings(context, api_error_code: int):
+    assert context.embeddings == api_error_code, f"embeddings request must return code {api_error_code}, but got {context.embeddings}"
 
 
 @step('an OAI compatible embeddings computation request for')
 @step('an OAI compatible embeddings computation request for')
 @async_run_until_complete
 @async_run_until_complete
@@ -1089,15 +1099,17 @@ async def oai_chat_completions(user_prompt,
     return completion_response
     return completion_response
 
 
 
 
-async def request_embedding(content, seed, base_url=None) -> list[list[float]]:
+async def request_embedding(content, seed, base_url=None) -> list[list[float]] | int:
     async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
     async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
         async with session.post(f'{base_url}/embedding',
         async with session.post(f'{base_url}/embedding',
                                 json={
                                 json={
                                     "content": content,
                                     "content": content,
                                 }) as response:
                                 }) as response:
-            assert response.status == 200
-            response_json = await response.json()
-            return [response_json['embedding']]
+            if response.status == 200:
+                response_json = await response.json()
+                return [response_json['embedding']]
+            else:
+                return response.status
 
 
 
 
 async def request_oai_embeddings(input, seed,
 async def request_oai_embeddings(input, seed,
@@ -1372,6 +1384,8 @@ def start_server_background(context):
         server_args.append('--verbose')
         server_args.append('--verbose')
     if context.lora_file:
     if context.lora_file:
         server_args.extend(['--lora', context.lora_file])
         server_args.extend(['--lora', context.lora_file])
+    if context.disable_ctx_shift:
+        server_args.extend(['--no-context-shift'])
 
 
     args = [str(arg) for arg in [context.server_path, *server_args]]
     args = [str(arg) for arg in [context.server_path, *server_args]]
     print(f"bench: starting server with: {' '.join(args)}")
     print(f"bench: starting server with: {' '.join(args)}")