ソースを参照

server: add router multi-model tests (#17704) (#17722)

* llama-server: add router multi-model tests (#17704)

Add 4 test cases for model router:
- test_router_unload_model: explicit model unloading
- test_router_models_max_evicts_lru: LRU eviction with --models-max
- test_router_no_models_autoload: --no-models-autoload flag behavior
- test_router_api_key_required: API key authentication

Tests use async model loading with polling and graceful skip when
insufficient models available for eviction testing.

utils.py changes:
- Add models_max, models_dir, no_models_autoload attributes to ServerProcess
- Handle JSONDecodeError for non-JSON error responses (fallback to text)

* llama-server: update test models to new HF repos

* add offline

* llama-server: fix router LRU eviction test and add preloading

Fix eviction test: load 2 models first, verify state, then load
3rd to trigger eviction. Previous logic loaded all 3 at once,
causing first model to be evicted before verification could occur.

Add module fixture to preload models via ServerPreset.load_all()
and mark test presets as offline to use cached models

* llama-server: fix split model download on Windows

---------

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
Pascal 1 ヶ月 前
コミット
e7c2cf1356

+ 1 - 0
tools/server/tests/unit/test_basic.py

@@ -65,6 +65,7 @@ def test_server_slots():
 
 def test_load_split_model():
     global server
+    server.offline = False
     server.model_hf_repo = "ggml-org/models"
     server.model_hf_file = "tinyllamas/split/stories15M-q8_0-00001-of-00003.gguf"
     server.model_alias = "tinyllama-split"

+ 145 - 1
tools/server/tests/unit/test_router.py

@@ -17,7 +17,6 @@ def create_server():
     ]
 )
 def test_router_chat_completion_stream(model: str, success: bool):
-    # TODO: make sure the model is in cache (ie. ServerProcess.load_all()) before starting the router server
     global server
     server.start()
     content = ""
@@ -48,3 +47,148 @@ def test_router_chat_completion_stream(model: str, success: bool):
     else:
         assert ex is not None
         assert content == ""
+
+
+def _get_model_status(model_id: str) -> str:
+    res = server.make_request("GET", "/models")
+    assert res.status_code == 200
+    for item in res.body.get("data", []):
+        if item.get("id") == model_id or item.get("model") == model_id:
+            return item["status"]["value"]
+    raise AssertionError(f"Model {model_id} not found in /models response")
+
+
+def _wait_for_model_status(model_id: str, desired: set[str], timeout: int = 60) -> str:
+    deadline = time.time() + timeout
+    last_status = None
+    while time.time() < deadline:
+        last_status = _get_model_status(model_id)
+        if last_status in desired:
+            return last_status
+        time.sleep(1)
+    raise AssertionError(
+        f"Timed out waiting for {model_id} to reach {desired}, last status: {last_status}"
+    )
+
+
+def _load_model_and_wait(
+    model_id: str, timeout: int = 60, headers: dict | None = None
+) -> None:
+    load_res = server.make_request(
+        "POST", "/models/load", data={"model": model_id}, headers=headers
+    )
+    assert load_res.status_code == 200
+    assert isinstance(load_res.body, dict)
+    assert load_res.body.get("success") is True
+    _wait_for_model_status(model_id, {"loaded"}, timeout=timeout)
+
+
+def test_router_unload_model():
+    global server
+    server.start()
+    model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
+
+    _load_model_and_wait(model_id)
+
+    unload_res = server.make_request("POST", "/models/unload", data={"model": model_id})
+    assert unload_res.status_code == 200
+    assert unload_res.body.get("success") is True
+    _wait_for_model_status(model_id, {"unloaded"})
+
+
+def test_router_models_max_evicts_lru():
+    global server
+    server.models_max = 2
+    server.start()
+
+    candidate_models = [
+        "ggml-org/tinygemma3-GGUF:Q8_0",
+        "ggml-org/test-model-stories260K",
+        "ggml-org/test-model-stories260K-infill",
+    ]
+
+    # Load only the first 2 models to fill the cache
+    first, second, third = candidate_models[:3]
+
+    _load_model_and_wait(first, timeout=120)
+    _load_model_and_wait(second, timeout=120)
+
+    # Verify both models are loaded
+    assert _get_model_status(first) == "loaded"
+    assert _get_model_status(second) == "loaded"
+
+    # Load the third model - this should trigger LRU eviction of the first model
+    _load_model_and_wait(third, timeout=120)
+
+    # Verify eviction: third is loaded, first was evicted
+    assert _get_model_status(third) == "loaded"
+    assert _get_model_status(first) == "unloaded"
+
+
+def test_router_no_models_autoload():
+    global server
+    server.no_models_autoload = True
+    server.start()
+    model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
+
+    res = server.make_request(
+        "POST",
+        "/v1/chat/completions",
+        data={
+            "model": model_id,
+            "messages": [{"role": "user", "content": "hello"}],
+            "max_tokens": 4,
+        },
+    )
+    assert res.status_code == 400
+    assert "error" in res.body
+
+    _load_model_and_wait(model_id)
+
+    success_res = server.make_request(
+        "POST",
+        "/v1/chat/completions",
+        data={
+            "model": model_id,
+            "messages": [{"role": "user", "content": "hello"}],
+            "max_tokens": 4,
+        },
+    )
+    assert success_res.status_code == 200
+    assert "error" not in success_res.body
+
+
+def test_router_api_key_required():
+    global server
+    server.api_key = "sk-router-secret"
+    server.start()
+
+    model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
+    auth_headers = {"Authorization": f"Bearer {server.api_key}"}
+
+    res = server.make_request(
+        "POST",
+        "/v1/chat/completions",
+        data={
+            "model": model_id,
+            "messages": [{"role": "user", "content": "hello"}],
+            "max_tokens": 4,
+        },
+    )
+    assert res.status_code == 401
+    assert res.body.get("error", {}).get("type") == "authentication_error"
+
+    _load_model_and_wait(model_id, headers=auth_headers)
+
+    authed = server.make_request(
+        "POST",
+        "/v1/chat/completions",
+        headers=auth_headers,
+        data={
+            "model": model_id,
+            "messages": [{"role": "user", "content": "hello"}],
+            "max_tokens": 4,
+        },
+    )
+    assert authed.status_code == 200
+    assert "error" not in authed.body

+ 23 - 5
tools/server/tests/utils.py

@@ -7,6 +7,7 @@ import subprocess
 import os
 import re
 import json
+from json import JSONDecodeError
 import sys
 import requests
 import time
@@ -83,6 +84,9 @@ class ServerProcess:
     pooling: str | None = None
     draft: int | None = None
     api_key: str | None = None
+    models_dir: str | None = None
+    models_max: int | None = None
+    no_models_autoload: bool | None = None
     lora_files: List[str] | None = None
     enable_ctx_shift: int | None = False
     draft_min: int | None = None
@@ -143,6 +147,10 @@ class ServerProcess:
             server_args.extend(["--hf-repo", self.model_hf_repo])
         if self.model_hf_file:
             server_args.extend(["--hf-file", self.model_hf_file])
+        if self.models_dir:
+            server_args.extend(["--models-dir", self.models_dir])
+        if self.models_max is not None:
+            server_args.extend(["--models-max", self.models_max])
         if self.n_batch:
             server_args.extend(["--batch-size", self.n_batch])
         if self.n_ubatch:
@@ -204,6 +212,8 @@ class ServerProcess:
             server_args.extend(["--draft-min", self.draft_min])
         if self.no_webui:
             server_args.append("--no-webui")
+        if self.no_models_autoload:
+            server_args.append("--no-models-autoload")
         if self.jinja:
             server_args.append("--jinja")
         else:
@@ -295,7 +305,13 @@ class ServerProcess:
         result = ServerResponse()
         result.headers = dict(response.headers)
         result.status_code = response.status_code
-        result.body = response.json() if parse_body else None
+        if parse_body:
+            try:
+                result.body = response.json()
+            except JSONDecodeError:
+                result.body = response.text
+        else:
+            result.body = None
         print("Response from server", json.dumps(result.body, indent=2))
         return result
 
@@ -434,8 +450,9 @@ class ServerPreset:
     @staticmethod
     def tinyllama2() -> ServerProcess:
         server = ServerProcess()
-        server.model_hf_repo = "ggml-org/models"
-        server.model_hf_file = "tinyllamas/stories260K.gguf"
+        server.offline = True # will be downloaded by load_all()
+        server.model_hf_repo = "ggml-org/test-model-stories260K"
+        server.model_hf_file = None
         server.model_alias = "tinyllama-2"
         server.n_ctx = 512
         server.n_batch = 32
@@ -479,8 +496,8 @@ class ServerPreset:
     def tinyllama_infill() -> ServerProcess:
         server = ServerProcess()
         server.offline = True # will be downloaded by load_all()
-        server.model_hf_repo = "ggml-org/models"
-        server.model_hf_file = "tinyllamas/stories260K-infill.gguf"
+        server.model_hf_repo = "ggml-org/test-model-stories260K-infill"
+        server.model_hf_file = None
         server.model_alias = "tinyllama-infill"
         server.n_ctx = 2048
         server.n_batch = 1024
@@ -537,6 +554,7 @@ class ServerPreset:
     @staticmethod
     def router() -> ServerProcess:
         server = ServerProcess()
+        server.offline = True # will be downloaded by load_all()
         # router server has no models
         server.model_file = None
         server.model_alias = None