test_chat_completion.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. import pytest
  2. from openai import OpenAI
  3. from utils import *
  4. server: ServerProcess
  5. @pytest.fixture(autouse=True)
  6. def create_server():
  7. global server
  8. server = ServerPreset.tinyllama2()
  9. @pytest.mark.parametrize(
  10. "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template",
  11. [
  12. (None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", False, None),
  13. (None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", True, None),
  14. (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", False, None),
  15. (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None),
  16. (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'),
  17. (None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
  18. ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
  19. ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
  20. ]
  21. )
  22. def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
  23. global server
  24. server.jinja = jinja
  25. server.chat_template = chat_template
  26. server.start()
  27. res = server.make_request("POST", "/chat/completions", data={
  28. "model": model,
  29. "max_tokens": max_tokens,
  30. "messages": [
  31. {"role": "system", "content": system_prompt},
  32. {"role": "user", "content": user_prompt},
  33. ],
  34. })
  35. assert res.status_code == 200
  36. assert "cmpl" in res.body["id"] # make sure the completion id has the expected format
  37. assert res.body["system_fingerprint"].startswith("b")
  38. assert res.body["model"] == model if model is not None else server.model_alias
  39. assert res.body["usage"]["prompt_tokens"] == n_prompt
  40. assert res.body["usage"]["completion_tokens"] == n_predicted
  41. choice = res.body["choices"][0]
  42. assert "assistant" == choice["message"]["role"]
  43. assert match_regex(re_content, choice["message"]["content"])
  44. assert choice["finish_reason"] == finish_reason
  45. @pytest.mark.parametrize(
  46. "system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
  47. [
  48. ("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
  49. ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
  50. ]
  51. )
  52. def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
  53. global server
  54. server.model_alias = None # try using DEFAULT_OAICOMPAT_MODEL
  55. server.start()
  56. res = server.make_stream_request("POST", "/chat/completions", data={
  57. "max_tokens": max_tokens,
  58. "messages": [
  59. {"role": "system", "content": system_prompt},
  60. {"role": "user", "content": user_prompt},
  61. ],
  62. "stream": True,
  63. })
  64. content = ""
  65. last_cmpl_id = None
  66. for data in res:
  67. choice = data["choices"][0]
  68. assert data["system_fingerprint"].startswith("b")
  69. assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
  70. if last_cmpl_id is None:
  71. last_cmpl_id = data["id"]
  72. assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
  73. if choice["finish_reason"] in ["stop", "length"]:
  74. assert data["usage"]["prompt_tokens"] == n_prompt
  75. assert data["usage"]["completion_tokens"] == n_predicted
  76. assert "content" not in choice["delta"]
  77. assert match_regex(re_content, content)
  78. assert choice["finish_reason"] == finish_reason
  79. else:
  80. assert choice["finish_reason"] is None
  81. content += choice["delta"]["content"]
  82. def test_chat_completion_with_openai_library():
  83. global server
  84. server.start()
  85. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  86. res = client.chat.completions.create(
  87. model="gpt-3.5-turbo-instruct",
  88. messages=[
  89. {"role": "system", "content": "Book"},
  90. {"role": "user", "content": "What is the best book"},
  91. ],
  92. max_tokens=8,
  93. seed=42,
  94. temperature=0.8,
  95. )
  96. assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
  97. assert res.choices[0].finish_reason == "length"
  98. assert res.choices[0].message.content is not None
  99. assert match_regex("(Suddenly)+", res.choices[0].message.content)
  100. def test_chat_template():
  101. global server
  102. server.chat_template = "llama3"
  103. server.debug = True # to get the "__verbose" object in the response
  104. server.start()
  105. res = server.make_request("POST", "/chat/completions", data={
  106. "max_tokens": 8,
  107. "messages": [
  108. {"role": "system", "content": "Book"},
  109. {"role": "user", "content": "What is the best book"},
  110. ]
  111. })
  112. assert res.status_code == 200
  113. assert "__verbose" in res.body
  114. assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
  115. def test_apply_chat_template():
  116. global server
  117. server.chat_template = "command-r"
  118. server.start()
  119. res = server.make_request("POST", "/apply-template", data={
  120. "messages": [
  121. {"role": "system", "content": "You are a test."},
  122. {"role": "user", "content":"Hi there"},
  123. ]
  124. })
  125. assert res.status_code == 200
  126. assert "prompt" in res.body
  127. assert res.body["prompt"] == "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a test.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
  128. @pytest.mark.parametrize("response_format,n_predicted,re_content", [
  129. ({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
  130. ({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
  131. ({"type": "json_object"}, 10, "(\\{|John)+"),
  132. ({"type": "sound"}, 0, None),
  133. # invalid response format (expected to fail)
  134. ({"type": "json_object", "schema": 123}, 0, None),
  135. ({"type": "json_object", "schema": {"type": 123}}, 0, None),
  136. ({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None),
  137. ])
  138. def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None):
  139. global server
  140. server.start()
  141. res = server.make_request("POST", "/chat/completions", data={
  142. "max_tokens": n_predicted,
  143. "messages": [
  144. {"role": "system", "content": "You are a coding assistant."},
  145. {"role": "user", "content": "Write an example"},
  146. ],
  147. "response_format": response_format,
  148. })
  149. if re_content is not None:
  150. assert res.status_code == 200
  151. choice = res.body["choices"][0]
  152. assert match_regex(re_content, choice["message"]["content"])
  153. else:
  154. assert res.status_code != 200
  155. assert "error" in res.body
  156. @pytest.mark.parametrize("messages", [
  157. None,
  158. "string",
  159. [123],
  160. [{}],
  161. [{"role": 123}],
  162. [{"role": "system", "content": 123}],
  163. # [{"content": "hello"}], # TODO: should not be a valid case
  164. [{"role": "system", "content": "test"}, {}],
  165. ])
  166. def test_invalid_chat_completion_req(messages):
  167. global server
  168. server.start()
  169. res = server.make_request("POST", "/chat/completions", data={
  170. "messages": messages,
  171. })
  172. assert res.status_code == 400 or res.status_code == 500
  173. assert "error" in res.body
  174. def test_chat_completion_with_timings_per_token():
  175. global server
  176. server.start()
  177. res = server.make_stream_request("POST", "/chat/completions", data={
  178. "max_tokens": 10,
  179. "messages": [{"role": "user", "content": "test"}],
  180. "stream": True,
  181. "timings_per_token": True,
  182. })
  183. for data in res:
  184. assert "timings" in data
  185. assert "prompt_per_second" in data["timings"]
  186. assert "predicted_per_second" in data["timings"]
  187. assert "predicted_n" in data["timings"]
  188. assert data["timings"]["predicted_n"] <= 10
  189. def test_logprobs():
  190. global server
  191. server.start()
  192. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  193. res = client.chat.completions.create(
  194. model="gpt-3.5-turbo-instruct",
  195. temperature=0.0,
  196. messages=[
  197. {"role": "system", "content": "Book"},
  198. {"role": "user", "content": "What is the best book"},
  199. ],
  200. max_tokens=5,
  201. logprobs=True,
  202. top_logprobs=10,
  203. )
  204. output_text = res.choices[0].message.content
  205. aggregated_text = ''
  206. assert res.choices[0].logprobs is not None
  207. assert res.choices[0].logprobs.content is not None
  208. for token in res.choices[0].logprobs.content:
  209. aggregated_text += token.token
  210. assert token.logprob <= 0.0
  211. assert token.bytes is not None
  212. assert len(token.top_logprobs) > 0
  213. assert aggregated_text == output_text
  214. def test_logprobs_stream():
  215. global server
  216. server.start()
  217. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  218. res = client.chat.completions.create(
  219. model="gpt-3.5-turbo-instruct",
  220. temperature=0.0,
  221. messages=[
  222. {"role": "system", "content": "Book"},
  223. {"role": "user", "content": "What is the best book"},
  224. ],
  225. max_tokens=5,
  226. logprobs=True,
  227. top_logprobs=10,
  228. stream=True,
  229. )
  230. output_text = ''
  231. aggregated_text = ''
  232. for data in res:
  233. choice = data.choices[0]
  234. if choice.finish_reason is None:
  235. if choice.delta.content:
  236. output_text += choice.delta.content
  237. assert choice.logprobs is not None
  238. assert choice.logprobs.content is not None
  239. for token in choice.logprobs.content:
  240. aggregated_text += token.token
  241. assert token.logprob <= 0.0
  242. assert token.bytes is not None
  243. assert token.top_logprobs is not None
  244. assert len(token.top_logprobs) > 0
  245. assert aggregated_text == output_text