test_chat_completion.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. import pytest
  2. from openai import OpenAI
  3. from utils import *
  4. server: ServerProcess
  5. @pytest.fixture(autouse=True)
  6. def create_server():
  7. global server
  8. server = ServerPreset.tinyllama2()
  9. @pytest.mark.parametrize(
  10. "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template",
  11. [
  12. (None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", False, None),
  13. (None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", True, None),
  14. (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", False, None),
  15. (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None),
  16. (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'),
  17. (None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
  18. ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", False, None),
  19. ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", True, None),
  20. (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
  21. (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
  22. ]
  23. )
  24. def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
  25. global server
  26. server.jinja = jinja
  27. server.chat_template = chat_template
  28. server.start()
  29. res = server.make_request("POST", "/chat/completions", data={
  30. "model": model,
  31. "max_tokens": max_tokens,
  32. "messages": [
  33. {"role": "system", "content": system_prompt},
  34. {"role": "user", "content": user_prompt},
  35. ],
  36. })
  37. assert res.status_code == 200
  38. assert "cmpl" in res.body["id"] # make sure the completion id has the expected format
  39. assert res.body["system_fingerprint"].startswith("b")
  40. assert res.body["model"] == model if model is not None else server.model_alias
  41. assert res.body["usage"]["prompt_tokens"] == n_prompt
  42. assert res.body["usage"]["completion_tokens"] == n_predicted
  43. choice = res.body["choices"][0]
  44. assert "assistant" == choice["message"]["role"]
  45. assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
  46. assert choice["finish_reason"] == finish_reason
  47. @pytest.mark.parametrize(
  48. "system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
  49. [
  50. ("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
  51. ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length"),
  52. ]
  53. )
  54. def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
  55. global server
  56. server.model_alias = None # try using DEFAULT_OAICOMPAT_MODEL
  57. server.start()
  58. res = server.make_stream_request("POST", "/chat/completions", data={
  59. "max_tokens": max_tokens,
  60. "messages": [
  61. {"role": "system", "content": system_prompt},
  62. {"role": "user", "content": user_prompt},
  63. ],
  64. "stream": True,
  65. })
  66. content = ""
  67. last_cmpl_id = None
  68. for i, data in enumerate(res):
  69. if data["choices"]:
  70. choice = data["choices"][0]
  71. if i == 0:
  72. # Check first role message for stream=True
  73. assert choice["delta"]["content"] is None
  74. assert choice["delta"]["role"] == "assistant"
  75. else:
  76. assert "role" not in choice["delta"]
  77. assert data["system_fingerprint"].startswith("b")
  78. assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
  79. if last_cmpl_id is None:
  80. last_cmpl_id = data["id"]
  81. assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
  82. if choice["finish_reason"] in ["stop", "length"]:
  83. assert "content" not in choice["delta"]
  84. assert match_regex(re_content, content)
  85. assert choice["finish_reason"] == finish_reason
  86. else:
  87. assert choice["finish_reason"] is None
  88. content += choice["delta"]["content"] or ''
  89. else:
  90. assert data["usage"]["prompt_tokens"] == n_prompt
  91. assert data["usage"]["completion_tokens"] == n_predicted
  92. def test_chat_completion_with_openai_library():
  93. global server
  94. server.start()
  95. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  96. res = client.chat.completions.create(
  97. model="gpt-3.5-turbo-instruct",
  98. messages=[
  99. {"role": "system", "content": "Book"},
  100. {"role": "user", "content": "What is the best book"},
  101. ],
  102. max_tokens=8,
  103. seed=42,
  104. temperature=0.8,
  105. )
  106. assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
  107. assert res.choices[0].finish_reason == "length"
  108. assert res.choices[0].message.content is not None
  109. assert match_regex("(Suddenly)+", res.choices[0].message.content)
  110. def test_chat_template():
  111. global server
  112. server.chat_template = "llama3"
  113. server.debug = True # to get the "__verbose" object in the response
  114. server.start()
  115. res = server.make_request("POST", "/chat/completions", data={
  116. "max_tokens": 8,
  117. "messages": [
  118. {"role": "system", "content": "Book"},
  119. {"role": "user", "content": "What is the best book"},
  120. ]
  121. })
  122. assert res.status_code == 200
  123. assert "__verbose" in res.body
  124. assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
  125. @pytest.mark.parametrize("prefill,re_prefill", [
  126. ("Whill", "Whill"),
  127. ([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Whill"),
  128. ])
  129. def test_chat_template_assistant_prefill(prefill, re_prefill):
  130. global server
  131. server.chat_template = "llama3"
  132. server.debug = True # to get the "__verbose" object in the response
  133. server.start()
  134. res = server.make_request("POST", "/chat/completions", data={
  135. "max_tokens": 8,
  136. "messages": [
  137. {"role": "system", "content": "Book"},
  138. {"role": "user", "content": "What is the best book"},
  139. {"role": "assistant", "content": prefill},
  140. ]
  141. })
  142. assert res.status_code == 200
  143. assert "__verbose" in res.body
  144. assert res.body["__verbose"]["prompt"] == f"<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}"
  145. def test_apply_chat_template():
  146. global server
  147. server.chat_template = "command-r"
  148. server.start()
  149. res = server.make_request("POST", "/apply-template", data={
  150. "messages": [
  151. {"role": "system", "content": "You are a test."},
  152. {"role": "user", "content":"Hi there"},
  153. ]
  154. })
  155. assert res.status_code == 200
  156. assert "prompt" in res.body
  157. assert res.body["prompt"] == "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a test.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
  158. @pytest.mark.parametrize("response_format,n_predicted,re_content", [
  159. ({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
  160. ({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
  161. ({"type": "json_schema", "json_schema": {"schema": {"const": "foooooo"}}}, 10, "\"foooooo\""),
  162. ({"type": "json_object"}, 10, "(\\{|John)+"),
  163. ({"type": "sound"}, 0, None),
  164. # invalid response format (expected to fail)
  165. ({"type": "json_object", "schema": 123}, 0, None),
  166. ({"type": "json_object", "schema": {"type": 123}}, 0, None),
  167. ({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None),
  168. ])
  169. def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None):
  170. global server
  171. server.start()
  172. res = server.make_request("POST", "/chat/completions", data={
  173. "max_tokens": n_predicted,
  174. "messages": [
  175. {"role": "system", "content": "You are a coding assistant."},
  176. {"role": "user", "content": "Write an example"},
  177. ],
  178. "response_format": response_format,
  179. })
  180. if re_content is not None:
  181. assert res.status_code == 200
  182. choice = res.body["choices"][0]
  183. assert match_regex(re_content, choice["message"]["content"])
  184. else:
  185. assert res.status_code != 200
  186. assert "error" in res.body
  187. @pytest.mark.parametrize("jinja,json_schema,n_predicted,re_content", [
  188. (False, {"const": "42"}, 6, "\"42\""),
  189. (True, {"const": "42"}, 6, "\"42\""),
  190. ])
  191. def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str):
  192. global server
  193. server.jinja = jinja
  194. server.start()
  195. res = server.make_request("POST", "/chat/completions", data={
  196. "max_tokens": n_predicted,
  197. "messages": [
  198. {"role": "system", "content": "You are a coding assistant."},
  199. {"role": "user", "content": "Write an example"},
  200. ],
  201. "json_schema": json_schema,
  202. })
  203. assert res.status_code == 200, f'Expected 200, got {res.status_code}'
  204. choice = res.body["choices"][0]
  205. assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
  206. @pytest.mark.parametrize("jinja,grammar,n_predicted,re_content", [
  207. (False, 'root ::= "a"{5,5}', 6, "a{5,5}"),
  208. (True, 'root ::= "a"{5,5}', 6, "a{5,5}"),
  209. ])
  210. def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re_content: str):
  211. global server
  212. server.jinja = jinja
  213. server.start()
  214. res = server.make_request("POST", "/chat/completions", data={
  215. "max_tokens": n_predicted,
  216. "messages": [
  217. {"role": "user", "content": "Does not matter what I say, does it?"},
  218. ],
  219. "grammar": grammar,
  220. })
  221. assert res.status_code == 200, res.body
  222. choice = res.body["choices"][0]
  223. assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"]
  224. @pytest.mark.parametrize("messages", [
  225. None,
  226. "string",
  227. [123],
  228. [{}],
  229. [{"role": 123}],
  230. [{"role": "system", "content": 123}],
  231. # [{"content": "hello"}], # TODO: should not be a valid case
  232. [{"role": "system", "content": "test"}, {}],
  233. [{"role": "user", "content": "test"}, {"role": "assistant", "content": "test"}, {"role": "assistant", "content": "test"}],
  234. ])
  235. def test_invalid_chat_completion_req(messages):
  236. global server
  237. server.start()
  238. res = server.make_request("POST", "/chat/completions", data={
  239. "messages": messages,
  240. })
  241. assert res.status_code == 400 or res.status_code == 500
  242. assert "error" in res.body
  243. def test_chat_completion_with_timings_per_token():
  244. global server
  245. server.start()
  246. res = server.make_stream_request("POST", "/chat/completions", data={
  247. "max_tokens": 10,
  248. "messages": [{"role": "user", "content": "test"}],
  249. "stream": True,
  250. "stream_options": {"include_usage": True},
  251. "timings_per_token": True,
  252. })
  253. stats_received = False
  254. for i, data in enumerate(res):
  255. if i == 0:
  256. # Check first role message for stream=True
  257. assert data["choices"][0]["delta"]["content"] is None
  258. assert data["choices"][0]["delta"]["role"] == "assistant"
  259. assert "timings" not in data, f'First event should not have timings: {data}'
  260. else:
  261. if data["choices"]:
  262. assert "role" not in data["choices"][0]["delta"]
  263. else:
  264. assert "timings" in data
  265. assert "prompt_per_second" in data["timings"]
  266. assert "predicted_per_second" in data["timings"]
  267. assert "predicted_n" in data["timings"]
  268. assert data["timings"]["predicted_n"] <= 10
  269. stats_received = True
  270. assert stats_received
  271. def test_logprobs():
  272. global server
  273. server.start()
  274. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  275. res = client.chat.completions.create(
  276. model="gpt-3.5-turbo-instruct",
  277. temperature=0.0,
  278. messages=[
  279. {"role": "system", "content": "Book"},
  280. {"role": "user", "content": "What is the best book"},
  281. ],
  282. max_tokens=5,
  283. logprobs=True,
  284. top_logprobs=10,
  285. )
  286. output_text = res.choices[0].message.content
  287. aggregated_text = ''
  288. assert res.choices[0].logprobs is not None
  289. assert res.choices[0].logprobs.content is not None
  290. for token in res.choices[0].logprobs.content:
  291. aggregated_text += token.token
  292. assert token.logprob <= 0.0
  293. assert token.bytes is not None
  294. assert len(token.top_logprobs) > 0
  295. assert aggregated_text == output_text
  296. def test_logprobs_stream():
  297. global server
  298. server.start()
  299. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  300. res = client.chat.completions.create(
  301. model="gpt-3.5-turbo-instruct",
  302. temperature=0.0,
  303. messages=[
  304. {"role": "system", "content": "Book"},
  305. {"role": "user", "content": "What is the best book"},
  306. ],
  307. max_tokens=5,
  308. logprobs=True,
  309. top_logprobs=10,
  310. stream=True,
  311. )
  312. output_text = ''
  313. aggregated_text = ''
  314. for i, data in enumerate(res):
  315. if data.choices:
  316. choice = data.choices[0]
  317. if i == 0:
  318. # Check first role message for stream=True
  319. assert choice.delta.content is None
  320. assert choice.delta.role == "assistant"
  321. else:
  322. assert choice.delta.role is None
  323. if choice.finish_reason is None:
  324. if choice.delta.content:
  325. output_text += choice.delta.content
  326. assert choice.logprobs is not None
  327. assert choice.logprobs.content is not None
  328. for token in choice.logprobs.content:
  329. aggregated_text += token.token
  330. assert token.logprob <= 0.0
  331. assert token.bytes is not None
  332. assert token.top_logprobs is not None
  333. assert len(token.top_logprobs) > 0
  334. assert aggregated_text == output_text
  335. def test_logit_bias():
  336. global server
  337. server.start()
  338. exclude = ["i", "I", "the", "The", "to", "a", "an", "be", "is", "was", "but", "But", "and", "And", "so", "So", "you", "You", "he", "He", "she", "She", "we", "We", "they", "They", "it", "It", "his", "His", "her", "Her", "book", "Book"]
  339. res = server.make_request("POST", "/tokenize", data={
  340. "content": " " + " ".join(exclude) + " ",
  341. })
  342. assert res.status_code == 200
  343. tokens = res.body["tokens"]
  344. logit_bias = {tok: -100 for tok in tokens}
  345. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  346. res = client.chat.completions.create(
  347. model="gpt-3.5-turbo-instruct",
  348. temperature=0.0,
  349. messages=[
  350. {"role": "system", "content": "Book"},
  351. {"role": "user", "content": "What is the best book"},
  352. ],
  353. max_tokens=64,
  354. logit_bias=logit_bias
  355. )
  356. output_text = res.choices[0].message.content
  357. assert output_text
  358. assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude)
  359. def test_context_size_exceeded():
  360. global server
  361. server.start()
  362. res = server.make_request("POST", "/chat/completions", data={
  363. "messages": [
  364. {"role": "system", "content": "Book"},
  365. {"role": "user", "content": "What is the best book"},
  366. ] * 100, # make the prompt too long
  367. })
  368. assert res.status_code == 400
  369. assert "error" in res.body
  370. assert res.body["error"]["type"] == "exceed_context_size_error"
  371. assert res.body["error"]["n_prompt_tokens"] > 0
  372. assert server.n_ctx is not None
  373. assert server.n_slots is not None
  374. assert res.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
  375. def test_context_size_exceeded_stream():
  376. global server
  377. server.start()
  378. try:
  379. for _ in server.make_stream_request("POST", "/chat/completions", data={
  380. "messages": [
  381. {"role": "system", "content": "Book"},
  382. {"role": "user", "content": "What is the best book"},
  383. ] * 100, # make the prompt too long
  384. "stream": True}):
  385. pass
  386. assert False, "Should have failed"
  387. except ServerError as e:
  388. assert e.code == 400
  389. assert "error" in e.body
  390. assert e.body["error"]["type"] == "exceed_context_size_error"
  391. assert e.body["error"]["n_prompt_tokens"] > 0
  392. assert server.n_ctx is not None
  393. assert server.n_slots is not None
  394. assert e.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
  395. @pytest.mark.parametrize(
  396. "n_batch,batch_count,reuse_cache",
  397. [
  398. (64, 3, False),
  399. (64, 1, True),
  400. ]
  401. )
  402. def test_return_progress(n_batch, batch_count, reuse_cache):
  403. global server
  404. server.n_batch = n_batch
  405. server.n_ctx = 256
  406. server.n_slots = 1
  407. server.start()
  408. def make_cmpl_request():
  409. return server.make_stream_request("POST", "/chat/completions", data={
  410. "max_tokens": 10,
  411. "messages": [
  412. {"role": "user", "content": "This is a test" * 10},
  413. ],
  414. "stream": True,
  415. "return_progress": True,
  416. })
  417. if reuse_cache:
  418. # make a first request to populate the cache
  419. res0 = make_cmpl_request()
  420. for _ in res0:
  421. pass # discard the output
  422. res = make_cmpl_request()
  423. last_progress = None
  424. total_batch_count = 0
  425. for data in res:
  426. cur_progress = data.get("prompt_progress", None)
  427. if cur_progress is None:
  428. continue
  429. if last_progress is not None:
  430. assert cur_progress["total"] == last_progress["total"]
  431. assert cur_progress["cache"] == last_progress["cache"]
  432. assert cur_progress["processed"] > last_progress["processed"]
  433. total_batch_count += 1
  434. last_progress = cur_progress
  435. assert last_progress is not None
  436. assert last_progress["total"] > 0
  437. assert last_progress["processed"] == last_progress["total"]
  438. assert total_batch_count == batch_count