Browse Source

server: init functional tests (#5566)

* server: tests: init scenarios
 - health and slots endpoints
 - completion endpoint
 - OAI compatible chat completion requests w/ and without streaming
 - completion multi users scenario
 - multi users scenario on OAI compatible endpoint with streaming
 - multi users with total number of tokens to predict exceeds the KV Cache size
 - server wrong usage scenario, like in Infinite loop of "context shift" #3969
 - slots shifting
 - continuous batching
 - embeddings endpoint
 - multi users embedding endpoint: Segmentation fault #5655
 - OpenAI-compatible embeddings API
 - tokenize endpoint
 - CORS and api key scenario

* server: CI GitHub workflow


---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Pierrick Hymbert 1 year ago
parent
commit
525213d2f5

+ 2 - 0
.github/ISSUE_TEMPLATE/bug.md

@@ -7,3 +7,5 @@ assignees: ''
 ---
 ---
 
 
 Please include information about your system, the steps to reproduce the bug, and the version of llama.cpp that you are using. If possible, please provide a minimal code example that reproduces the bug.
 Please include information about your system, the steps to reproduce the bug, and the version of llama.cpp that you are using. If possible, please provide a minimal code example that reproduces the bug.
+
+If the bug concerns the server, please try to reproduce it first using the [server test scenario framework](https://github.com/ggerganov/llama.cpp/tree/master/examples/server/tests).

+ 127 - 0
.github/workflows/server.yml

@@ -0,0 +1,127 @@
+# Server build and tests
+name: Server
+
+on:
+  workflow_dispatch: # allows manual triggering
+  push:
+    branches:
+      - master
+      - test/server-add-ci-test # FIXME remove
+    paths: ['.github/workflows/**', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'examples/server/**.*']
+  pull_request:
+    types: [opened, synchronize, reopened]
+    paths: ['**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'examples/server/**.*']
+
+jobs:
+  server:
+    runs-on: ubuntu-latest
+
+    strategy:
+      matrix:
+        build: [noavx, avx2, avx, avx512, cublas, clblast, openblas, kompute, vulkan]
+        sanitizer: [ADDRESS, THREAD, UNDEFINED]
+        build_type: [Debug, Release]
+        include:
+          - build: 'noavx'
+            defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DLLAMA_AVX=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF'
+            image: ubuntu:latest
+          - build: 'avx2'
+            defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON'
+            image: ubuntu:latest
+          - build: 'avx'
+            defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DLLAMA_AVX2=OFF'
+            image: ubuntu:latest
+          - build: 'avx512'
+            defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DLLAMA_AVX512=ON'
+            image: ubuntu:latest
+            experimental: true
+          - build: 'cublas'
+            defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DLLAMA_CUBLAS=ON'
+            image: nvidia/cuda:12.3.1-devel-ubuntu22.04
+            arch_not_available: true # require nvidia docker engine
+          - build: 'clblast'
+            defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DLLAMA_CLBLAST=ON'
+            image: ubuntu:latest
+            arch_not_available: true
+          - build: 'openblas'
+            defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS'
+            image: ubuntu:latest
+          - build: 'kompute'
+            defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DLLAMA_KOMPUTE=ON -DKOMPUTE_OPT_DISABLE_VULKAN_VERSION_CHECK=ON'
+            image: ubuntu:latest
+            arch_not_available: true
+          - build: 'vulkan'
+            defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DLLAMA_VULKAN=ON'
+            image: ubuntu:latest
+            arch_not_available: true
+
+    container:
+      image: ${{ matrix.image }}
+      ports:
+        - 8888
+      options: --cpus 4
+
+    steps:
+      - name: Clone
+        id: checkout
+        uses: actions/checkout@v3
+
+      - name: Dependencies
+        id: depends
+        run: |
+          apt-get update
+          apt-get -y install \
+            build-essential \
+            pkg-config \
+            git \
+            cmake \
+            python3-pip \
+            wget \
+            psmisc
+
+      - name: Download CLBlast
+        id: get_clblast
+        if: ${{ matrix.build == 'clblast' }}
+        run: |
+          apt install -y libclblast-dev
+
+      - name: Download OpenBLAS
+        id: get_openblas
+        if: ${{ matrix.build == 'openblas' }}
+        run: |
+          apt-get -y install libopenblas-dev
+
+      - name: Install Vulkan SDK
+        id: get_vulkan
+        if: ${{ matrix.build == 'kompute' || matrix.build == 'vulkan' }}
+        run: |
+          wget -qO- https://packages.lunarg.com/lunarg-signing-key-pub.asc | tee /etc/apt/trusted.gpg.d/lunarg.asc
+          wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list http://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list
+          apt-get update
+          apt-get -y install vulkan-sdk
+
+      - name: Build
+        id: cmake_build
+        run: |
+          mkdir build
+          cd build
+          cmake .. -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} ${{ matrix.defines }}
+          cmake --build . --config ${{ matrix.build_type }} -j $(nproc) --target server
+
+      - name: Tests dependencies
+        id: test_dependencies
+        run: |
+          pip install -r examples/server/tests/requirements.txt
+
+      - name: Download models
+        id: download_models
+        run: |
+          cd examples/server/tests
+          ../../../scripts/hf.sh --repo ggml-org/models --file tinyllamas/stories260K.gguf
+
+      - name: Tests
+        id: server_integration_test
+        continue-on-error: ${{ matrix.experimental || matrix.arch_not_available }}
+        run: |
+          cd examples/server/tests
+          PORT=8888 ./tests.sh

+ 6 - 0
examples/server/README.md

@@ -98,6 +98,12 @@ curl --request POST \
     --data '{"prompt": "Building a website can be done in 10 simple steps:","n_predict": 128}'
     --data '{"prompt": "Building a website can be done in 10 simple steps:","n_predict": 128}'
 ```
 ```
 
 
+## Advanced testing
+
+We implemented a [server test framework](./tests/README.md) using human-readable scenario.
+
+*Before submitting an issue, please try to reproduce it with this format.*
+
 ## Node JS Test
 ## Node JS Test
 
 
 You need to have [Node.js](https://nodejs.org/en) installed.
 You need to have [Node.js](https://nodejs.org/en) installed.

+ 18 - 18
examples/server/server.cpp

@@ -1410,11 +1410,6 @@ struct llama_server_context
                 int n_processing_slots = 0;
                 int n_processing_slots = 0;
 
 
                 for (llama_client_slot &slot: slots) {
                 for (llama_client_slot &slot: slots) {
-                    if (slot.available()) {
-                        n_idle_slots++;
-                    } else {
-                        n_processing_slots++;
-                    }
                     json slot_data = get_formated_generation(slot);
                     json slot_data = get_formated_generation(slot);
                     slot_data["id"] = slot.id;
                     slot_data["id"] = slot.id;
                     slot_data["task_id"] = slot.task_id;
                     slot_data["task_id"] = slot.task_id;
@@ -1429,6 +1424,11 @@ struct llama_server_context
                             {"stopped_limit", slot.stopped_limit},
                             {"stopped_limit", slot.stopped_limit},
                             {"stopping_word", slot.stopping_word},
                             {"stopping_word", slot.stopping_word},
                     };
                     };
+                    if (slot_data["state"] == IDLE) {
+                        n_idle_slots++;
+                    } else {
+                        n_processing_slots++;
+                    }
                     slots_data.push_back(slot_data);
                     slots_data.push_back(slot_data);
                 }
                 }
                 LOG_TEE("task %i - slots data: idle=%i processing=%i\n", task.id, n_idle_slots, n_processing_slots);
                 LOG_TEE("task %i - slots data: idle=%i processing=%i\n", task.id, n_idle_slots, n_processing_slots);
@@ -2748,19 +2748,6 @@ int main(int argc, char **argv)
         log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded";
         log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded";
     }
     }
 
 
-    LOG_INFO("HTTP server listening", log_data);
-    // run the HTTP server in a thread - see comment below
-    std::thread t([&]()
-            {
-                if (!svr.listen_after_bind())
-                {
-                    state.store(SERVER_STATE_ERROR);
-                    return 1;
-                }
-
-                return 0;
-            });
-
     // load the model
     // load the model
     if (!llama.load_model(params))
     if (!llama.load_model(params))
     {
     {
@@ -3228,6 +3215,19 @@ int main(int argc, char **argv)
     }*/
     }*/
     //);
     //);
 
 
+    LOG_INFO("HTTP server listening", log_data);
+    // run the HTTP server in a thread - see comment below
+    std::thread t([&]()
+            {
+                if (!svr.listen_after_bind())
+                {
+                    state.store(SERVER_STATE_ERROR);
+                    return 1;
+                }
+
+                return 0;
+            });
+
     llama.queue_tasks.on_new_task(std::bind(
     llama.queue_tasks.on_new_task(std::bind(
         &llama_server_context::process_single_task, &llama, std::placeholders::_1));
         &llama_server_context::process_single_task, &llama, std::placeholders::_1));
     llama.queue_tasks.on_finish_multitask(std::bind(
     llama.queue_tasks.on_finish_multitask(std::bind(

+ 46 - 0
examples/server/tests/README.md

@@ -0,0 +1,46 @@
+# Server tests
+
+Python based server tests scenario using [BDD](https://en.wikipedia.org/wiki/Behavior-driven_development) and [behave](https://behave.readthedocs.io/en/latest/):
+ * [issues.feature](./features/issues.feature) Pending issues scenario
+ * [parallel.feature](./features/parallel.feature) Scenario involving multi slots and concurrent requests
+ * [security.feature](./features/security.feature) Security, CORS and API Key
+ * [server.feature](./features/server.feature) Server base scenario: completion, embedding, tokenization, etc...
+
+Tests target GitHub workflows job runners with 4 vCPU.
+
+Requests are using [aiohttp](https://docs.aiohttp.org/en/stable/client_reference.html), [asyncio](https://docs.python.org/fr/3/library/asyncio.html) based http client.
+
+Note: If the host architecture inference speed is faster than GitHub runners one, parallel scenario may randomly fail. To mitigate it, you can increase values in `n_predict`, `kv_size`.
+
+### Install dependencies
+`pip install -r requirements.txt`
+
+### Run tests
+1. Build the server
+```shell
+cd ../../..
+mkdir build
+cd build
+cmake ../
+cmake --build . --target server
+```
+2. download required models:
+   1. `../../../scripts/hf.sh --repo ggml-org/models --file tinyllamas/stories260K.gguf`
+3. Start the test: `./tests.sh`
+
+It's possible to override some scenario steps values with environment variables:
+ - `PORT` -> `context.server_port` to set the listening port of the server during scenario, default: `8080`
+ - `LLAMA_SERVER_BIN_PATH` -> to change the server binary path, default: `../../../build/bin/server`
+ - `DEBUG` -> "ON" to enable steps and server verbose mode `--verbose`
+
+### Run @bug, @wip or @wrong_usage annotated scenario
+
+Feature or Scenario must be annotated with `@llama.cpp` to be included in the default scope.
+- `@bug` annotation aims to link a scenario with a GitHub issue.
+- `@wrong_usage` are meant to show user issue that are actually an expected behavior
+- `@wip` to focus on a scenario working in progress
+
+To run a scenario annotated with `@bug`, start:
+`DEBUG=ON ./tests.sh --no-skipped --tags bug`
+
+After changing logic in `steps.py`, ensure that `@bug` and `@wrong_usage` scenario are updated.

+ 67 - 0
examples/server/tests/features/environment.py

@@ -0,0 +1,67 @@
+import os
+import socket
+import subprocess
+import time
+from contextlib import closing
+from signal import SIGKILL
+
+
+def before_scenario(context, scenario):
+    print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m")
+    port = 8080
+    if 'PORT' in os.environ:
+        port = int(os.environ['PORT'])
+    if is_server_listening("localhost", port):
+        assert False, "Server already started"
+
+
+def after_scenario(context, scenario):
+    if scenario.status == "failed":
+        if 'GITHUB_ACTIONS' in os.environ:
+            print(f"\x1b[33;101mSCENARIO FAILED: {scenario.name} server logs:\x1b[0m\n\n")
+            if os.path.isfile('llama.log'):
+                with closing(open('llama.log', 'r')) as f:
+                    for line in f:
+                        print(line)
+        if not is_server_listening(context.server_fqdn, context.server_port):
+            print("\x1b[33;101mERROR: Server stopped listening\x1b[0m")
+
+    if not pid_exists(context.server_process.pid):
+        assert False, f"Server not running pid={context.server_process.pid} ..."
+
+    print(f"stopping server pid={context.server_process.pid} ...")
+    context.server_process.kill()
+    # Wait few for socket to free up
+    time.sleep(0.05)
+
+    attempts = 0
+    while is_server_listening(context.server_fqdn, context.server_port):
+        print(f"stopping server pid={context.server_process.pid} ...")
+        os.kill(context.server_process.pid, SIGKILL)
+        time.sleep(0.1)
+        attempts += 1
+        if attempts > 5:
+            print(f"Server dangling exits, killing all {context.server_path} ...")
+            process = subprocess.run(['killall', '-9', context.server_path],
+                                     stderr=subprocess.PIPE,
+                                     universal_newlines=True)
+            print(process)
+
+
+def is_server_listening(server_fqdn, server_port):
+    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
+        result = sock.connect_ex((server_fqdn, server_port))
+        return result == 0
+
+
+def pid_exists(pid):
+    """Check whether pid exists in the current process table."""
+    import errno
+    if pid < 0:
+        return False
+    try:
+        os.kill(pid, 0)
+    except OSError as e:
+        return e.errno == errno.EPERM
+    else:
+        return True

+ 36 - 0
examples/server/tests/features/issues.feature

@@ -0,0 +1,36 @@
+# List of ongoing issues
+@bug
+Feature: Issues
+    # Issue #5655
+  Scenario: Multi users embeddings
+    Given a server listening on localhost:8080
+    And   a model file stories260K.gguf
+    And   a model alias tinyllama-2
+    And   42 as server seed
+    And   64 KV cache size
+    And   2 slots
+    And   continuous batching
+    And   embeddings extraction
+    Then  the server is starting
+    Then  the server is healthy
+
+    Given a prompt:
+      """
+      Write a very long story about AI.
+      """
+    And a prompt:
+      """
+      Write another very long music lyrics.
+      """
+    And a prompt:
+      """
+      Write a very long poem.
+      """
+    And a prompt:
+      """
+      Write a very long joke.
+      """
+    Given concurrent embedding requests
+    Then the server is busy
+    Then the server is idle
+    Then all embeddings are generated

+ 77 - 0
examples/server/tests/features/parallel.feature

@@ -0,0 +1,77 @@
+@llama.cpp
+Feature: Parallel
+
+  Background: Server startup
+    Given a server listening on localhost:8080
+    And   a model file stories260K.gguf
+    And   a model alias tinyllama-2
+    And   42 as server seed
+    And   64 KV cache size
+    And   2 slots
+    And   continuous batching
+    Then  the server is starting
+    Then  the server is healthy
+
+  Scenario Outline: Multi users completion
+    Given a prompt:
+      """
+      Write a very long story about AI.
+      """
+    And a prompt:
+      """
+      Write another very long music lyrics.
+      """
+    And <n_predict> max tokens to predict
+    Given concurrent completion requests
+    Then the server is busy
+    Then the server is idle
+    And  all slots are idle
+    Then all prompts are predicted with <n_predict> tokens
+    Examples:
+      | n_predict |
+      | 128       |
+
+  Scenario Outline: Multi users OAI completions compatibility
+    Given a system prompt You are a writer.
+    And   a model tinyllama-2
+    Given a prompt:
+      """
+      Write a very long book.
+      """
+    And a prompt:
+      """
+      Write another a poem.
+      """
+    And <n_predict> max tokens to predict
+    And streaming is <streaming>
+    Given concurrent OAI completions requests
+    Then the server is busy
+    Then the server is idle
+    Then all prompts are predicted with <n_predict> tokens
+    Examples:
+      | streaming | n_predict |
+      | disabled  | 128       |
+      | enabled   | 64        |
+
+  Scenario:  Multi users with total number of tokens to predict exceeds the KV Cache size #3969
+    Given a prompt:
+      """
+      Write a very long story about AI.
+      """
+    And a prompt:
+      """
+      Write another very long music lyrics.
+      """
+    And a prompt:
+      """
+      Write a very long poem.
+      """
+    And a prompt:
+      """
+      Write a very long joke.
+      """
+    And 128 max tokens to predict
+    Given concurrent completion requests
+    Then the server is busy
+    Then the server is idle
+    Then all prompts are predicted

+ 50 - 0
examples/server/tests/features/security.feature

@@ -0,0 +1,50 @@
+@llama.cpp
+Feature: Security
+
+  Background: Server startup with an api key defined
+    Given a server listening on localhost:8080
+    And   a model file stories260K.gguf
+    And   a server api key llama.cpp
+    Then  the server is starting
+    Then  the server is healthy
+
+  Scenario Outline: Completion with some user api key
+    Given a prompt test
+    And   a user api key <api_key>
+    And   4 max tokens to predict
+    And   a completion request with <api_error> api error
+
+    Examples: Prompts
+      | api_key   | api_error |
+      | llama.cpp | no        |
+      | llama.cpp | no        |
+      | hackeme   | raised    |
+      |           | raised    |
+
+  Scenario Outline: OAI Compatibility
+    Given a system prompt test
+    And   a user prompt test
+    And   a model test
+    And   2 max tokens to predict
+    And   streaming is disabled
+    And   a user api key <api_key>
+    Given an OAI compatible chat completions request with <api_error> api error
+
+    Examples: Prompts
+      | api_key   | api_error |
+      | llama.cpp | no        |
+      | llama.cpp | no        |
+      | hackme    | raised    |
+
+
+  Scenario Outline: CORS Options
+    When an OPTIONS request is sent from <origin>
+    Then CORS header <cors_header> is set to <cors_header_value>
+
+    Examples: Headers
+      | origin          | cors_header                      | cors_header_value |
+      | localhost       | Access-Control-Allow-Origin      | localhost         |
+      | web.mydomain.fr | Access-Control-Allow-Origin      | web.mydomain.fr   |
+      | origin          | Access-Control-Allow-Credentials | true              |
+      | web.mydomain.fr | Access-Control-Allow-Methods     | POST              |
+      | web.mydomain.fr | Access-Control-Allow-Headers     | *                 |

+ 69 - 0
examples/server/tests/features/server.feature

@@ -0,0 +1,69 @@
+@llama.cpp
+Feature: llama.cpp server
+
+  Background: Server startup
+    Given a server listening on localhost:8080
+    And   a model file stories260K.gguf
+    And   a model alias tinyllama-2
+    And   42 as server seed
+      # KV Cache corresponds to the total amount of tokens
+      # that can be stored across all independent sequences: #4130
+      # see --ctx-size and #5568
+    And   32 KV cache size
+    And   1 slots
+    And   embeddings extraction
+    And   32 server max tokens to predict
+    Then  the server is starting
+    Then  the server is healthy
+
+  Scenario: Health
+    Then the server is ready
+    And  all slots are idle
+
+  Scenario Outline: Completion
+    Given a prompt <prompt>
+    And   <n_predict> max tokens to predict
+    And   a completion request with no api error
+    Then  <n_predicted> tokens are predicted matching <re_content>
+
+    Examples: Prompts
+      | prompt                           | n_predict | re_content                   | n_predicted |
+      | I believe the meaning of life is | 8         | read                         | 8           |
+      | Write a joke about AI            | 64        | (park<or>friends<or>scared)+ | 32          |
+
+  Scenario Outline: OAI Compatibility
+    Given a model <model>
+    And   a system prompt <system_prompt>
+    And   a user prompt <user_prompt>
+    And   <max_tokens> max tokens to predict
+    And   streaming is <enable_streaming>
+    Given an OAI compatible chat completions request with no api error
+    Then  <n_predicted> tokens are predicted matching <re_content>
+
+    Examples: Prompts
+      | model        | system_prompt               | user_prompt                          | max_tokens | re_content                 | n_predicted | enable_streaming |
+      | llama-2      | Book                        | What is the best book                | 8          | (Mom<or>what)+             | 8           | disabled         |
+      | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64         | (thanks<or>happy<or>bird)+ | 32          | enabled          |
+
+  Scenario: Embedding
+    When embeddings are computed for:
+    """
+    What is the capital of Bulgaria ?
+    """
+    Then embeddings are generated
+
+  Scenario: OAI Embeddings compatibility
+    Given a model tinyllama-2
+    When an OAI compatible embeddings computation request for:
+    """
+    What is the capital of Spain ?
+    """
+    Then embeddings are generated
+
+
+  Scenario: Tokenize / Detokenize
+    When tokenizing:
+    """
+    What is the capital of France ?
+    """
+    Then tokens can be detokenize

+ 709 - 0
examples/server/tests/features/steps/steps.py

@@ -0,0 +1,709 @@
+import asyncio
+import json
+import os
+import re
+import socket
+import subprocess
+import time
+from contextlib import closing
+from re import RegexFlag
+
+import aiohttp
+import openai
+from behave import step
+from behave.api.async_step import async_run_until_complete
+
+
+@step(u"a server listening on {server_fqdn}:{server_port}")
+def step_server_config(context, server_fqdn, server_port):
+    context.server_fqdn = server_fqdn
+    context.server_port = int(server_port)
+    if 'PORT' in os.environ:
+        context.server_port = int(os.environ['PORT'])
+        print(f"$PORT set, overriding server port with to {context.server_port}")
+
+    context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
+
+    context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON'
+    context.model_alias = None
+    context.n_ctx = None
+    context.n_predict = None
+    context.n_server_predict = None
+    context.n_slots = None
+    context.server_api_key = None
+    context.server_continuous_batching = False
+    context.server_embeddings = False
+    context.server_seed = None
+    context.user_api_key = None
+
+    context.tasks_result = []
+    context.concurrent_tasks = []
+    context.prompts = []
+
+
+@step(u'a model file {model_file}')
+def step_model_file(context, model_file):
+    context.model_file = model_file
+
+
+@step(u'a model alias {model_alias}')
+def step_model_alias(context, model_alias):
+    context.model_alias = model_alias
+
+
+@step(u'{seed} as server seed')
+def step_seed(context, seed):
+    context.server_seed = int(seed)
+
+
+@step(u'{n_ctx} KV cache size')
+def step_n_ctx(context, n_ctx):
+    context.n_ctx = int(n_ctx)
+
+
+@step(u'{n_slots} slots')
+def step_n_slots(context, n_slots):
+    context.n_slots = int(n_slots)
+
+
+@step(u'{n_predict} server max tokens to predict')
+def step_server_n_predict(context, n_predict):
+    context.n_server_predict = int(n_predict)
+
+
+@step(u'continuous batching')
+def step_server_continuous_batching(context):
+    context.server_continuous_batching = True
+
+
+@step(u'embeddings extraction')
+def step_server_embeddings(context):
+    context.server_embeddings = True
+
+
+@step(u"the server is starting")
+def step_start_server(context):
+    start_server_background(context)
+    attempts = 0
+    while True:
+        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
+            result = sock.connect_ex((context.server_fqdn, context.server_port))
+            if result == 0:
+                print("\x1b[33;46mserver started!\x1b[0m")
+                return
+            attempts += 1
+            if attempts > 20:
+                assert False, "server not started"
+            print(f"waiting for server to start, connect error code = {result}...")
+            time.sleep(0.1)
+
+
+@step(u"the server is {expecting_status}")
+@async_run_until_complete
+async def step_wait_for_the_server_to_be_started(context, expecting_status):
+    match expecting_status:
+        case 'healthy':
+            await wait_for_health_status(context, context.base_url, 200, 'ok')
+
+        case 'ready' | 'idle':
+            await wait_for_health_status(context, context.base_url, 200, 'ok',
+                                         params={'fail_on_no_slot': 0, 'include_slots': 0},
+                                         slots_idle=context.n_slots,
+                                         slots_processing=0,
+                                         expected_slots=[{'id': slot_id, 'state': 0}
+                                                         for slot_id in range(context.n_slots)])
+        case 'busy':
+            await wait_for_health_status(context, context.base_url, 503,
+                                         'no slot available',
+                                         params={'fail_on_no_slot': 0, 'include_slots': 0},
+                                         slots_idle=0,
+                                         slots_processing=context.n_slots,
+                                         expected_slots=[{'id': slot_id, 'state': 1}
+                                                         for slot_id in range(context.n_slots)])
+        case _:
+            assert False, "unknown status"
+
+
+@step(u'all slots are {expected_slot_status_string}')
+@async_run_until_complete
+async def step_all_slots_status(context, expected_slot_status_string):
+    match expected_slot_status_string:
+        case 'idle':
+            expected_slot_status = 0
+        case 'busy':
+            expected_slot_status = 1
+        case _:
+            assert False, "unknown status"
+
+    expected_slots = [{'id': slot_id, 'state': expected_slot_status}
+                      for slot_id in range(context.n_slots)]
+    await request_slots_status(context, expected_slots)
+
+
+@step(u'a completion request with {api_error} api error')
+@async_run_until_complete
+async def step_request_completion(context, api_error):
+    expect_api_error = api_error == 'raised'
+    completion = await request_completion(context.prompts.pop(),
+                                          context.base_url,
+                                          debug=context.debug,
+                                          n_predict=context.n_predict,
+                                          server_seed=context.server_seed,
+                                          expect_api_error=expect_api_error,
+                                          user_api_key=context.user_api_key)
+    context.tasks_result.append(completion)
+    if context.debug:
+        print(f"Completion response: {completion}")
+    if expect_api_error:
+        assert completion == 401, f"completion must be an 401 status code: {completion}"
+
+
+@step(u'{predicted_n} tokens are predicted matching {re_content}')
+def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
+    assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n), re_content)
+
+
+@step(u'{predicted_n} tokens are predicted')
+def step_n_tokens_predicted(context, predicted_n):
+    assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n))
+
+
+@step(u'a user prompt {user_prompt}')
+def step_user_prompt(context, user_prompt):
+    context.prompts.append(user_prompt)
+
+
+@step(u'a system prompt {system_prompt}')
+def step_system_prompt(context, system_prompt):
+    context.system_prompt = system_prompt
+
+
+@step(u'a model {model}')
+def step_model(context, model):
+    context.model = model
+
+
+@step(u'{max_tokens} max tokens to predict')
+def step_max_tokens(context, max_tokens):
+    context.n_predict = int(max_tokens)
+
+
+@step(u'streaming is {enable_streaming}')
+def step_streaming(context, enable_streaming):
+    context.enable_streaming = enable_streaming == 'enabled'
+
+
+@step(u'a user api key {user_api_key}')
+def step_user_api_key(context, user_api_key):
+    context.user_api_key = user_api_key
+
+
+@step(u'no user api key')
+def step_no_user_api_key(context):
+    context.user_api_key = None
+
+
+@step(u'a user api key ')
+def step_no_user_api_key_space(context):
+    context.user_api_key = None
+
+
+@step(u'a server api key {server_api_key}')
+def step_server_api_key(context, server_api_key):
+    context.server_api_key = server_api_key
+
+
+@step(u'an OAI compatible chat completions request with {api_error} api error')
+@async_run_until_complete
+async def step_oai_chat_completions(context, api_error):
+    if context.debug:
+        print(f"Submitting OAI compatible completions request...")
+    expect_api_error = api_error == 'raised'
+    completion = await oai_chat_completions(context.prompts.pop(),
+                                            context.system_prompt,
+                                            context.base_url,
+                                            False,
+                                            model=context.model if hasattr(context, 'model') else None,
+
+                                            n_predict=context.n_predict
+                                            if hasattr(context, 'n_predict') else None,
+
+                                            enable_streaming=context.enable_streaming
+                                            if hasattr(context, 'enable_streaming') else None,
+
+                                            server_seed=context.server_seed
+                                            if hasattr(context, 'server_seed') else None,
+
+                                            user_api_key=context.user_api_key
+                                            if hasattr(context, 'user_api_key') else None,
+
+                                            expect_api_error=expect_api_error)
+    context.tasks_result.append(completion)
+    if context.debug:
+        print(f"Completion response: {completion}")
+    if expect_api_error:
+        assert completion == 401, f"completion must be an 401 status code: {completion}"
+
+    if context.debug:
+        print(f"Completion response: {completion}")
+
+
+@step(u'a prompt')
+def step_a_prompt(context):
+    context.prompts.append(context.text)
+
+
+@step(u'a prompt {prompt}')
+def step_a_prompt_prompt(context, prompt):
+    context.prompts.append(prompt)
+
+
+@step(u'concurrent completion requests')
+@async_run_until_complete()
+async def step_concurrent_completion_requests(context):
+    await concurrent_completion_requests(context,
+                                         request_completion,
+                                         # prompt is inserted automatically
+                                         context.base_url,
+                                         debug=context.debug,
+                                         n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
+                                         server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
+                                         user_api_key=context.user_api_key if hasattr(context,
+                                                                                      'user_api_key') else None)
+
+
+@step(u'concurrent OAI completions requests')
+@async_run_until_complete
+async def step_oai_chat_completions(context):
+    await concurrent_completion_requests(context, oai_chat_completions,
+                                         # user_prompt is inserted automatically
+                                         context.system_prompt,
+                                         context.base_url,
+                                         True,  # async_client
+                                         model=context.model
+                                         if hasattr(context, 'model') else None,
+                                         n_predict=context.n_predict
+                                         if hasattr(context, 'n_predict') else None,
+                                         enable_streaming=context.enable_streaming
+                                         if hasattr(context, 'enable_streaming') else None,
+                                         server_seed=context.server_seed
+                                         if hasattr(context, 'server_seed') else None,
+                                         user_api_key=context.user_api_key
+                                         if hasattr(context, 'user_api_key') else None)
+
+
+@step(u'all prompts are predicted')
+@async_run_until_complete
+async def step_all_prompts_are_predicted(context):
+    await all_prompts_are_predicted(context)
+
+
+@step(u'all prompts are predicted with {n_predict} tokens')
+@async_run_until_complete
+async def step_all_prompts_are_predicted_with_n_tokens(context, n_predict):
+    expected_predicted_n = int(n_predict)
+    await all_prompts_are_predicted(context, expected_predicted_n)
+
+
+async def all_prompts_are_predicted(context, expected_predicted_n=None):
+    n_completions = await gather_tasks_results(context)
+    assert n_completions > 0
+    for i in range(n_completions):
+        assert_n_tokens_predicted(context.tasks_result.pop(), expected_predicted_n=expected_predicted_n)
+    assert len(context.concurrent_tasks) == 0, f"{len(context.concurrent_tasks)} pending requests"
+
+
+@step(u'embeddings are computed for')
+@async_run_until_complete
+async def step_compute_embedding(context):
+    content = context.text
+    base_url = context.base_url
+    context.embeddings = await request_embedding(content, base_url)
+
+
+@step(u'embeddings are generated')
+def step_assert_embeddings(context):
+    assert_embeddings(context.embeddings)
+
+
+@step(u'an OAI compatible embeddings computation request for')
+def step_oai_compute_embedding(context):
+    openai.api_key = 'nope'  # openai client always expects an api_keu
+    if context.user_api_key is not None:
+        openai.api_key = context.user_api_key
+    openai.api_base = f'{context.base_url}/v1'
+    embeddings = openai.Embedding.create(
+        model=context.model,
+        input=context.text,
+    )
+    context.embeddings = embeddings
+
+
+@step(u'concurrent embedding requests')
+@async_run_until_complete()
+async def step_concurrent_embedding_requests(context):
+    await concurrent_completion_requests(context,
+                                         request_embedding,
+                                         # prompt is inserted automatically
+                                         context.base_url)
+
+
+@step(u'all embeddings are generated')
+@async_run_until_complete()
+async def all_embeddings_are_generated(context):
+    n_embedding_requests = await gather_tasks_results(context)
+    assert n_embedding_requests > 0
+    for i in range(n_embedding_requests):
+        assert_embeddings(context.tasks_result.pop())
+
+
+@step(u'tokenizing')
+@async_run_until_complete
+async def step_tokenize(context):
+    context.tokenized_text = context.text
+    async with aiohttp.ClientSession() as session:
+        async with session.post(f'{context.base_url}/tokenize',
+                                json={
+                                    "content": context.tokenized_text,
+                                }) as response:
+            assert response.status == 200
+            tokenize_json = await response.json()
+            context.tokens = tokenize_json['tokens']
+
+
+@step(u'tokens can be detokenize')
+@async_run_until_complete
+async def step_detokenize(context):
+    assert len(context.tokens) > 0
+    async with aiohttp.ClientSession() as session:
+        async with session.post(f'{context.base_url}/detokenize',
+                                json={
+                                    "tokens": context.tokens,
+                                }) as response:
+            assert response.status == 200
+            detokenize_json = await response.json()
+            # SPM tokenizer adds a whitespace prefix: https://github.com/google/sentencepiece/issues/15
+            assert context.tokenized_text == detokenize_json['content'].strip()
+
+
+@step(u'an OPTIONS request is sent from {origin}')
+@async_run_until_complete
+async def step_options_request(context, origin):
+    async with aiohttp.ClientSession() as session:
+        async with session.options(f'{context.base_url}/v1/chat/completions',
+                                   headers={"Origin": origin}) as response:
+            assert response.status == 200
+            context.options_response = response
+
+
+@step(u'CORS header {cors_header} is set to {cors_header_value}')
+def step_check_options_header_value(context, cors_header, cors_header_value):
+    assert context.options_response.headers[cors_header] == cors_header_value
+
+
+async def concurrent_completion_requests(context, f_completion, *args, **kwargs):
+    n_prompts = len(context.prompts)
+    if context.debug:
+        print(f"starting {n_prompts} concurrent completion requests...")
+    assert n_prompts > 0
+    for prompt_no in range(n_prompts):
+        shifted_args = [context.prompts.pop(), *args]
+        context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
+    await asyncio.sleep(0.1)
+
+
+async def request_completion(prompt,
+                             base_url,
+                             debug=False,
+                             n_predict=None,
+                             server_seed=None,
+                             expect_api_error=None,
+                             user_api_key=None):
+    if debug:
+        print(f"Sending completion request: {prompt}")
+    origin = "my.super.domain"
+    headers = {
+        'Origin': origin
+    }
+    if user_api_key is not None:
+        if debug:
+            print(f"Set user_api_key: {user_api_key}")
+        headers['Authorization'] = f'Bearer {user_api_key}'
+
+    async with aiohttp.ClientSession() as session:
+        async with session.post(f'{base_url}/completion',
+                                json={
+                                    "prompt": prompt,
+                                    "n_predict": int(n_predict) if n_predict is not None else -1,
+                                    "seed": server_seed if server_seed is not None else 42
+                                },
+                                headers=headers) as response:
+            if expect_api_error is None or not expect_api_error:
+                assert response.status == 200
+                assert response.headers['Access-Control-Allow-Origin'] == origin
+                return await response.json()
+            else:
+                return response.status
+
+
+async def oai_chat_completions(user_prompt,
+                               system_prompt,
+                               base_url,
+                               async_client,
+                               debug=False,
+                               model=None,
+                               n_predict=None,
+                               enable_streaming=None,
+                               server_seed=None,
+                               user_api_key=None,
+                               expect_api_error=None):
+    if debug:
+        print(f"Sending OAI Chat completions request: {user_prompt}")
+    # openai client always expects an api key
+    user_api_key = user_api_key if user_api_key is not None else 'nope'
+    seed = server_seed if server_seed is not None else 42
+    enable_streaming = enable_streaming if enable_streaming is not None else False
+    payload = {
+        "messages": [
+            {
+                "role": "system",
+                "content": system_prompt,
+            },
+            {
+                "role": "user",
+                "content": user_prompt,
+            }
+        ],
+        "model": model,
+        "max_tokens": n_predict,
+        "stream": enable_streaming,
+        "seed": seed
+    }
+    completion_response = {
+        'content': '',
+        'timings': {
+            'predicted_n': 0
+        }
+    }
+    if async_client:
+        origin = 'llama.cpp'
+        headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
+        async with aiohttp.ClientSession() as session:
+            async with session.post(f'{base_url}/v1/chat/completions',
+                                    json=payload,
+                                    headers=headers) as response:
+                if enable_streaming:
+                    assert response.status == 200
+                    assert response.headers['Access-Control-Allow-Origin'] == origin
+                    assert response.headers['Content-Type'] == "text/event-stream"
+                    event_received = True
+                    while event_received:
+                        event_received = False
+                        async for line_in_bytes in response.content:
+                            line = line_in_bytes.decode('utf8')
+                            line = line.rstrip('\n').rstrip('\r')
+                            if line == '':
+                                continue
+                            event_data = line.split(': ', 1)
+                            assert event_data[0] == 'data', f'Bad event code received: ```{event_data}```'
+                            chunk_raw = event_data[1]
+
+                            chunk = json.loads(chunk_raw)
+                            assert len(chunk['choices']) == 1, f"no choices provided, line ```{line}```"
+                            delta = chunk['choices'][0]['delta']
+                            if 'content' in delta:
+                                completion_response['content'] += delta['content']
+                                completion_response['timings']['predicted_n'] += 1
+                else:
+                    if expect_api_error is None or not expect_api_error:
+                        assert response.status == 200
+                        assert response.headers['Access-Control-Allow-Origin'] == origin
+                        assert response.headers['Content-Type'] == "application/json; charset=utf-8"
+                        chat_completion_raw = await response.json()
+                        completion_response = {
+                            'content': chat_completion_raw['choices'][0]['message'],
+                            'timings': {
+                                'predicted_n': chat_completion_raw['usage']['completion_tokens']
+                            }
+                        }
+                    else:
+                        return response.status
+    else:
+        try:
+            openai.api_key = user_api_key
+            openai.api_base = f'{base_url}/v1/chat'
+            chat_completion = openai.Completion.create(
+                messages=payload['messages'],
+                model=model,
+                max_tokens=n_predict,
+                stream=enable_streaming,
+                seed=seed
+            )
+        except openai.error.APIError as e:
+            if expect_api_error is not None and expect_api_error:
+                return 401
+            else:
+                assert False, f'error raised: {e}'
+
+        if enable_streaming:
+            for chunk in chat_completion:
+                assert len(chunk.choices) == 1
+                delta = chunk.choices[0].delta
+                if 'content' in delta:
+                    completion_response['content'] += delta['content']
+                    completion_response['timings']['predicted_n'] += 1
+        else:
+            assert len(chat_completion.choices) == 1
+            completion_response = {
+                'content': chat_completion.choices[0].message.content,
+                'timings': {
+                    'predicted_n': chat_completion.usage.completion_tokens
+                }
+            }
+    if debug:
+        print("OAI response formatted to llama.cpp:", completion_response)
+    return completion_response
+
+
+async def request_embedding(content, base_url):
+    async with aiohttp.ClientSession() as session:
+        async with session.post(f'{base_url}/embedding',
+                                json={
+                                    "content": content,
+                                }) as response:
+            assert response.status == 200
+            response_json = await response.json()
+            return response_json['embedding']
+
+
+def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
+    content = completion_response['content']
+    n_predicted = completion_response['timings']['predicted_n']
+    assert len(content) > 0, "no token predicted"
+    if expected_predicted_n is not None:
+        assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
+                                                     f' {n_predicted} <> {expected_predicted_n}')
+    if re_content is not None:
+        re_content = '^.*' + re_content.replace('<or>', '|') + '.*$'
+        assert re.match(re_content, content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL), (
+            f'invalid tokens predicted:'
+            f' ```\n{content}\n``` do not match /{re_content}/')
+
+
+async def gather_tasks_results(context):
+    n_tasks = len(context.concurrent_tasks)
+    if context.debug:
+        print(f"Waiting for all {n_tasks} tasks results...")
+    for task_no in range(n_tasks):
+        context.tasks_result.append(await context.concurrent_tasks.pop())
+    n_completions = len(context.tasks_result)
+    return n_completions
+
+
+async def wait_for_health_status(context,
+                                 base_url,
+                                 expected_http_status_code,
+                                 expected_health_status,
+                                 params=None,
+                                 slots_idle=None,
+                                 slots_processing=None,
+                                 expected_slots=None):
+    if context.debug:
+        print(f"Starting checking for health for expected_health_status={expected_health_status}")
+    timeout = 3  # seconds
+    interval = 0.5
+    counter = 0
+    async with aiohttp.ClientSession() as session:
+        while True:
+            async with await session.get(f'{base_url}/health', params=params) as health_response:
+                status_code = health_response.status
+                health = await health_response.json()
+                if context.debug:
+                    print(f"HEALTH - response for expected health status='{expected_health_status}' on "
+                          f"'{base_url}/health'?{params} is {health}")
+                if (status_code == expected_http_status_code
+                        and health['status'] == expected_health_status
+                        and (slots_idle is None or health['slots_idle'] == slots_idle)
+                        and (slots_processing is None or health['slots_processing'] == slots_processing)):
+                    if expected_slots is not None:
+                        assert_slots_status(health['slots'], expected_slots)
+                    return
+                if (status_code == expected_http_status_code
+                        and health['status'] == expected_health_status
+                        and (slots_idle is None or health['slots_idle'] == slots_idle)
+                        and (slots_processing is None or health['slots_processing'] == slots_processing)):
+                    if expected_slots is not None:
+                        assert_slots_status(health['slots'], expected_slots)
+                    return
+            await asyncio.sleep(interval)
+
+            counter += interval
+            if counter >= timeout:
+                # Sometimes health requests are triggered after completions are predicted
+                if expected_http_status_code == 503:
+                    if len(context.tasks_result) == 0:
+                        print("\x1b[5;37;43mWARNING: forcing concurrent tasks,"
+                              " busy health check missed, probably too fast inference\x1b[0m")
+                        n_completions = await gather_tasks_results(context)
+                        if n_completions > 0:
+                            return
+
+                assert False, 'timeout exceeded'
+
+
+def assert_embeddings(embeddings):
+    assert len(embeddings) > 0
+    embeddings_computed = False
+    for emb in embeddings:
+        if emb != 0:
+            embeddings_computed = True
+    assert embeddings_computed, f"Embeddings: {embeddings}"
+
+
+async def request_slots_status(context, expected_slots):
+    async with aiohttp.ClientSession() as session:
+        async with await session.get(f'{context.base_url}/slots') as slots_response:
+            assert slots_response.status == 200
+            slots = await slots_response.json()
+            assert_slots_status(slots, expected_slots)
+
+
+def assert_slots_status(slots, expected_slots):
+    assert len(slots) == len(expected_slots)
+    for slot_id, (expected, slot) in enumerate(zip(expected_slots, slots)):
+        for key in expected:
+            assert expected[key] == slot[key], (f"invalid slot {slot_id}"
+                                                f" expected[{key}] != slot[{key}]"
+                                                f" = {expected[key]} != {slot[key]}")
+
+
+def start_server_background(context):
+    context.server_path = '../../../build/bin/server'
+    if 'LLAMA_SERVER_BIN_PATH' in os.environ:
+        context.server_path = os.environ['LLAMA_SERVER_BIN_PATH']
+    server_args = [
+        '--host', context.server_fqdn,
+        '--port', context.server_port,
+        '--model', context.model_file
+    ]
+    if context.server_continuous_batching:
+        server_args.append('--cont-batching')
+    if context.server_embeddings:
+        server_args.append('--embedding')
+    if context.model_alias is not None:
+        server_args.extend(['--alias', context.model_alias])
+    if context.n_ctx is not None:
+        server_args.extend(['--ctx-size', context.n_ctx])
+    if context.n_slots is not None:
+        server_args.extend(['--parallel', context.n_slots])
+    if context.n_server_predict is not None:
+        server_args.extend(['--n-predict', context.n_server_predict])
+    if context.server_api_key is not None:
+        server_args.extend(['--api-key', context.server_api_key])
+    if context.debug:
+        server_args.append('--verbose')
+    print(f"starting server with: {context.server_path}", *server_args)
+    context.server_process = subprocess.Popen(
+        [str(arg) for arg in [context.server_path, *server_args]],
+        close_fds=True)
+    print(f"server pid={context.server_process.pid}")

+ 21 - 0
examples/server/tests/features/wrong_usages.feature

@@ -0,0 +1,21 @@
+# run with ./test.sh --tags wrong_usage
+@wrong_usage
+Feature: Wrong usage of llama.cpp server
+
+  #3969 The user must always set --n-predict option
+  # to cap the number of tokens any completion request can generate
+  # or pass n_predict/max_tokens in the request.
+  Scenario: Infinite loop
+    Given a server listening on localhost:8080
+    And   a model file stories260K.gguf
+    # Uncomment below to fix the issue
+    #And   64 server max tokens to predict
+    Then  the server is starting
+    Given a prompt:
+      """
+      Go to: infinite loop
+      """
+    # Uncomment below to fix the issue
+    #And   128 max tokens to predict
+    Given concurrent completion requests
+    Then all prompts are predicted

+ 3 - 0
examples/server/tests/requirements.txt

@@ -0,0 +1,3 @@
+aiohttp~=3.9.3
+behave~=1.2.6
+openai~=0.25.0

+ 12 - 0
examples/server/tests/tests.sh

@@ -0,0 +1,12 @@
+#!/bin/bash
+
+set -eu
+
+if [ $# -lt 1 ]
+then
+  # Start @llama.cpp scenario
+  behave --summary --stop --no-capture --exclude 'issues|wrong_usages' --tags llama.cpp
+else
+  behave "$@"
+fi
+