test_completion.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. import pytest
  2. import time
  3. from utils import *
  4. server = ServerPreset.tinyllama2()
  5. @pytest.fixture(scope="module", autouse=True)
  6. def create_server():
  7. global server
  8. server = ServerPreset.tinyllama2()
  9. @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
  10. ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
  11. ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
  12. ])
  13. def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
  14. global server
  15. server.start()
  16. res = server.make_request("POST", "/completion", data={
  17. "n_predict": n_predict,
  18. "prompt": prompt,
  19. })
  20. assert res.status_code == 200
  21. assert res.body["timings"]["prompt_n"] == n_prompt
  22. assert res.body["timings"]["predicted_n"] == n_predicted
  23. assert res.body["truncated"] == truncated
  24. assert match_regex(re_content, res.body["content"])
  25. @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
  26. ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
  27. ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
  28. ])
  29. def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
  30. global server
  31. server.start()
  32. res = server.make_stream_request("POST", "/completion", data={
  33. "n_predict": n_predict,
  34. "prompt": prompt,
  35. "stream": True,
  36. })
  37. content = ""
  38. for data in res:
  39. if data["stop"]:
  40. assert data["timings"]["prompt_n"] == n_prompt
  41. assert data["timings"]["predicted_n"] == n_predicted
  42. assert data["truncated"] == truncated
  43. assert match_regex(re_content, content)
  44. else:
  45. content += data["content"]
  46. def test_completion_stream_vs_non_stream():
  47. global server
  48. server.start()
  49. res_stream = server.make_stream_request("POST", "/completion", data={
  50. "n_predict": 8,
  51. "prompt": "I believe the meaning of life is",
  52. "stream": True,
  53. })
  54. res_non_stream = server.make_request("POST", "/completion", data={
  55. "n_predict": 8,
  56. "prompt": "I believe the meaning of life is",
  57. })
  58. content_stream = ""
  59. for data in res_stream:
  60. content_stream += data["content"]
  61. assert content_stream == res_non_stream.body["content"]
  62. @pytest.mark.parametrize("n_slots", [1, 2])
  63. def test_consistent_result_same_seed(n_slots: int):
  64. global server
  65. server.n_slots = n_slots
  66. server.start()
  67. last_res = None
  68. for _ in range(4):
  69. res = server.make_request("POST", "/completion", data={
  70. "prompt": "I believe the meaning of life is",
  71. "seed": 42,
  72. "temperature": 1.0,
  73. "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
  74. })
  75. if last_res is not None:
  76. assert res.body["content"] == last_res.body["content"]
  77. last_res = res
  78. @pytest.mark.parametrize("n_slots", [1, 2])
  79. def test_different_result_different_seed(n_slots: int):
  80. global server
  81. server.n_slots = n_slots
  82. server.start()
  83. last_res = None
  84. for seed in range(4):
  85. res = server.make_request("POST", "/completion", data={
  86. "prompt": "I believe the meaning of life is",
  87. "seed": seed,
  88. "temperature": 1.0,
  89. "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
  90. })
  91. if last_res is not None:
  92. assert res.body["content"] != last_res.body["content"]
  93. last_res = res
  94. @pytest.mark.parametrize("n_batch", [16, 32])
  95. @pytest.mark.parametrize("temperature", [0.0, 1.0])
  96. def test_consistent_result_different_batch_size(n_batch: int, temperature: float):
  97. global server
  98. server.n_batch = n_batch
  99. server.start()
  100. last_res = None
  101. for _ in range(4):
  102. res = server.make_request("POST", "/completion", data={
  103. "prompt": "I believe the meaning of life is",
  104. "seed": 42,
  105. "temperature": temperature,
  106. "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
  107. })
  108. if last_res is not None:
  109. assert res.body["content"] == last_res.body["content"]
  110. last_res = res
  111. @pytest.mark.skip(reason="This test fails on linux, need to be fixed")
  112. def test_cache_vs_nocache_prompt():
  113. global server
  114. server.start()
  115. res_cache = server.make_request("POST", "/completion", data={
  116. "prompt": "I believe the meaning of life is",
  117. "seed": 42,
  118. "temperature": 1.0,
  119. "cache_prompt": True,
  120. })
  121. res_no_cache = server.make_request("POST", "/completion", data={
  122. "prompt": "I believe the meaning of life is",
  123. "seed": 42,
  124. "temperature": 1.0,
  125. "cache_prompt": False,
  126. })
  127. assert res_cache.body["content"] == res_no_cache.body["content"]
  128. def test_completion_with_tokens_input():
  129. global server
  130. server.temperature = 0.0
  131. server.start()
  132. prompt_str = "I believe the meaning of life is"
  133. res = server.make_request("POST", "/tokenize", data={
  134. "content": prompt_str,
  135. "add_special": True,
  136. })
  137. assert res.status_code == 200
  138. tokens = res.body["tokens"]
  139. # single completion
  140. res = server.make_request("POST", "/completion", data={
  141. "prompt": tokens,
  142. })
  143. assert res.status_code == 200
  144. assert type(res.body["content"]) == str
  145. # batch completion
  146. res = server.make_request("POST", "/completion", data={
  147. "prompt": [tokens, tokens],
  148. })
  149. assert res.status_code == 200
  150. assert type(res.body) == list
  151. assert len(res.body) == 2
  152. assert res.body[0]["content"] == res.body[1]["content"]
  153. # mixed string and tokens
  154. res = server.make_request("POST", "/completion", data={
  155. "prompt": [tokens, prompt_str],
  156. })
  157. assert res.status_code == 200
  158. assert type(res.body) == list
  159. assert len(res.body) == 2
  160. assert res.body[0]["content"] == res.body[1]["content"]
  161. # mixed string and tokens in one sequence
  162. res = server.make_request("POST", "/completion", data={
  163. "prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str],
  164. })
  165. assert res.status_code == 200
  166. assert type(res.body["content"]) == str
  167. @pytest.mark.parametrize("n_slots,n_requests", [
  168. (1, 3),
  169. (2, 2),
  170. (2, 4),
  171. (4, 2), # some slots must be idle
  172. (4, 6),
  173. ])
  174. def test_completion_parallel_slots(n_slots: int, n_requests: int):
  175. global server
  176. server.n_slots = n_slots
  177. server.temperature = 0.0
  178. server.start()
  179. PROMPTS = [
  180. ("Write a very long book.", "(very|special|big)+"),
  181. ("Write another a poem.", "(small|house)+"),
  182. ("What is LLM?", "(Dad|said)+"),
  183. ("The sky is blue and I love it.", "(climb|leaf)+"),
  184. ("Write another very long music lyrics.", "(friends|step|sky)+"),
  185. ("Write a very long joke.", "(cat|Whiskers)+"),
  186. ]
  187. def check_slots_status():
  188. should_all_slots_busy = n_requests >= n_slots
  189. time.sleep(0.1)
  190. res = server.make_request("GET", "/slots")
  191. n_busy = sum([1 for slot in res.body if slot["is_processing"]])
  192. if should_all_slots_busy:
  193. assert n_busy == n_slots
  194. else:
  195. assert n_busy <= n_slots
  196. tasks = []
  197. for i in range(n_requests):
  198. prompt, re_content = PROMPTS[i % len(PROMPTS)]
  199. tasks.append((server.make_request, ("POST", "/completion", {
  200. "prompt": prompt,
  201. "seed": 42,
  202. "temperature": 1.0,
  203. })))
  204. tasks.append((check_slots_status, ()))
  205. results = parallel_function_calls(tasks)
  206. # check results
  207. for i in range(n_requests):
  208. prompt, re_content = PROMPTS[i % len(PROMPTS)]
  209. res = results[i]
  210. assert res.status_code == 200
  211. assert type(res.body["content"]) == str
  212. assert len(res.body["content"]) > 10
  213. # FIXME: the result is not deterministic when using other slot than slot 0
  214. # assert match_regex(re_content, res.body["content"])
  215. def test_n_probs():
  216. global server
  217. server.start()
  218. res = server.make_request("POST", "/completion", data={
  219. "prompt": "I believe the meaning of life is",
  220. "n_probs": 10,
  221. "temperature": 0.0,
  222. "n_predict": 5,
  223. })
  224. assert res.status_code == 200
  225. assert "completion_probabilities" in res.body
  226. assert len(res.body["completion_probabilities"]) == 5
  227. for tok in res.body["completion_probabilities"]:
  228. assert "probs" in tok
  229. assert len(tok["probs"]) == 10
  230. for prob in tok["probs"]:
  231. assert "prob" in prob
  232. assert "tok_str" in prob
  233. assert 0.0 <= prob["prob"] <= 1.0