test_completion.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. import pytest
  2. import requests
  3. import time
  4. from openai import OpenAI
  5. from utils import *
  6. server = ServerPreset.tinyllama2()
  7. JSON_MULTIMODAL_KEY = "multimodal_data"
  8. JSON_PROMPT_STRING_KEY = "prompt_string"
  9. @pytest.fixture(autouse=True)
  10. def create_server():
  11. global server
  12. server = ServerPreset.tinyllama2()
  13. @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [
  14. ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False),
  15. ("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True),
  16. ])
  17. def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool):
  18. global server
  19. server.start()
  20. res = server.make_request("POST", "/completion", data={
  21. "n_predict": n_predict,
  22. "prompt": prompt,
  23. "return_tokens": return_tokens,
  24. })
  25. assert res.status_code == 200
  26. assert res.body["timings"]["prompt_n"] == n_prompt
  27. assert res.body["timings"]["predicted_n"] == n_predicted
  28. assert res.body["truncated"] == truncated
  29. assert type(res.body["has_new_line"]) == bool
  30. assert match_regex(re_content, res.body["content"])
  31. if return_tokens:
  32. assert len(res.body["tokens"]) > 0
  33. assert all(type(tok) == int for tok in res.body["tokens"])
  34. else:
  35. assert res.body["tokens"] == []
  36. @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
  37. ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
  38. ("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
  39. ])
  40. def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
  41. global server
  42. server.start()
  43. res = server.make_stream_request("POST", "/completion", data={
  44. "n_predict": n_predict,
  45. "prompt": prompt,
  46. "stream": True,
  47. })
  48. content = ""
  49. for data in res:
  50. assert "stop" in data and type(data["stop"]) == bool
  51. if data["stop"]:
  52. assert data["timings"]["prompt_n"] == n_prompt
  53. assert data["timings"]["predicted_n"] == n_predicted
  54. assert data["truncated"] == truncated
  55. assert data["stop_type"] == "limit"
  56. assert type(data["has_new_line"]) == bool
  57. assert "generation_settings" in data
  58. assert server.n_predict is not None
  59. assert data["generation_settings"]["n_predict"] == min(n_predict, server.n_predict)
  60. assert data["generation_settings"]["seed"] == server.seed
  61. assert match_regex(re_content, content)
  62. else:
  63. assert len(data["tokens"]) > 0
  64. assert all(type(tok) == int for tok in data["tokens"])
  65. content += data["content"]
  66. def test_completion_stream_vs_non_stream():
  67. global server
  68. server.start()
  69. res_stream = server.make_stream_request("POST", "/completion", data={
  70. "n_predict": 8,
  71. "prompt": "I believe the meaning of life is",
  72. "stream": True,
  73. })
  74. res_non_stream = server.make_request("POST", "/completion", data={
  75. "n_predict": 8,
  76. "prompt": "I believe the meaning of life is",
  77. })
  78. content_stream = ""
  79. for data in res_stream:
  80. content_stream += data["content"]
  81. assert content_stream == res_non_stream.body["content"]
  82. def test_completion_with_openai_library():
  83. global server
  84. server.start()
  85. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  86. res = client.completions.create(
  87. model="davinci-002",
  88. prompt="I believe the meaning of life is",
  89. max_tokens=8,
  90. )
  91. assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
  92. assert res.choices[0].finish_reason == "length"
  93. assert res.choices[0].text is not None
  94. assert match_regex("(going|bed)+", res.choices[0].text)
  95. def test_completion_stream_with_openai_library():
  96. global server
  97. server.start()
  98. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  99. res = client.completions.create(
  100. model="davinci-002",
  101. prompt="I believe the meaning of life is",
  102. max_tokens=8,
  103. stream=True,
  104. )
  105. output_text = ''
  106. for data in res:
  107. choice = data.choices[0]
  108. if choice.finish_reason is None:
  109. assert choice.text is not None
  110. output_text += choice.text
  111. assert match_regex("(going|bed)+", output_text)
  112. # Test case from https://github.com/ggml-org/llama.cpp/issues/13780
  113. @pytest.mark.slow
  114. def test_completion_stream_with_openai_library_stops():
  115. global server
  116. server.model_hf_repo = "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M"
  117. server.model_hf_file = None
  118. server.start()
  119. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  120. res = client.completions.create(
  121. model="davinci-002",
  122. prompt="System: You are helpfull assistant.\nAssistant:\nHey! How could I help?\nUser:\nTell me a joke.\nAssistant:\n",
  123. stop=["User:\n", "Assistant:\n"],
  124. max_tokens=200,
  125. stream=True,
  126. )
  127. output_text = ''
  128. for data in res:
  129. choice = data.choices[0]
  130. if choice.finish_reason is None:
  131. assert choice.text is not None
  132. output_text += choice.text
  133. assert match_regex("Sure, here's one for[\\s\\S]*", output_text), f'Unexpected output: {output_text}'
  134. @pytest.mark.parametrize("n_slots", [1, 2])
  135. def test_consistent_result_same_seed(n_slots: int):
  136. global server
  137. server.n_slots = n_slots
  138. server.start()
  139. last_res = None
  140. for _ in range(4):
  141. res = server.make_request("POST", "/completion", data={
  142. "prompt": "I believe the meaning of life is",
  143. "seed": 42,
  144. "temperature": 0.0,
  145. "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
  146. })
  147. if last_res is not None:
  148. assert res.body["content"] == last_res.body["content"]
  149. last_res = res
  150. @pytest.mark.parametrize("n_slots", [1, 2])
  151. def test_different_result_different_seed(n_slots: int):
  152. global server
  153. server.n_slots = n_slots
  154. server.start()
  155. last_res = None
  156. for seed in range(4):
  157. res = server.make_request("POST", "/completion", data={
  158. "prompt": "I believe the meaning of life is",
  159. "seed": seed,
  160. "temperature": 1.0,
  161. "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
  162. })
  163. if last_res is not None:
  164. assert res.body["content"] != last_res.body["content"]
  165. last_res = res
  166. # TODO figure why it don't work with temperature = 1
  167. # @pytest.mark.parametrize("temperature", [0.0, 1.0])
  168. @pytest.mark.parametrize("n_batch", [16, 32])
  169. @pytest.mark.parametrize("temperature", [0.0])
  170. def test_consistent_result_different_batch_size(n_batch: int, temperature: float):
  171. global server
  172. server.n_batch = n_batch
  173. server.start()
  174. last_res = None
  175. for _ in range(4):
  176. res = server.make_request("POST", "/completion", data={
  177. "prompt": "I believe the meaning of life is",
  178. "seed": 42,
  179. "temperature": temperature,
  180. "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
  181. })
  182. if last_res is not None:
  183. assert res.body["content"] == last_res.body["content"]
  184. last_res = res
  185. @pytest.mark.skip(reason="This test fails on linux, need to be fixed")
  186. def test_cache_vs_nocache_prompt():
  187. global server
  188. server.start()
  189. res_cache = server.make_request("POST", "/completion", data={
  190. "prompt": "I believe the meaning of life is",
  191. "seed": 42,
  192. "temperature": 1.0,
  193. "cache_prompt": True,
  194. })
  195. res_no_cache = server.make_request("POST", "/completion", data={
  196. "prompt": "I believe the meaning of life is",
  197. "seed": 42,
  198. "temperature": 1.0,
  199. "cache_prompt": False,
  200. })
  201. assert res_cache.body["content"] == res_no_cache.body["content"]
  202. def test_nocache_long_input_prompt():
  203. global server
  204. server.start()
  205. res = server.make_request("POST", "/completion", data={
  206. "prompt": "I believe the meaning of life is"*32,
  207. "seed": 42,
  208. "temperature": 1.0,
  209. "cache_prompt": False,
  210. })
  211. assert res.status_code == 400
  212. def test_json_prompt_no_mtmd():
  213. global server
  214. server.start()
  215. res = server.make_request("POST", "/completion", data={
  216. "prompt": { JSON_PROMPT_STRING_KEY: "I believe the meaning of life is" },
  217. "seed": 42,
  218. "temperature": 1.0,
  219. "cache_prompt": False,
  220. })
  221. assert res.status_code == 200
  222. def test_json_prompt_mtm_error_when_not_supported():
  223. global server
  224. server.start()
  225. res = server.make_request("POST", "/completion", data={
  226. "prompt": { JSON_PROMPT_STRING_KEY: "I believe the meaning of life is <__media__>", JSON_MULTIMODAL_KEY: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=" },
  227. "seed": 42,
  228. "temperature": 1.0,
  229. "cache_prompt": False,
  230. })
  231. # MTMD is disabled on this model, so this should fail.
  232. assert res.status_code != 200
  233. def test_completion_with_tokens_input():
  234. global server
  235. server.temperature = 0.0
  236. server.start()
  237. prompt_str = "I believe the meaning of life is"
  238. res = server.make_request("POST", "/tokenize", data={
  239. "content": prompt_str,
  240. "add_special": True,
  241. })
  242. assert res.status_code == 200
  243. tokens = res.body["tokens"]
  244. # single completion
  245. res = server.make_request("POST", "/completion", data={
  246. "prompt": tokens,
  247. })
  248. assert res.status_code == 200
  249. assert type(res.body["content"]) == str
  250. # batch completion
  251. res = server.make_request("POST", "/completion", data={
  252. "prompt": [tokens, tokens],
  253. })
  254. assert res.status_code == 200
  255. assert type(res.body) == list
  256. assert len(res.body) == 2
  257. assert res.body[0]["content"] == res.body[1]["content"]
  258. # mixed string and tokens
  259. res = server.make_request("POST", "/completion", data={
  260. "prompt": [tokens, prompt_str],
  261. })
  262. assert res.status_code == 200
  263. assert type(res.body) == list
  264. assert len(res.body) == 2
  265. assert res.body[0]["content"] == res.body[1]["content"]
  266. # mixed JSON and tokens
  267. res = server.make_request("POST", "/completion", data={
  268. "prompt": [
  269. tokens,
  270. {
  271. JSON_PROMPT_STRING_KEY: "I believe the meaning of life is",
  272. },
  273. ],
  274. })
  275. assert res.status_code == 200
  276. assert type(res.body) == list
  277. assert len(res.body) == 2
  278. assert res.body[0]["content"] == res.body[1]["content"]
  279. # mixed string and tokens in one sequence
  280. res = server.make_request("POST", "/completion", data={
  281. "prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str],
  282. })
  283. assert res.status_code == 200
  284. assert type(res.body["content"]) == str
  285. @pytest.mark.parametrize("n_slots,n_requests", [
  286. (1, 3),
  287. (2, 2),
  288. (2, 4),
  289. (4, 2), # some slots must be idle
  290. (4, 6),
  291. ])
  292. def test_completion_parallel_slots(n_slots: int, n_requests: int):
  293. global server
  294. server.n_slots = n_slots
  295. server.temperature = 0.0
  296. server.start()
  297. PROMPTS = [
  298. ("Write a very long book.", "(very|special|big)+"),
  299. ("Write another a poem.", "(small|house)+"),
  300. ("What is LLM?", "(Dad|said)+"),
  301. ("The sky is blue and I love it.", "(climb|leaf)+"),
  302. ("Write another very long music lyrics.", "(friends|step|sky)+"),
  303. ("Write a very long joke.", "(cat|Whiskers)+"),
  304. ]
  305. def check_slots_status():
  306. should_all_slots_busy = n_requests >= n_slots
  307. time.sleep(0.1)
  308. res = server.make_request("GET", "/slots")
  309. n_busy = sum([1 for slot in res.body if slot["is_processing"]])
  310. if should_all_slots_busy:
  311. assert n_busy == n_slots
  312. else:
  313. assert n_busy <= n_slots
  314. tasks = []
  315. for i in range(n_requests):
  316. prompt, re_content = PROMPTS[i % len(PROMPTS)]
  317. tasks.append((server.make_request, ("POST", "/completion", {
  318. "prompt": prompt,
  319. "seed": 42,
  320. "temperature": 1.0,
  321. })))
  322. tasks.append((check_slots_status, ()))
  323. results = parallel_function_calls(tasks)
  324. # check results
  325. for i in range(n_requests):
  326. prompt, re_content = PROMPTS[i % len(PROMPTS)]
  327. res = results[i]
  328. assert res.status_code == 200
  329. assert type(res.body["content"]) == str
  330. assert len(res.body["content"]) > 10
  331. # FIXME: the result is not deterministic when using other slot than slot 0
  332. # assert match_regex(re_content, res.body["content"])
  333. @pytest.mark.parametrize(
  334. "prompt,n_predict,response_fields",
  335. [
  336. ("I believe the meaning of life is", 8, []),
  337. ("I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"]),
  338. ],
  339. )
  340. def test_completion_response_fields(
  341. prompt: str, n_predict: int, response_fields: list[str]
  342. ):
  343. global server
  344. server.start()
  345. res = server.make_request(
  346. "POST",
  347. "/completion",
  348. data={
  349. "n_predict": n_predict,
  350. "prompt": prompt,
  351. "response_fields": response_fields,
  352. },
  353. )
  354. assert res.status_code == 200
  355. assert "content" in res.body
  356. assert len(res.body["content"])
  357. if len(response_fields):
  358. assert res.body["generation_settings/n_predict"] == n_predict
  359. assert res.body["prompt"] == "<s> " + prompt
  360. assert isinstance(res.body["content"], str)
  361. assert len(res.body) == len(response_fields)
  362. else:
  363. assert len(res.body)
  364. assert "generation_settings" in res.body
  365. def test_n_probs():
  366. global server
  367. server.start()
  368. res = server.make_request("POST", "/completion", data={
  369. "prompt": "I believe the meaning of life is",
  370. "n_probs": 10,
  371. "temperature": 0.0,
  372. "n_predict": 5,
  373. })
  374. assert res.status_code == 200
  375. assert "completion_probabilities" in res.body
  376. assert len(res.body["completion_probabilities"]) == 5
  377. for tok in res.body["completion_probabilities"]:
  378. assert "id" in tok and tok["id"] > 0
  379. assert "token" in tok and type(tok["token"]) == str
  380. assert "logprob" in tok and tok["logprob"] <= 0.0
  381. assert "bytes" in tok and type(tok["bytes"]) == list
  382. assert len(tok["top_logprobs"]) == 10
  383. for prob in tok["top_logprobs"]:
  384. assert "id" in prob and prob["id"] > 0
  385. assert "token" in prob and type(prob["token"]) == str
  386. assert "logprob" in prob and prob["logprob"] <= 0.0
  387. assert "bytes" in prob and type(prob["bytes"]) == list
  388. def test_n_probs_stream():
  389. global server
  390. server.start()
  391. res = server.make_stream_request("POST", "/completion", data={
  392. "prompt": "I believe the meaning of life is",
  393. "n_probs": 10,
  394. "temperature": 0.0,
  395. "n_predict": 5,
  396. "stream": True,
  397. })
  398. for data in res:
  399. if data["stop"] == False:
  400. assert "completion_probabilities" in data
  401. assert len(data["completion_probabilities"]) == 1
  402. for tok in data["completion_probabilities"]:
  403. assert "id" in tok and tok["id"] > 0
  404. assert "token" in tok and type(tok["token"]) == str
  405. assert "logprob" in tok and tok["logprob"] <= 0.0
  406. assert "bytes" in tok and type(tok["bytes"]) == list
  407. assert len(tok["top_logprobs"]) == 10
  408. for prob in tok["top_logprobs"]:
  409. assert "id" in prob and prob["id"] > 0
  410. assert "token" in prob and type(prob["token"]) == str
  411. assert "logprob" in prob and prob["logprob"] <= 0.0
  412. assert "bytes" in prob and type(prob["bytes"]) == list
  413. def test_n_probs_post_sampling():
  414. global server
  415. server.start()
  416. res = server.make_request("POST", "/completion", data={
  417. "prompt": "I believe the meaning of life is",
  418. "n_probs": 10,
  419. "temperature": 0.0,
  420. "n_predict": 5,
  421. "post_sampling_probs": True,
  422. })
  423. assert res.status_code == 200
  424. assert "completion_probabilities" in res.body
  425. assert len(res.body["completion_probabilities"]) == 5
  426. for tok in res.body["completion_probabilities"]:
  427. assert "id" in tok and tok["id"] > 0
  428. assert "token" in tok and type(tok["token"]) == str
  429. assert "prob" in tok and 0.0 < tok["prob"] <= 1.0
  430. assert "bytes" in tok and type(tok["bytes"]) == list
  431. assert len(tok["top_probs"]) == 10
  432. for prob in tok["top_probs"]:
  433. assert "id" in prob and prob["id"] > 0
  434. assert "token" in prob and type(prob["token"]) == str
  435. assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0
  436. assert "bytes" in prob and type(prob["bytes"]) == list
  437. # because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
  438. assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
  439. @pytest.mark.parametrize("tokenize,openai_style", [(False, False), (False, True), (True, False), (True, True)])
  440. def test_logit_bias(tokenize, openai_style):
  441. global server
  442. server.start()
  443. exclude = ["i", "I", "the", "The", "to", "a", "an", "be", "is", "was", "but", "But", "and", "And", "so", "So", "you", "You", "he", "He", "she", "She", "we", "We", "they", "They", "it", "It", "his", "His", "her", "Her", "book", "Book"]
  444. logit_bias = []
  445. if tokenize:
  446. res = server.make_request("POST", "/tokenize", data={
  447. "content": " " + " ".join(exclude) + " ",
  448. })
  449. assert res.status_code == 200
  450. tokens = res.body["tokens"]
  451. logit_bias = [[tok, -100] for tok in tokens]
  452. else:
  453. logit_bias = [[" " + tok + " ", -100] for tok in exclude]
  454. if openai_style:
  455. logit_bias = {el[0]: -100 for el in logit_bias}
  456. res = server.make_request("POST", "/completion", data={
  457. "n_predict": 64,
  458. "prompt": "What is the best book",
  459. "logit_bias": logit_bias,
  460. "temperature": 0.0
  461. })
  462. assert res.status_code == 200
  463. output_text = res.body["content"]
  464. assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude)
  465. def test_cancel_request():
  466. global server
  467. server.n_ctx = 4096
  468. server.n_predict = -1
  469. server.n_slots = 1
  470. server.server_slots = True
  471. server.start()
  472. # send a request that will take a long time, but cancel it before it finishes
  473. try:
  474. server.make_request("POST", "/completion", data={
  475. "prompt": "I believe the meaning of life is",
  476. }, timeout=0.1)
  477. except requests.exceptions.ReadTimeout:
  478. pass # expected
  479. # make sure the slot is free
  480. time.sleep(1) # wait for HTTP_POLLING_SECONDS
  481. res = server.make_request("GET", "/slots")
  482. assert res.body[0]["is_processing"] == False