test_speculative.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import pytest
  2. from utils import *
  3. # We use a F16 MOE gguf as main model, and q4_0 as draft model
  4. server = ServerPreset.stories15m_moe()
  5. MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf"
  6. def create_server():
  7. global server
  8. server = ServerPreset.stories15m_moe()
  9. # download draft model file if needed
  10. file_name = MODEL_DRAFT_FILE_URL.split('/').pop()
  11. model_draft_file = f'../../../{file_name}'
  12. if not os.path.exists(model_draft_file):
  13. print(f"Downloading {MODEL_DRAFT_FILE_URL} to {model_draft_file}")
  14. with open(model_draft_file, 'wb') as f:
  15. f.write(requests.get(MODEL_DRAFT_FILE_URL).content)
  16. print(f"Done downloading draft model file")
  17. # set default values
  18. server.model_draft = model_draft_file
  19. server.draft_min = 4
  20. server.draft_max = 8
  21. @pytest.fixture(scope="module", autouse=True)
  22. def fixture_create_server():
  23. return create_server()
  24. def test_with_and_without_draft():
  25. global server
  26. server.model_draft = None # disable draft model
  27. server.start()
  28. res = server.make_request("POST", "/completion", data={
  29. "prompt": "I believe the meaning of life is",
  30. "temperature": 0.0,
  31. "top_k": 1,
  32. })
  33. assert res.status_code == 200
  34. content_no_draft = res.body["content"]
  35. server.stop()
  36. # create new server with draft model
  37. create_server()
  38. server.start()
  39. res = server.make_request("POST", "/completion", data={
  40. "prompt": "I believe the meaning of life is",
  41. "temperature": 0.0,
  42. "top_k": 1,
  43. })
  44. assert res.status_code == 200
  45. content_draft = res.body["content"]
  46. assert content_no_draft == content_draft
  47. def test_different_draft_min_draft_max():
  48. global server
  49. test_values = [
  50. (1, 2),
  51. (1, 4),
  52. (4, 8),
  53. (4, 12),
  54. (8, 16),
  55. ]
  56. last_content = None
  57. for draft_min, draft_max in test_values:
  58. server.stop()
  59. server.draft_min = draft_min
  60. server.draft_max = draft_max
  61. server.start()
  62. res = server.make_request("POST", "/completion", data={
  63. "prompt": "I believe the meaning of life is",
  64. "temperature": 0.0,
  65. "top_k": 1,
  66. })
  67. assert res.status_code == 200
  68. if last_content is not None:
  69. assert last_content == res.body["content"]
  70. last_content = res.body["content"]
  71. def test_slot_ctx_not_exceeded():
  72. global server
  73. server.n_ctx = 64
  74. server.start()
  75. res = server.make_request("POST", "/completion", data={
  76. "prompt": "Hello " * 56,
  77. "temperature": 0.0,
  78. "top_k": 1,
  79. "speculative.p_min": 0.0,
  80. })
  81. assert res.status_code == 200
  82. assert len(res.body["content"]) > 0
  83. def test_with_ctx_shift():
  84. global server
  85. server.n_ctx = 64
  86. server.start()
  87. res = server.make_request("POST", "/completion", data={
  88. "prompt": "Hello " * 56,
  89. "temperature": 0.0,
  90. "top_k": 1,
  91. "n_predict": 64,
  92. "speculative.p_min": 0.0,
  93. })
  94. assert res.status_code == 200
  95. assert len(res.body["content"]) > 0
  96. assert res.body["tokens_predicted"] == 64
  97. assert res.body["truncated"] == True
  98. @pytest.mark.parametrize("n_slots,n_requests", [
  99. (1, 2),
  100. (2, 2),
  101. ])
  102. def test_multi_requests_parallel(n_slots: int, n_requests: int):
  103. global server
  104. server.n_slots = n_slots
  105. server.start()
  106. tasks = []
  107. for _ in range(n_requests):
  108. tasks.append((server.make_request, ("POST", "/completion", {
  109. "prompt": "I believe the meaning of life is",
  110. "temperature": 0.0,
  111. "top_k": 1,
  112. })))
  113. results = parallel_function_calls(tasks)
  114. for res in results:
  115. assert res.status_code == 200
  116. assert match_regex("(wise|kind|owl|answer)+", res.body["content"])