1
0

test_speculative.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import pytest
  2. from utils import *
  3. # We use a F16 MOE gguf as main model, and q4_0 as draft model
  4. server = ServerPreset.stories15m_moe()
  5. MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf"
  6. def create_server():
  7. global server
  8. server = ServerPreset.stories15m_moe()
  9. # set default values
  10. server.model_draft = download_file(MODEL_DRAFT_FILE_URL)
  11. server.draft_min = 4
  12. server.draft_max = 8
  13. server.fa = "off"
  14. @pytest.fixture(autouse=True)
  15. def fixture_create_server():
  16. return create_server()
  17. def test_with_and_without_draft():
  18. global server
  19. server.model_draft = None # disable draft model
  20. server.start()
  21. res = server.make_request("POST", "/completion", data={
  22. "prompt": "I believe the meaning of life is",
  23. "temperature": 0.0,
  24. "top_k": 1,
  25. })
  26. assert res.status_code == 200
  27. content_no_draft = res.body["content"]
  28. server.stop()
  29. # create new server with draft model
  30. create_server()
  31. server.start()
  32. res = server.make_request("POST", "/completion", data={
  33. "prompt": "I believe the meaning of life is",
  34. "temperature": 0.0,
  35. "top_k": 1,
  36. })
  37. assert res.status_code == 200
  38. content_draft = res.body["content"]
  39. assert content_no_draft == content_draft
  40. def test_different_draft_min_draft_max():
  41. global server
  42. test_values = [
  43. (1, 2),
  44. (1, 4),
  45. (4, 8),
  46. (4, 12),
  47. (8, 16),
  48. ]
  49. last_content = None
  50. for draft_min, draft_max in test_values:
  51. server.stop()
  52. server.draft_min = draft_min
  53. server.draft_max = draft_max
  54. server.start()
  55. res = server.make_request("POST", "/completion", data={
  56. "prompt": "I believe the meaning of life is",
  57. "temperature": 0.0,
  58. "top_k": 1,
  59. })
  60. assert res.status_code == 200
  61. if last_content is not None:
  62. assert last_content == res.body["content"]
  63. last_content = res.body["content"]
  64. def test_slot_ctx_not_exceeded():
  65. global server
  66. server.n_ctx = 64
  67. server.start()
  68. res = server.make_request("POST", "/completion", data={
  69. "prompt": "Hello " * 56,
  70. "temperature": 0.0,
  71. "top_k": 1,
  72. "speculative.p_min": 0.0,
  73. })
  74. assert res.status_code == 200
  75. assert len(res.body["content"]) > 0
  76. def test_with_ctx_shift():
  77. global server
  78. server.n_ctx = 64
  79. server.enable_ctx_shift = True
  80. server.start()
  81. res = server.make_request("POST", "/completion", data={
  82. "prompt": "Hello " * 56,
  83. "temperature": 0.0,
  84. "top_k": 1,
  85. "n_predict": 64,
  86. "speculative.p_min": 0.0,
  87. })
  88. assert res.status_code == 200
  89. assert len(res.body["content"]) > 0
  90. assert res.body["tokens_predicted"] == 64
  91. assert res.body["truncated"] == True
  92. @pytest.mark.parametrize("n_slots,n_requests", [
  93. (1, 2),
  94. (2, 2),
  95. ])
  96. def test_multi_requests_parallel(n_slots: int, n_requests: int):
  97. global server
  98. server.n_slots = n_slots
  99. server.start()
  100. tasks = []
  101. for _ in range(n_requests):
  102. tasks.append((server.make_request, ("POST", "/completion", {
  103. "prompt": "I believe the meaning of life is",
  104. "temperature": 0.0,
  105. "top_k": 1,
  106. })))
  107. results = parallel_function_calls(tasks)
  108. for res in results:
  109. assert res.status_code == 200
  110. assert match_regex("(wise|kind|owl|answer)+", res.body["content"])