test_completion.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. import pytest
  2. import time
  3. from utils import *
  4. server = ServerPreset.tinyllama2()
  5. @pytest.fixture(scope="module", autouse=True)
  6. def create_server():
  7. global server
  8. server = ServerPreset.tinyllama2()
  9. @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
  10. ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
  11. ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
  12. ])
  13. def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
  14. global server
  15. server.start()
  16. res = server.make_request("POST", "/completion", data={
  17. "n_predict": n_predict,
  18. "prompt": prompt,
  19. })
  20. assert res.status_code == 200
  21. assert res.body["timings"]["prompt_n"] == n_prompt
  22. assert res.body["timings"]["predicted_n"] == n_predicted
  23. assert res.body["truncated"] == truncated
  24. assert match_regex(re_content, res.body["content"])
  25. @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
  26. ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
  27. ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
  28. ])
  29. def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
  30. global server
  31. server.start()
  32. res = server.make_stream_request("POST", "/completion", data={
  33. "n_predict": n_predict,
  34. "prompt": prompt,
  35. "stream": True,
  36. })
  37. content = ""
  38. for data in res:
  39. assert "stop" in data and type(data["stop"]) == bool
  40. if data["stop"]:
  41. assert data["timings"]["prompt_n"] == n_prompt
  42. assert data["timings"]["predicted_n"] == n_predicted
  43. assert data["truncated"] == truncated
  44. assert data["stop_type"] == "limit"
  45. assert "generation_settings" in data
  46. assert server.n_predict is not None
  47. assert data["generation_settings"]["n_predict"] == min(n_predict, server.n_predict)
  48. assert data["generation_settings"]["seed"] == server.seed
  49. assert match_regex(re_content, content)
  50. else:
  51. content += data["content"]
  52. def test_completion_stream_vs_non_stream():
  53. global server
  54. server.start()
  55. res_stream = server.make_stream_request("POST", "/completion", data={
  56. "n_predict": 8,
  57. "prompt": "I believe the meaning of life is",
  58. "stream": True,
  59. })
  60. res_non_stream = server.make_request("POST", "/completion", data={
  61. "n_predict": 8,
  62. "prompt": "I believe the meaning of life is",
  63. })
  64. content_stream = ""
  65. for data in res_stream:
  66. content_stream += data["content"]
  67. assert content_stream == res_non_stream.body["content"]
  68. @pytest.mark.parametrize("n_slots", [1, 2])
  69. def test_consistent_result_same_seed(n_slots: int):
  70. global server
  71. server.n_slots = n_slots
  72. server.start()
  73. last_res = None
  74. for _ in range(4):
  75. res = server.make_request("POST", "/completion", data={
  76. "prompt": "I believe the meaning of life is",
  77. "seed": 42,
  78. "temperature": 1.0,
  79. "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
  80. })
  81. if last_res is not None:
  82. assert res.body["content"] == last_res.body["content"]
  83. last_res = res
  84. @pytest.mark.parametrize("n_slots", [1, 2])
  85. def test_different_result_different_seed(n_slots: int):
  86. global server
  87. server.n_slots = n_slots
  88. server.start()
  89. last_res = None
  90. for seed in range(4):
  91. res = server.make_request("POST", "/completion", data={
  92. "prompt": "I believe the meaning of life is",
  93. "seed": seed,
  94. "temperature": 1.0,
  95. "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
  96. })
  97. if last_res is not None:
  98. assert res.body["content"] != last_res.body["content"]
  99. last_res = res
  100. @pytest.mark.parametrize("n_batch", [16, 32])
  101. @pytest.mark.parametrize("temperature", [0.0, 1.0])
  102. def test_consistent_result_different_batch_size(n_batch: int, temperature: float):
  103. global server
  104. server.n_batch = n_batch
  105. server.start()
  106. last_res = None
  107. for _ in range(4):
  108. res = server.make_request("POST", "/completion", data={
  109. "prompt": "I believe the meaning of life is",
  110. "seed": 42,
  111. "temperature": temperature,
  112. "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
  113. })
  114. if last_res is not None:
  115. assert res.body["content"] == last_res.body["content"]
  116. last_res = res
  117. @pytest.mark.skip(reason="This test fails on linux, need to be fixed")
  118. def test_cache_vs_nocache_prompt():
  119. global server
  120. server.start()
  121. res_cache = server.make_request("POST", "/completion", data={
  122. "prompt": "I believe the meaning of life is",
  123. "seed": 42,
  124. "temperature": 1.0,
  125. "cache_prompt": True,
  126. })
  127. res_no_cache = server.make_request("POST", "/completion", data={
  128. "prompt": "I believe the meaning of life is",
  129. "seed": 42,
  130. "temperature": 1.0,
  131. "cache_prompt": False,
  132. })
  133. assert res_cache.body["content"] == res_no_cache.body["content"]
  134. def test_completion_with_tokens_input():
  135. global server
  136. server.temperature = 0.0
  137. server.start()
  138. prompt_str = "I believe the meaning of life is"
  139. res = server.make_request("POST", "/tokenize", data={
  140. "content": prompt_str,
  141. "add_special": True,
  142. })
  143. assert res.status_code == 200
  144. tokens = res.body["tokens"]
  145. # single completion
  146. res = server.make_request("POST", "/completion", data={
  147. "prompt": tokens,
  148. })
  149. assert res.status_code == 200
  150. assert type(res.body["content"]) == str
  151. # batch completion
  152. res = server.make_request("POST", "/completion", data={
  153. "prompt": [tokens, tokens],
  154. })
  155. assert res.status_code == 200
  156. assert type(res.body) == list
  157. assert len(res.body) == 2
  158. assert res.body[0]["content"] == res.body[1]["content"]
  159. # mixed string and tokens
  160. res = server.make_request("POST", "/completion", data={
  161. "prompt": [tokens, prompt_str],
  162. })
  163. assert res.status_code == 200
  164. assert type(res.body) == list
  165. assert len(res.body) == 2
  166. assert res.body[0]["content"] == res.body[1]["content"]
  167. # mixed string and tokens in one sequence
  168. res = server.make_request("POST", "/completion", data={
  169. "prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str],
  170. })
  171. assert res.status_code == 200
  172. assert type(res.body["content"]) == str
  173. @pytest.mark.parametrize("n_slots,n_requests", [
  174. (1, 3),
  175. (2, 2),
  176. (2, 4),
  177. (4, 2), # some slots must be idle
  178. (4, 6),
  179. ])
  180. def test_completion_parallel_slots(n_slots: int, n_requests: int):
  181. global server
  182. server.n_slots = n_slots
  183. server.temperature = 0.0
  184. server.start()
  185. PROMPTS = [
  186. ("Write a very long book.", "(very|special|big)+"),
  187. ("Write another a poem.", "(small|house)+"),
  188. ("What is LLM?", "(Dad|said)+"),
  189. ("The sky is blue and I love it.", "(climb|leaf)+"),
  190. ("Write another very long music lyrics.", "(friends|step|sky)+"),
  191. ("Write a very long joke.", "(cat|Whiskers)+"),
  192. ]
  193. def check_slots_status():
  194. should_all_slots_busy = n_requests >= n_slots
  195. time.sleep(0.1)
  196. res = server.make_request("GET", "/slots")
  197. n_busy = sum([1 for slot in res.body if slot["is_processing"]])
  198. if should_all_slots_busy:
  199. assert n_busy == n_slots
  200. else:
  201. assert n_busy <= n_slots
  202. tasks = []
  203. for i in range(n_requests):
  204. prompt, re_content = PROMPTS[i % len(PROMPTS)]
  205. tasks.append((server.make_request, ("POST", "/completion", {
  206. "prompt": prompt,
  207. "seed": 42,
  208. "temperature": 1.0,
  209. })))
  210. tasks.append((check_slots_status, ()))
  211. results = parallel_function_calls(tasks)
  212. # check results
  213. for i in range(n_requests):
  214. prompt, re_content = PROMPTS[i % len(PROMPTS)]
  215. res = results[i]
  216. assert res.status_code == 200
  217. assert type(res.body["content"]) == str
  218. assert len(res.body["content"]) > 10
  219. # FIXME: the result is not deterministic when using other slot than slot 0
  220. # assert match_regex(re_content, res.body["content"])
  221. def test_n_probs():
  222. global server
  223. server.start()
  224. res = server.make_request("POST", "/completion", data={
  225. "prompt": "I believe the meaning of life is",
  226. "n_probs": 10,
  227. "temperature": 0.0,
  228. "n_predict": 5,
  229. })
  230. assert res.status_code == 200
  231. assert "completion_probabilities" in res.body
  232. assert len(res.body["completion_probabilities"]) == 5
  233. for tok in res.body["completion_probabilities"]:
  234. assert "probs" in tok
  235. assert len(tok["probs"]) == 10
  236. for prob in tok["probs"]:
  237. assert "prob" in prob
  238. assert "tok_str" in prob
  239. assert 0.0 <= prob["prob"] <= 1.0