1
0

test_router.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. import pytest
  2. from utils import *
  3. server: ServerProcess
  4. @pytest.fixture(autouse=True)
  5. def create_server():
  6. global server
  7. server = ServerPreset.router()
  8. @pytest.mark.parametrize(
  9. "model,success",
  10. [
  11. ("ggml-org/tinygemma3-GGUF:Q8_0", True),
  12. ("non-existent/model", False),
  13. ]
  14. )
  15. def test_router_chat_completion_stream(model: str, success: bool):
  16. global server
  17. server.start()
  18. content = ""
  19. ex: ServerError | None = None
  20. try:
  21. res = server.make_stream_request("POST", "/chat/completions", data={
  22. "model": model,
  23. "max_tokens": 16,
  24. "messages": [
  25. {"role": "user", "content": "hello"},
  26. ],
  27. "stream": True,
  28. })
  29. for data in res:
  30. if data["choices"]:
  31. choice = data["choices"][0]
  32. if choice["finish_reason"] in ["stop", "length"]:
  33. assert "content" not in choice["delta"]
  34. else:
  35. assert choice["finish_reason"] is None
  36. content += choice["delta"]["content"] or ''
  37. except ServerError as e:
  38. ex = e
  39. if success:
  40. assert ex is None
  41. assert len(content) > 0
  42. else:
  43. assert ex is not None
  44. assert content == ""
  45. def _get_model_status(model_id: str) -> str:
  46. res = server.make_request("GET", "/models")
  47. assert res.status_code == 200
  48. for item in res.body.get("data", []):
  49. if item.get("id") == model_id or item.get("model") == model_id:
  50. return item["status"]["value"]
  51. raise AssertionError(f"Model {model_id} not found in /models response")
  52. def _wait_for_model_status(model_id: str, desired: set[str], timeout: int = 60) -> str:
  53. deadline = time.time() + timeout
  54. last_status = None
  55. while time.time() < deadline:
  56. last_status = _get_model_status(model_id)
  57. if last_status in desired:
  58. return last_status
  59. time.sleep(1)
  60. raise AssertionError(
  61. f"Timed out waiting for {model_id} to reach {desired}, last status: {last_status}"
  62. )
  63. def _load_model_and_wait(
  64. model_id: str, timeout: int = 60, headers: dict | None = None
  65. ) -> None:
  66. load_res = server.make_request(
  67. "POST", "/models/load", data={"model": model_id}, headers=headers
  68. )
  69. assert load_res.status_code == 200
  70. assert isinstance(load_res.body, dict)
  71. assert load_res.body.get("success") is True
  72. _wait_for_model_status(model_id, {"loaded"}, timeout=timeout)
  73. def test_router_unload_model():
  74. global server
  75. server.start()
  76. model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
  77. _load_model_and_wait(model_id)
  78. unload_res = server.make_request("POST", "/models/unload", data={"model": model_id})
  79. assert unload_res.status_code == 200
  80. assert unload_res.body.get("success") is True
  81. _wait_for_model_status(model_id, {"unloaded"})
  82. def test_router_models_max_evicts_lru():
  83. global server
  84. server.models_max = 2
  85. server.start()
  86. candidate_models = [
  87. "ggml-org/tinygemma3-GGUF:Q8_0",
  88. "ggml-org/test-model-stories260K",
  89. "ggml-org/test-model-stories260K-infill",
  90. ]
  91. # Load only the first 2 models to fill the cache
  92. first, second, third = candidate_models[:3]
  93. _load_model_and_wait(first, timeout=120)
  94. _load_model_and_wait(second, timeout=120)
  95. # Verify both models are loaded
  96. assert _get_model_status(first) == "loaded"
  97. assert _get_model_status(second) == "loaded"
  98. # Load the third model - this should trigger LRU eviction of the first model
  99. _load_model_and_wait(third, timeout=120)
  100. # Verify eviction: third is loaded, first was evicted
  101. assert _get_model_status(third) == "loaded"
  102. assert _get_model_status(first) == "unloaded"
  103. def test_router_no_models_autoload():
  104. global server
  105. server.no_models_autoload = True
  106. server.start()
  107. model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
  108. res = server.make_request(
  109. "POST",
  110. "/v1/chat/completions",
  111. data={
  112. "model": model_id,
  113. "messages": [{"role": "user", "content": "hello"}],
  114. "max_tokens": 4,
  115. },
  116. )
  117. assert res.status_code == 400
  118. assert "error" in res.body
  119. _load_model_and_wait(model_id)
  120. success_res = server.make_request(
  121. "POST",
  122. "/v1/chat/completions",
  123. data={
  124. "model": model_id,
  125. "messages": [{"role": "user", "content": "hello"}],
  126. "max_tokens": 4,
  127. },
  128. )
  129. assert success_res.status_code == 200
  130. assert "error" not in success_res.body
  131. def test_router_api_key_required():
  132. global server
  133. server.api_key = "sk-router-secret"
  134. server.start()
  135. model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
  136. auth_headers = {"Authorization": f"Bearer {server.api_key}"}
  137. res = server.make_request(
  138. "POST",
  139. "/v1/chat/completions",
  140. data={
  141. "model": model_id,
  142. "messages": [{"role": "user", "content": "hello"}],
  143. "max_tokens": 4,
  144. },
  145. )
  146. assert res.status_code == 401
  147. assert res.body.get("error", {}).get("type") == "authentication_error"
  148. _load_model_and_wait(model_id, headers=auth_headers)
  149. authed = server.make_request(
  150. "POST",
  151. "/v1/chat/completions",
  152. headers=auth_headers,
  153. data={
  154. "model": model_id,
  155. "messages": [{"role": "user", "content": "hello"}],
  156. "max_tokens": 4,
  157. },
  158. )
  159. assert authed.status_code == 200
  160. assert "error" not in authed.body