Просмотр исходного кода

server : (refactor) no more json in server_task input (#10691)

* server : (refactor) no more json in server_task input

* add test for slots endpoint

* add tests for /props and /slots

* remove task inf_type

* fix CI by adding safe_json_to_str

* add "model_path" to /props

* update readme
Xuan Son Nguyen 1 год назад
Родитель
Сommit
3573fa8e7b

+ 2 - 0
examples/server/README.md

@@ -687,12 +687,14 @@ This endpoint is public (no API key check). By default, it is read-only. To make
     }
     }
   },
   },
   "total_slots": 1,
   "total_slots": 1,
+  "model_path": "../models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf",
   "chat_template": "..."
   "chat_template": "..."
 }
 }
 ```
 ```
 
 
 - `default_generation_settings` - the default generation settings for the `/completion` endpoint, which has the same fields as the `generation_settings` response object from the `/completion` endpoint.
 - `default_generation_settings` - the default generation settings for the `/completion` endpoint, which has the same fields as the `generation_settings` response object from the `/completion` endpoint.
 - `total_slots` - the total number of slots for process requests (defined by `--parallel` option)
 - `total_slots` - the total number of slots for process requests (defined by `--parallel` option)
+- `model_path` - the path to model file (same with `-m` argument)
 - `chat_template` - the model's original Jinja2 prompt template
 - `chat_template` - the model's original Jinja2 prompt template
 
 
 ### POST `/props`: Change server global properties.
 ### POST `/props`: Change server global properties.

Разница между файлами не показана из-за своего большого размера
+ 322 - 355
examples/server/server.cpp


+ 30 - 0
examples/server/tests/unit/test_basic.py

@@ -22,7 +22,12 @@ def test_server_props():
     server.start()
     server.start()
     res = server.make_request("GET", "/props")
     res = server.make_request("GET", "/props")
     assert res.status_code == 200
     assert res.status_code == 200
+    assert ".gguf" in res.body["model_path"]
     assert res.body["total_slots"] == server.n_slots
     assert res.body["total_slots"] == server.n_slots
+    default_val = res.body["default_generation_settings"]
+    assert server.n_ctx is not None and server.n_slots is not None
+    assert default_val["n_ctx"] == server.n_ctx / server.n_slots
+    assert default_val["params"]["seed"] == server.seed
 
 
 
 
 def test_server_models():
 def test_server_models():
@@ -33,6 +38,31 @@ def test_server_models():
     assert len(res.body["data"]) == 1
     assert len(res.body["data"]) == 1
     assert res.body["data"][0]["id"] == server.model_alias
     assert res.body["data"][0]["id"] == server.model_alias
 
 
+
+def test_server_slots():
+    global server
+
+    # without slots endpoint enabled, this should return error
+    server.server_slots = False
+    server.start()
+    res = server.make_request("GET", "/slots")
+    assert res.status_code == 501 # ERROR_TYPE_NOT_SUPPORTED
+    assert "error" in res.body
+    server.stop()
+
+    # with slots endpoint enabled, this should return slots info
+    server.server_slots = True
+    server.n_slots = 2
+    server.start()
+    res = server.make_request("GET", "/slots")
+    assert res.status_code == 200
+    assert len(res.body) == server.n_slots
+    assert server.n_ctx is not None and server.n_slots is not None
+    assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots
+    assert "params" in res.body[0]
+    assert res.body[0]["params"]["seed"] == server.seed
+
+
 def test_load_split_model():
 def test_load_split_model():
     global server
     global server
     server.model_hf_repo = "ggml-org/models"
     server.model_hf_repo = "ggml-org/models"

+ 5 - 0
examples/server/tests/unit/test_chat_completion.py

@@ -30,6 +30,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
         ],
         ],
     })
     })
     assert res.status_code == 200
     assert res.status_code == 200
+    assert "cmpl" in res.body["id"] # make sure the completion id has the expected format
     assert res.body["model"] == model if model is not None else server.model_alias
     assert res.body["model"] == model if model is not None else server.model_alias
     assert res.body["usage"]["prompt_tokens"] == n_prompt
     assert res.body["usage"]["prompt_tokens"] == n_prompt
     assert res.body["usage"]["completion_tokens"] == n_predicted
     assert res.body["usage"]["completion_tokens"] == n_predicted
@@ -59,9 +60,13 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
         "stream": True,
         "stream": True,
     })
     })
     content = ""
     content = ""
+    last_cmpl_id = None
     for data in res:
     for data in res:
         choice = data["choices"][0]
         choice = data["choices"][0]
         assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
         assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
+        if last_cmpl_id is None:
+            last_cmpl_id = data["id"]
+        assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
         if choice["finish_reason"] in ["stop", "length"]:
         if choice["finish_reason"] in ["stop", "length"]:
             assert data["usage"]["prompt_tokens"] == n_prompt
             assert data["usage"]["prompt_tokens"] == n_prompt
             assert data["usage"]["completion_tokens"] == n_predicted
             assert data["usage"]["completion_tokens"] == n_predicted

+ 6 - 4
examples/server/tests/utils.py

@@ -64,6 +64,7 @@ class ServerProcess:
     server_embeddings: bool | None = False
     server_embeddings: bool | None = False
     server_reranking: bool | None = False
     server_reranking: bool | None = False
     server_metrics: bool | None = False
     server_metrics: bool | None = False
+    server_slots: bool | None = False
     draft: int | None = None
     draft: int | None = None
     api_key: str | None = None
     api_key: str | None = None
     response_format: str | None = None
     response_format: str | None = None
@@ -91,7 +92,6 @@ class ServerProcess:
         else:
         else:
             server_path = "../../../build/bin/llama-server"
             server_path = "../../../build/bin/llama-server"
         server_args = [
         server_args = [
-            "--slots",  # requires to get slot status via /slots endpoint
             "--host",
             "--host",
             self.server_host,
             self.server_host,
             "--port",
             "--port",
@@ -129,6 +129,8 @@ class ServerProcess:
             server_args.append("--reranking")
             server_args.append("--reranking")
         if self.server_metrics:
         if self.server_metrics:
             server_args.append("--metrics")
             server_args.append("--metrics")
+        if self.server_slots:
+            server_args.append("--slots")
         if self.model_alias:
         if self.model_alias:
             server_args.extend(["--alias", self.model_alias])
             server_args.extend(["--alias", self.model_alias])
         if self.n_ctx:
         if self.n_ctx:
@@ -181,7 +183,7 @@ class ServerProcess:
         start_time = time.time()
         start_time = time.time()
         while time.time() - start_time < timeout_seconds:
         while time.time() - start_time < timeout_seconds:
             try:
             try:
-                response = self.make_request("GET", "/slots", headers={
+                response = self.make_request("GET", "/health", headers={
                     "Authorization": f"Bearer {self.api_key}" if self.api_key else None
                     "Authorization": f"Bearer {self.api_key}" if self.api_key else None
                 })
                 })
                 if response.status_code == 200:
                 if response.status_code == 200:
@@ -224,7 +226,7 @@ class ServerProcess:
         result.headers = dict(response.headers)
         result.headers = dict(response.headers)
         result.status_code = response.status_code
         result.status_code = response.status_code
         result.body = response.json() if parse_body else None
         result.body = response.json() if parse_body else None
-        print("Response from server", result.body)
+        print("Response from server", json.dumps(result.body, indent=2))
         return result
         return result
 
 
     def make_stream_request(
     def make_stream_request(
@@ -245,7 +247,7 @@ class ServerProcess:
                 break
                 break
             elif line.startswith('data: '):
             elif line.startswith('data: '):
                 data = json.loads(line[6:])
                 data = json.loads(line[6:])
-                print("Partial response from server", data)
+                print("Partial response from server", json.dumps(data, indent=2))
                 yield data
                 yield data
 
 
 
 

+ 18 - 2
examples/server/utils.hpp

@@ -164,6 +164,9 @@ static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, con
     } else {
     } else {
         throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
         throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
     }
     }
+    if (result.empty()) {
+        throw std::runtime_error("\"prompt\" must not be empty");
+    }
     return result;
     return result;
 }
 }
 
 
@@ -496,8 +499,6 @@ static json oaicompat_completion_params_parse(
     const std::string & chat_template) {
     const std::string & chat_template) {
     json llama_params;
     json llama_params;
 
 
-    llama_params["__oaicompat"] = true;
-
     // Apply chat template to the list of messages
     // Apply chat template to the list of messages
     llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
     llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
 
 
@@ -648,3 +649,18 @@ static json format_detokenized_response(const std::string & content) {
         {"content", content}
         {"content", content}
     };
     };
 }
 }
+
+static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) {
+    json data = json::array();
+    for (const auto & lb : logit_bias) {
+        data.push_back(json{
+            {"bias", lb.bias},
+            {"token", lb.token},
+        });
+    }
+    return data;
+}
+
+static std::string safe_json_to_str(json data) {
+    return data.dump(-1, ' ', false, json::error_handler_t::replace);
+}

Некоторые файлы не были показаны из-за большого количества измененных файлов