test_completion.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. import pytest
  2. import time
  3. from openai import OpenAI
  4. from utils import *
  5. server = ServerPreset.tinyllama2()
  6. @pytest.fixture(scope="module", autouse=True)
  7. def create_server():
  8. global server
  9. server = ServerPreset.tinyllama2()
  10. @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [
  11. ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False),
  12. ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True),
  13. ])
  14. def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool):
  15. global server
  16. server.start()
  17. res = server.make_request("POST", "/completion", data={
  18. "n_predict": n_predict,
  19. "prompt": prompt,
  20. "return_tokens": return_tokens,
  21. })
  22. assert res.status_code == 200
  23. assert res.body["timings"]["prompt_n"] == n_prompt
  24. assert res.body["timings"]["predicted_n"] == n_predicted
  25. assert res.body["truncated"] == truncated
  26. assert type(res.body["has_new_line"]) == bool
  27. assert match_regex(re_content, res.body["content"])
  28. if return_tokens:
  29. assert len(res.body["tokens"]) > 0
  30. assert all(type(tok) == int for tok in res.body["tokens"])
  31. else:
  32. assert res.body["tokens"] == []
  33. @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
  34. ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
  35. ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
  36. ])
  37. def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
  38. global server
  39. server.start()
  40. res = server.make_stream_request("POST", "/completion", data={
  41. "n_predict": n_predict,
  42. "prompt": prompt,
  43. "stream": True,
  44. })
  45. content = ""
  46. for data in res:
  47. assert "stop" in data and type(data["stop"]) == bool
  48. if data["stop"]:
  49. assert data["timings"]["prompt_n"] == n_prompt
  50. assert data["timings"]["predicted_n"] == n_predicted
  51. assert data["truncated"] == truncated
  52. assert data["stop_type"] == "limit"
  53. assert type(data["has_new_line"]) == bool
  54. assert "generation_settings" in data
  55. assert server.n_predict is not None
  56. assert data["generation_settings"]["n_predict"] == min(n_predict, server.n_predict)
  57. assert data["generation_settings"]["seed"] == server.seed
  58. assert match_regex(re_content, content)
  59. else:
  60. assert len(data["tokens"]) > 0
  61. assert all(type(tok) == int for tok in data["tokens"])
  62. content += data["content"]
  63. def test_completion_stream_vs_non_stream():
  64. global server
  65. server.start()
  66. res_stream = server.make_stream_request("POST", "/completion", data={
  67. "n_predict": 8,
  68. "prompt": "I believe the meaning of life is",
  69. "stream": True,
  70. })
  71. res_non_stream = server.make_request("POST", "/completion", data={
  72. "n_predict": 8,
  73. "prompt": "I believe the meaning of life is",
  74. })
  75. content_stream = ""
  76. for data in res_stream:
  77. content_stream += data["content"]
  78. assert content_stream == res_non_stream.body["content"]
  79. def test_completion_stream_with_openai_library():
  80. global server
  81. server.start()
  82. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  83. res = client.completions.create(
  84. model="davinci-002",
  85. prompt="I believe the meaning of life is",
  86. max_tokens=8,
  87. )
  88. assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
  89. assert res.choices[0].finish_reason == "length"
  90. assert res.choices[0].text is not None
  91. assert match_regex("(going|bed)+", res.choices[0].text)
  92. def test_completion_with_openai_library():
  93. global server
  94. server.start()
  95. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  96. res = client.completions.create(
  97. model="davinci-002",
  98. prompt="I believe the meaning of life is",
  99. max_tokens=8,
  100. stream=True,
  101. )
  102. output_text = ''
  103. for data in res:
  104. choice = data.choices[0]
  105. if choice.finish_reason is None:
  106. assert choice.text is not None
  107. output_text += choice.text
  108. assert match_regex("(going|bed)+", output_text)
  109. @pytest.mark.parametrize("n_slots", [1, 2])
  110. def test_consistent_result_same_seed(n_slots: int):
  111. global server
  112. server.n_slots = n_slots
  113. server.start()
  114. last_res = None
  115. for _ in range(4):
  116. res = server.make_request("POST", "/completion", data={
  117. "prompt": "I believe the meaning of life is",
  118. "seed": 42,
  119. "temperature": 0.0,
  120. "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
  121. })
  122. if last_res is not None:
  123. assert res.body["content"] == last_res.body["content"]
  124. last_res = res
  125. @pytest.mark.parametrize("n_slots", [1, 2])
  126. def test_different_result_different_seed(n_slots: int):
  127. global server
  128. server.n_slots = n_slots
  129. server.start()
  130. last_res = None
  131. for seed in range(4):
  132. res = server.make_request("POST", "/completion", data={
  133. "prompt": "I believe the meaning of life is",
  134. "seed": seed,
  135. "temperature": 1.0,
  136. "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
  137. })
  138. if last_res is not None:
  139. assert res.body["content"] != last_res.body["content"]
  140. last_res = res
  141. # TODO figure why it don't work with temperature = 1
  142. # @pytest.mark.parametrize("temperature", [0.0, 1.0])
  143. @pytest.mark.parametrize("n_batch", [16, 32])
  144. @pytest.mark.parametrize("temperature", [0.0])
  145. def test_consistent_result_different_batch_size(n_batch: int, temperature: float):
  146. global server
  147. server.n_batch = n_batch
  148. server.start()
  149. last_res = None
  150. for _ in range(4):
  151. res = server.make_request("POST", "/completion", data={
  152. "prompt": "I believe the meaning of life is",
  153. "seed": 42,
  154. "temperature": temperature,
  155. "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
  156. })
  157. if last_res is not None:
  158. assert res.body["content"] == last_res.body["content"]
  159. last_res = res
  160. @pytest.mark.skip(reason="This test fails on linux, need to be fixed")
  161. def test_cache_vs_nocache_prompt():
  162. global server
  163. server.start()
  164. res_cache = server.make_request("POST", "/completion", data={
  165. "prompt": "I believe the meaning of life is",
  166. "seed": 42,
  167. "temperature": 1.0,
  168. "cache_prompt": True,
  169. })
  170. res_no_cache = server.make_request("POST", "/completion", data={
  171. "prompt": "I believe the meaning of life is",
  172. "seed": 42,
  173. "temperature": 1.0,
  174. "cache_prompt": False,
  175. })
  176. assert res_cache.body["content"] == res_no_cache.body["content"]
  177. def test_completion_with_tokens_input():
  178. global server
  179. server.temperature = 0.0
  180. server.start()
  181. prompt_str = "I believe the meaning of life is"
  182. res = server.make_request("POST", "/tokenize", data={
  183. "content": prompt_str,
  184. "add_special": True,
  185. })
  186. assert res.status_code == 200
  187. tokens = res.body["tokens"]
  188. # single completion
  189. res = server.make_request("POST", "/completion", data={
  190. "prompt": tokens,
  191. })
  192. assert res.status_code == 200
  193. assert type(res.body["content"]) == str
  194. # batch completion
  195. res = server.make_request("POST", "/completion", data={
  196. "prompt": [tokens, tokens],
  197. })
  198. assert res.status_code == 200
  199. assert type(res.body) == list
  200. assert len(res.body) == 2
  201. assert res.body[0]["content"] == res.body[1]["content"]
  202. # mixed string and tokens
  203. res = server.make_request("POST", "/completion", data={
  204. "prompt": [tokens, prompt_str],
  205. })
  206. assert res.status_code == 200
  207. assert type(res.body) == list
  208. assert len(res.body) == 2
  209. assert res.body[0]["content"] == res.body[1]["content"]
  210. # mixed string and tokens in one sequence
  211. res = server.make_request("POST", "/completion", data={
  212. "prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str],
  213. })
  214. assert res.status_code == 200
  215. assert type(res.body["content"]) == str
  216. @pytest.mark.parametrize("n_slots,n_requests", [
  217. (1, 3),
  218. (2, 2),
  219. (2, 4),
  220. (4, 2), # some slots must be idle
  221. (4, 6),
  222. ])
  223. def test_completion_parallel_slots(n_slots: int, n_requests: int):
  224. global server
  225. server.n_slots = n_slots
  226. server.temperature = 0.0
  227. server.start()
  228. PROMPTS = [
  229. ("Write a very long book.", "(very|special|big)+"),
  230. ("Write another a poem.", "(small|house)+"),
  231. ("What is LLM?", "(Dad|said)+"),
  232. ("The sky is blue and I love it.", "(climb|leaf)+"),
  233. ("Write another very long music lyrics.", "(friends|step|sky)+"),
  234. ("Write a very long joke.", "(cat|Whiskers)+"),
  235. ]
  236. def check_slots_status():
  237. should_all_slots_busy = n_requests >= n_slots
  238. time.sleep(0.1)
  239. res = server.make_request("GET", "/slots")
  240. n_busy = sum([1 for slot in res.body if slot["is_processing"]])
  241. if should_all_slots_busy:
  242. assert n_busy == n_slots
  243. else:
  244. assert n_busy <= n_slots
  245. tasks = []
  246. for i in range(n_requests):
  247. prompt, re_content = PROMPTS[i % len(PROMPTS)]
  248. tasks.append((server.make_request, ("POST", "/completion", {
  249. "prompt": prompt,
  250. "seed": 42,
  251. "temperature": 1.0,
  252. })))
  253. tasks.append((check_slots_status, ()))
  254. results = parallel_function_calls(tasks)
  255. # check results
  256. for i in range(n_requests):
  257. prompt, re_content = PROMPTS[i % len(PROMPTS)]
  258. res = results[i]
  259. assert res.status_code == 200
  260. assert type(res.body["content"]) == str
  261. assert len(res.body["content"]) > 10
  262. # FIXME: the result is not deterministic when using other slot than slot 0
  263. # assert match_regex(re_content, res.body["content"])
  264. @pytest.mark.parametrize(
  265. "prompt,n_predict,response_fields",
  266. [
  267. ("I believe the meaning of life is", 8, []),
  268. ("I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"]),
  269. ],
  270. )
  271. def test_completion_response_fields(
  272. prompt: str, n_predict: int, response_fields: list[str]
  273. ):
  274. global server
  275. server.start()
  276. res = server.make_request(
  277. "POST",
  278. "/completion",
  279. data={
  280. "n_predict": n_predict,
  281. "prompt": prompt,
  282. "response_fields": response_fields,
  283. },
  284. )
  285. assert res.status_code == 200
  286. assert "content" in res.body
  287. assert len(res.body["content"])
  288. if len(response_fields):
  289. assert res.body["generation_settings/n_predict"] == n_predict
  290. assert res.body["prompt"] == "<s> " + prompt
  291. assert isinstance(res.body["content"], str)
  292. assert len(res.body) == len(response_fields)
  293. else:
  294. assert len(res.body)
  295. assert "generation_settings" in res.body
  296. def test_n_probs():
  297. global server
  298. server.start()
  299. res = server.make_request("POST", "/completion", data={
  300. "prompt": "I believe the meaning of life is",
  301. "n_probs": 10,
  302. "temperature": 0.0,
  303. "n_predict": 5,
  304. })
  305. assert res.status_code == 200
  306. assert "completion_probabilities" in res.body
  307. assert len(res.body["completion_probabilities"]) == 5
  308. for tok in res.body["completion_probabilities"]:
  309. assert "id" in tok and tok["id"] > 0
  310. assert "token" in tok and type(tok["token"]) == str
  311. assert "logprob" in tok and tok["logprob"] <= 0.0
  312. assert "bytes" in tok and type(tok["bytes"]) == list
  313. assert len(tok["top_logprobs"]) == 10
  314. for prob in tok["top_logprobs"]:
  315. assert "id" in prob and prob["id"] > 0
  316. assert "token" in prob and type(prob["token"]) == str
  317. assert "logprob" in prob and prob["logprob"] <= 0.0
  318. assert "bytes" in prob and type(prob["bytes"]) == list
  319. def test_n_probs_stream():
  320. global server
  321. server.start()
  322. res = server.make_stream_request("POST", "/completion", data={
  323. "prompt": "I believe the meaning of life is",
  324. "n_probs": 10,
  325. "temperature": 0.0,
  326. "n_predict": 5,
  327. "stream": True,
  328. })
  329. for data in res:
  330. if data["stop"] == False:
  331. assert "completion_probabilities" in data
  332. assert len(data["completion_probabilities"]) == 1
  333. for tok in data["completion_probabilities"]:
  334. assert "id" in tok and tok["id"] > 0
  335. assert "token" in tok and type(tok["token"]) == str
  336. assert "logprob" in tok and tok["logprob"] <= 0.0
  337. assert "bytes" in tok and type(tok["bytes"]) == list
  338. assert len(tok["top_logprobs"]) == 10
  339. for prob in tok["top_logprobs"]:
  340. assert "id" in prob and prob["id"] > 0
  341. assert "token" in prob and type(prob["token"]) == str
  342. assert "logprob" in prob and prob["logprob"] <= 0.0
  343. assert "bytes" in prob and type(prob["bytes"]) == list
  344. def test_n_probs_post_sampling():
  345. global server
  346. server.start()
  347. res = server.make_request("POST", "/completion", data={
  348. "prompt": "I believe the meaning of life is",
  349. "n_probs": 10,
  350. "temperature": 0.0,
  351. "n_predict": 5,
  352. "post_sampling_probs": True,
  353. })
  354. assert res.status_code == 200
  355. assert "completion_probabilities" in res.body
  356. assert len(res.body["completion_probabilities"]) == 5
  357. for tok in res.body["completion_probabilities"]:
  358. assert "id" in tok and tok["id"] > 0
  359. assert "token" in tok and type(tok["token"]) == str
  360. assert "prob" in tok and 0.0 < tok["prob"] <= 1.0
  361. assert "bytes" in tok and type(tok["bytes"]) == list
  362. assert len(tok["top_probs"]) == 10
  363. for prob in tok["top_probs"]:
  364. assert "id" in prob and prob["id"] > 0
  365. assert "token" in prob and type(prob["token"]) == str
  366. assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0
  367. assert "bytes" in prob and type(prob["bytes"]) == list
  368. # because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
  369. assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])