|
|
@@ -17,7 +17,6 @@ def create_server():
|
|
|
]
|
|
|
)
|
|
|
def test_router_chat_completion_stream(model: str, success: bool):
|
|
|
- # TODO: make sure the model is in cache (ie. ServerProcess.load_all()) before starting the router server
|
|
|
global server
|
|
|
server.start()
|
|
|
content = ""
|
|
|
@@ -48,3 +47,148 @@ def test_router_chat_completion_stream(model: str, success: bool):
|
|
|
else:
|
|
|
assert ex is not None
|
|
|
assert content == ""
|
|
|
+
|
|
|
+
|
|
|
+def _get_model_status(model_id: str) -> str:
|
|
|
+ res = server.make_request("GET", "/models")
|
|
|
+ assert res.status_code == 200
|
|
|
+ for item in res.body.get("data", []):
|
|
|
+ if item.get("id") == model_id or item.get("model") == model_id:
|
|
|
+ return item["status"]["value"]
|
|
|
+ raise AssertionError(f"Model {model_id} not found in /models response")
|
|
|
+
|
|
|
+
|
|
|
+def _wait_for_model_status(model_id: str, desired: set[str], timeout: int = 60) -> str:
|
|
|
+ deadline = time.time() + timeout
|
|
|
+ last_status = None
|
|
|
+ while time.time() < deadline:
|
|
|
+ last_status = _get_model_status(model_id)
|
|
|
+ if last_status in desired:
|
|
|
+ return last_status
|
|
|
+ time.sleep(1)
|
|
|
+ raise AssertionError(
|
|
|
+ f"Timed out waiting for {model_id} to reach {desired}, last status: {last_status}"
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def _load_model_and_wait(
|
|
|
+ model_id: str, timeout: int = 60, headers: dict | None = None
|
|
|
+) -> None:
|
|
|
+ load_res = server.make_request(
|
|
|
+ "POST", "/models/load", data={"model": model_id}, headers=headers
|
|
|
+ )
|
|
|
+ assert load_res.status_code == 200
|
|
|
+ assert isinstance(load_res.body, dict)
|
|
|
+ assert load_res.body.get("success") is True
|
|
|
+ _wait_for_model_status(model_id, {"loaded"}, timeout=timeout)
|
|
|
+
|
|
|
+
|
|
|
+def test_router_unload_model():
|
|
|
+ global server
|
|
|
+ server.start()
|
|
|
+ model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
|
|
|
+
|
|
|
+ _load_model_and_wait(model_id)
|
|
|
+
|
|
|
+ unload_res = server.make_request("POST", "/models/unload", data={"model": model_id})
|
|
|
+ assert unload_res.status_code == 200
|
|
|
+ assert unload_res.body.get("success") is True
|
|
|
+ _wait_for_model_status(model_id, {"unloaded"})
|
|
|
+
|
|
|
+
|
|
|
+def test_router_models_max_evicts_lru():
|
|
|
+ global server
|
|
|
+ server.models_max = 2
|
|
|
+ server.start()
|
|
|
+
|
|
|
+ candidate_models = [
|
|
|
+ "ggml-org/tinygemma3-GGUF:Q8_0",
|
|
|
+ "ggml-org/test-model-stories260K",
|
|
|
+ "ggml-org/test-model-stories260K-infill",
|
|
|
+ ]
|
|
|
+
|
|
|
+ # Load only the first 2 models to fill the cache
|
|
|
+ first, second, third = candidate_models[:3]
|
|
|
+
|
|
|
+ _load_model_and_wait(first, timeout=120)
|
|
|
+ _load_model_and_wait(second, timeout=120)
|
|
|
+
|
|
|
+ # Verify both models are loaded
|
|
|
+ assert _get_model_status(first) == "loaded"
|
|
|
+ assert _get_model_status(second) == "loaded"
|
|
|
+
|
|
|
+ # Load the third model - this should trigger LRU eviction of the first model
|
|
|
+ _load_model_and_wait(third, timeout=120)
|
|
|
+
|
|
|
+ # Verify eviction: third is loaded, first was evicted
|
|
|
+ assert _get_model_status(third) == "loaded"
|
|
|
+ assert _get_model_status(first) == "unloaded"
|
|
|
+
|
|
|
+
|
|
|
+def test_router_no_models_autoload():
|
|
|
+ global server
|
|
|
+ server.no_models_autoload = True
|
|
|
+ server.start()
|
|
|
+ model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
|
|
|
+
|
|
|
+ res = server.make_request(
|
|
|
+ "POST",
|
|
|
+ "/v1/chat/completions",
|
|
|
+ data={
|
|
|
+ "model": model_id,
|
|
|
+ "messages": [{"role": "user", "content": "hello"}],
|
|
|
+ "max_tokens": 4,
|
|
|
+ },
|
|
|
+ )
|
|
|
+ assert res.status_code == 400
|
|
|
+ assert "error" in res.body
|
|
|
+
|
|
|
+ _load_model_and_wait(model_id)
|
|
|
+
|
|
|
+ success_res = server.make_request(
|
|
|
+ "POST",
|
|
|
+ "/v1/chat/completions",
|
|
|
+ data={
|
|
|
+ "model": model_id,
|
|
|
+ "messages": [{"role": "user", "content": "hello"}],
|
|
|
+ "max_tokens": 4,
|
|
|
+ },
|
|
|
+ )
|
|
|
+ assert success_res.status_code == 200
|
|
|
+ assert "error" not in success_res.body
|
|
|
+
|
|
|
+
|
|
|
+def test_router_api_key_required():
|
|
|
+ global server
|
|
|
+ server.api_key = "sk-router-secret"
|
|
|
+ server.start()
|
|
|
+
|
|
|
+ model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
|
|
|
+ auth_headers = {"Authorization": f"Bearer {server.api_key}"}
|
|
|
+
|
|
|
+ res = server.make_request(
|
|
|
+ "POST",
|
|
|
+ "/v1/chat/completions",
|
|
|
+ data={
|
|
|
+ "model": model_id,
|
|
|
+ "messages": [{"role": "user", "content": "hello"}],
|
|
|
+ "max_tokens": 4,
|
|
|
+ },
|
|
|
+ )
|
|
|
+ assert res.status_code == 401
|
|
|
+ assert res.body.get("error", {}).get("type") == "authentication_error"
|
|
|
+
|
|
|
+ _load_model_and_wait(model_id, headers=auth_headers)
|
|
|
+
|
|
|
+ authed = server.make_request(
|
|
|
+ "POST",
|
|
|
+ "/v1/chat/completions",
|
|
|
+ headers=auth_headers,
|
|
|
+ data={
|
|
|
+ "model": model_id,
|
|
|
+ "messages": [{"role": "user", "content": "hello"}],
|
|
|
+ "max_tokens": 4,
|
|
|
+ },
|
|
|
+ )
|
|
|
+ assert authed.status_code == 200
|
|
|
+ assert "error" not in authed.body
|