test_speculative.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import pytest
  2. from utils import *
  3. # We use a F16 MOE gguf as main model, and q4_0 as draft model
  4. server = ServerPreset.stories15m_moe()
  5. MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf"
  6. def create_server():
  7. global server
  8. server = ServerPreset.stories15m_moe()
  9. # set default values
  10. server.model_draft = download_file(MODEL_DRAFT_FILE_URL)
  11. server.draft_min = 4
  12. server.draft_max = 8
  13. @pytest.fixture(scope="module", autouse=True)
  14. def fixture_create_server():
  15. return create_server()
  16. def test_with_and_without_draft():
  17. global server
  18. server.model_draft = None # disable draft model
  19. server.start()
  20. res = server.make_request("POST", "/completion", data={
  21. "prompt": "I believe the meaning of life is",
  22. "temperature": 0.0,
  23. "top_k": 1,
  24. })
  25. assert res.status_code == 200
  26. content_no_draft = res.body["content"]
  27. server.stop()
  28. # create new server with draft model
  29. create_server()
  30. server.start()
  31. res = server.make_request("POST", "/completion", data={
  32. "prompt": "I believe the meaning of life is",
  33. "temperature": 0.0,
  34. "top_k": 1,
  35. })
  36. assert res.status_code == 200
  37. content_draft = res.body["content"]
  38. assert content_no_draft == content_draft
  39. def test_different_draft_min_draft_max():
  40. global server
  41. test_values = [
  42. (1, 2),
  43. (1, 4),
  44. (4, 8),
  45. (4, 12),
  46. (8, 16),
  47. ]
  48. last_content = None
  49. for draft_min, draft_max in test_values:
  50. server.stop()
  51. server.draft_min = draft_min
  52. server.draft_max = draft_max
  53. server.start()
  54. res = server.make_request("POST", "/completion", data={
  55. "prompt": "I believe the meaning of life is",
  56. "temperature": 0.0,
  57. "top_k": 1,
  58. })
  59. assert res.status_code == 200
  60. if last_content is not None:
  61. assert last_content == res.body["content"]
  62. last_content = res.body["content"]
  63. def test_slot_ctx_not_exceeded():
  64. global server
  65. server.n_ctx = 64
  66. server.start()
  67. res = server.make_request("POST", "/completion", data={
  68. "prompt": "Hello " * 56,
  69. "temperature": 0.0,
  70. "top_k": 1,
  71. "speculative.p_min": 0.0,
  72. })
  73. assert res.status_code == 200
  74. assert len(res.body["content"]) > 0
  75. def test_with_ctx_shift():
  76. global server
  77. server.n_ctx = 64
  78. server.start()
  79. res = server.make_request("POST", "/completion", data={
  80. "prompt": "Hello " * 56,
  81. "temperature": 0.0,
  82. "top_k": 1,
  83. "n_predict": 64,
  84. "speculative.p_min": 0.0,
  85. })
  86. assert res.status_code == 200
  87. assert len(res.body["content"]) > 0
  88. assert res.body["tokens_predicted"] == 64
  89. assert res.body["truncated"] == True
  90. @pytest.mark.parametrize("n_slots,n_requests", [
  91. (1, 2),
  92. (2, 2),
  93. ])
  94. def test_multi_requests_parallel(n_slots: int, n_requests: int):
  95. global server
  96. server.n_slots = n_slots
  97. server.start()
  98. tasks = []
  99. for _ in range(n_requests):
  100. tasks.append((server.make_request, ("POST", "/completion", {
  101. "prompt": "I believe the meaning of life is",
  102. "temperature": 0.0,
  103. "top_k": 1,
  104. })))
  105. results = parallel_function_calls(tasks)
  106. for res in results:
  107. assert res.status_code == 200
  108. assert match_regex("(wise|kind|owl|answer)+", res.body["content"])