test_chat_completion.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. import pytest
  2. from openai import OpenAI
  3. from utils import *
  4. server: ServerProcess
  5. @pytest.fixture(autouse=True)
  6. def create_server():
  7. global server
  8. server = ServerPreset.tinyllama2()
  9. @pytest.mark.parametrize(
  10. "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template",
  11. [
  12. (None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", False, None),
  13. (None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", True, None),
  14. (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", False, None),
  15. (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None),
  16. (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'),
  17. (None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
  18. ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
  19. ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
  20. (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
  21. (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
  22. ]
  23. )
  24. def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
  25. global server
  26. server.jinja = jinja
  27. server.chat_template = chat_template
  28. server.start()
  29. res = server.make_request("POST", "/chat/completions", data={
  30. "model": model,
  31. "max_tokens": max_tokens,
  32. "messages": [
  33. {"role": "system", "content": system_prompt},
  34. {"role": "user", "content": user_prompt},
  35. ],
  36. })
  37. assert res.status_code == 200
  38. assert "cmpl" in res.body["id"] # make sure the completion id has the expected format
  39. assert res.body["system_fingerprint"].startswith("b")
  40. assert res.body["model"] == model if model is not None else server.model_alias
  41. assert res.body["usage"]["prompt_tokens"] == n_prompt
  42. assert res.body["usage"]["completion_tokens"] == n_predicted
  43. choice = res.body["choices"][0]
  44. assert "assistant" == choice["message"]["role"]
  45. assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
  46. assert choice["finish_reason"] == finish_reason
  47. @pytest.mark.parametrize(
  48. "system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
  49. [
  50. ("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
  51. ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
  52. ]
  53. )
  54. def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
  55. global server
  56. server.model_alias = None # try using DEFAULT_OAICOMPAT_MODEL
  57. server.start()
  58. res = server.make_stream_request("POST", "/chat/completions", data={
  59. "max_tokens": max_tokens,
  60. "messages": [
  61. {"role": "system", "content": system_prompt},
  62. {"role": "user", "content": user_prompt},
  63. ],
  64. "stream": True,
  65. })
  66. content = ""
  67. last_cmpl_id = None
  68. for i, data in enumerate(res):
  69. choice = data["choices"][0]
  70. if i == 0:
  71. # Check first role message for stream=True
  72. assert choice["delta"]["content"] is None
  73. assert choice["delta"]["role"] == "assistant"
  74. else:
  75. assert "role" not in choice["delta"]
  76. assert data["system_fingerprint"].startswith("b")
  77. assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
  78. if last_cmpl_id is None:
  79. last_cmpl_id = data["id"]
  80. assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
  81. if choice["finish_reason"] in ["stop", "length"]:
  82. assert data["usage"]["prompt_tokens"] == n_prompt
  83. assert data["usage"]["completion_tokens"] == n_predicted
  84. assert "content" not in choice["delta"]
  85. assert match_regex(re_content, content)
  86. assert choice["finish_reason"] == finish_reason
  87. else:
  88. assert choice["finish_reason"] is None
  89. content += choice["delta"]["content"] or ''
  90. def test_chat_completion_with_openai_library():
  91. global server
  92. server.start()
  93. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  94. res = client.chat.completions.create(
  95. model="gpt-3.5-turbo-instruct",
  96. messages=[
  97. {"role": "system", "content": "Book"},
  98. {"role": "user", "content": "What is the best book"},
  99. ],
  100. max_tokens=8,
  101. seed=42,
  102. temperature=0.8,
  103. )
  104. assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
  105. assert res.choices[0].finish_reason == "length"
  106. assert res.choices[0].message.content is not None
  107. assert match_regex("(Suddenly)+", res.choices[0].message.content)
  108. def test_chat_template():
  109. global server
  110. server.chat_template = "llama3"
  111. server.debug = True # to get the "__verbose" object in the response
  112. server.start()
  113. res = server.make_request("POST", "/chat/completions", data={
  114. "max_tokens": 8,
  115. "messages": [
  116. {"role": "system", "content": "Book"},
  117. {"role": "user", "content": "What is the best book"},
  118. ]
  119. })
  120. assert res.status_code == 200
  121. assert "__verbose" in res.body
  122. assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
  123. @pytest.mark.parametrize("prefill,re_prefill", [
  124. ("Whill", "Whill"),
  125. ([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Whill"),
  126. ])
  127. def test_chat_template_assistant_prefill(prefill, re_prefill):
  128. global server
  129. server.chat_template = "llama3"
  130. server.debug = True # to get the "__verbose" object in the response
  131. server.start()
  132. res = server.make_request("POST", "/chat/completions", data={
  133. "max_tokens": 8,
  134. "messages": [
  135. {"role": "system", "content": "Book"},
  136. {"role": "user", "content": "What is the best book"},
  137. {"role": "assistant", "content": prefill},
  138. ]
  139. })
  140. assert res.status_code == 200
  141. assert "__verbose" in res.body
  142. assert res.body["__verbose"]["prompt"] == f"<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}"
  143. def test_apply_chat_template():
  144. global server
  145. server.chat_template = "command-r"
  146. server.start()
  147. res = server.make_request("POST", "/apply-template", data={
  148. "messages": [
  149. {"role": "system", "content": "You are a test."},
  150. {"role": "user", "content":"Hi there"},
  151. ]
  152. })
  153. assert res.status_code == 200
  154. assert "prompt" in res.body
  155. assert res.body["prompt"] == "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a test.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
  156. @pytest.mark.parametrize("response_format,n_predicted,re_content", [
  157. ({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
  158. ({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
  159. ({"type": "json_schema", "json_schema": {"schema": {"const": "foooooo"}}}, 10, "\"foooooo\""),
  160. ({"type": "json_object"}, 10, "(\\{|John)+"),
  161. ({"type": "sound"}, 0, None),
  162. # invalid response format (expected to fail)
  163. ({"type": "json_object", "schema": 123}, 0, None),
  164. ({"type": "json_object", "schema": {"type": 123}}, 0, None),
  165. ({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None),
  166. ])
  167. def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None):
  168. global server
  169. server.start()
  170. res = server.make_request("POST", "/chat/completions", data={
  171. "max_tokens": n_predicted,
  172. "messages": [
  173. {"role": "system", "content": "You are a coding assistant."},
  174. {"role": "user", "content": "Write an example"},
  175. ],
  176. "response_format": response_format,
  177. })
  178. if re_content is not None:
  179. assert res.status_code == 200
  180. choice = res.body["choices"][0]
  181. assert match_regex(re_content, choice["message"]["content"])
  182. else:
  183. assert res.status_code != 200
  184. assert "error" in res.body
  185. @pytest.mark.parametrize("jinja,json_schema,n_predicted,re_content", [
  186. (False, {"const": "42"}, 6, "\"42\""),
  187. (True, {"const": "42"}, 6, "\"42\""),
  188. ])
  189. def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str):
  190. global server
  191. server.jinja = jinja
  192. server.start()
  193. res = server.make_request("POST", "/chat/completions", data={
  194. "max_tokens": n_predicted,
  195. "messages": [
  196. {"role": "system", "content": "You are a coding assistant."},
  197. {"role": "user", "content": "Write an example"},
  198. ],
  199. "json_schema": json_schema,
  200. })
  201. assert res.status_code == 200, f'Expected 200, got {res.status_code}'
  202. choice = res.body["choices"][0]
  203. assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
  204. @pytest.mark.parametrize("jinja,grammar,n_predicted,re_content", [
  205. (False, 'root ::= "a"{5,5}', 6, "a{5,5}"),
  206. (True, 'root ::= "a"{5,5}', 6, "a{5,5}"),
  207. ])
  208. def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re_content: str):
  209. global server
  210. server.jinja = jinja
  211. server.start()
  212. res = server.make_request("POST", "/chat/completions", data={
  213. "max_tokens": n_predicted,
  214. "messages": [
  215. {"role": "user", "content": "Does not matter what I say, does it?"},
  216. ],
  217. "grammar": grammar,
  218. })
  219. assert res.status_code == 200, res.body
  220. choice = res.body["choices"][0]
  221. assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"]
  222. @pytest.mark.parametrize("messages", [
  223. None,
  224. "string",
  225. [123],
  226. [{}],
  227. [{"role": 123}],
  228. [{"role": "system", "content": 123}],
  229. # [{"content": "hello"}], # TODO: should not be a valid case
  230. [{"role": "system", "content": "test"}, {}],
  231. [{"role": "user", "content": "test"}, {"role": "assistant", "content": "test"}, {"role": "assistant", "content": "test"}],
  232. ])
  233. def test_invalid_chat_completion_req(messages):
  234. global server
  235. server.start()
  236. res = server.make_request("POST", "/chat/completions", data={
  237. "messages": messages,
  238. })
  239. assert res.status_code == 400 or res.status_code == 500
  240. assert "error" in res.body
  241. def test_chat_completion_with_timings_per_token():
  242. global server
  243. server.start()
  244. res = server.make_stream_request("POST", "/chat/completions", data={
  245. "max_tokens": 10,
  246. "messages": [{"role": "user", "content": "test"}],
  247. "stream": True,
  248. "timings_per_token": True,
  249. })
  250. for i, data in enumerate(res):
  251. if i == 0:
  252. # Check first role message for stream=True
  253. assert data["choices"][0]["delta"]["content"] is None
  254. assert data["choices"][0]["delta"]["role"] == "assistant"
  255. assert "timings" not in data, f'First event should not have timings: {data}'
  256. else:
  257. assert "role" not in data["choices"][0]["delta"]
  258. assert "timings" in data
  259. assert "prompt_per_second" in data["timings"]
  260. assert "predicted_per_second" in data["timings"]
  261. assert "predicted_n" in data["timings"]
  262. assert data["timings"]["predicted_n"] <= 10
  263. def test_logprobs():
  264. global server
  265. server.start()
  266. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  267. res = client.chat.completions.create(
  268. model="gpt-3.5-turbo-instruct",
  269. temperature=0.0,
  270. messages=[
  271. {"role": "system", "content": "Book"},
  272. {"role": "user", "content": "What is the best book"},
  273. ],
  274. max_tokens=5,
  275. logprobs=True,
  276. top_logprobs=10,
  277. )
  278. output_text = res.choices[0].message.content
  279. aggregated_text = ''
  280. assert res.choices[0].logprobs is not None
  281. assert res.choices[0].logprobs.content is not None
  282. for token in res.choices[0].logprobs.content:
  283. aggregated_text += token.token
  284. assert token.logprob <= 0.0
  285. assert token.bytes is not None
  286. assert len(token.top_logprobs) > 0
  287. assert aggregated_text == output_text
  288. def test_logprobs_stream():
  289. global server
  290. server.start()
  291. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  292. res = client.chat.completions.create(
  293. model="gpt-3.5-turbo-instruct",
  294. temperature=0.0,
  295. messages=[
  296. {"role": "system", "content": "Book"},
  297. {"role": "user", "content": "What is the best book"},
  298. ],
  299. max_tokens=5,
  300. logprobs=True,
  301. top_logprobs=10,
  302. stream=True,
  303. )
  304. output_text = ''
  305. aggregated_text = ''
  306. for i, data in enumerate(res):
  307. choice = data.choices[0]
  308. if i == 0:
  309. # Check first role message for stream=True
  310. assert choice.delta.content is None
  311. assert choice.delta.role == "assistant"
  312. else:
  313. assert choice.delta.role is None
  314. if choice.finish_reason is None:
  315. if choice.delta.content:
  316. output_text += choice.delta.content
  317. assert choice.logprobs is not None
  318. assert choice.logprobs.content is not None
  319. for token in choice.logprobs.content:
  320. aggregated_text += token.token
  321. assert token.logprob <= 0.0
  322. assert token.bytes is not None
  323. assert token.top_logprobs is not None
  324. assert len(token.top_logprobs) > 0
  325. assert aggregated_text == output_text