test_chat_completion.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. import pytest
  2. from openai import OpenAI
  3. from utils import *
  4. server: ServerProcess
  5. @pytest.fixture(autouse=True)
  6. def create_server():
  7. global server
  8. server = ServerPreset.tinyllama2()
  9. @pytest.mark.parametrize(
  10. "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template",
  11. [
  12. (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", False, None),
  13. ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
  14. # TODO: fix testing of non-tool jinja mode
  15. # (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None),
  16. # (None, "Book", "What is the best book", 8, "I want to play with", 23, 8, "length", True, "This is not a chat template, it is"),
  17. # ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
  18. ]
  19. )
  20. def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
  21. global server
  22. server.jinja = jinja
  23. server.chat_template = chat_template
  24. server.start()
  25. res = server.make_request("POST", "/chat/completions", data={
  26. "model": model,
  27. "max_tokens": max_tokens,
  28. "messages": [
  29. {"role": "system", "content": system_prompt},
  30. {"role": "user", "content": user_prompt},
  31. ],
  32. })
  33. assert res.status_code == 200
  34. assert "cmpl" in res.body["id"] # make sure the completion id has the expected format
  35. assert res.body["system_fingerprint"].startswith("b")
  36. assert res.body["model"] == model if model is not None else server.model_alias
  37. assert res.body["usage"]["prompt_tokens"] == n_prompt
  38. assert res.body["usage"]["completion_tokens"] == n_predicted
  39. choice = res.body["choices"][0]
  40. assert "assistant" == choice["message"]["role"]
  41. assert match_regex(re_content, choice["message"]["content"])
  42. assert choice["finish_reason"] == finish_reason
  43. @pytest.mark.parametrize(
  44. "system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
  45. [
  46. ("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
  47. ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
  48. ]
  49. )
  50. def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
  51. global server
  52. server.model_alias = None # try using DEFAULT_OAICOMPAT_MODEL
  53. server.start()
  54. res = server.make_stream_request("POST", "/chat/completions", data={
  55. "max_tokens": max_tokens,
  56. "messages": [
  57. {"role": "system", "content": system_prompt},
  58. {"role": "user", "content": user_prompt},
  59. ],
  60. "stream": True,
  61. })
  62. content = ""
  63. last_cmpl_id = None
  64. for data in res:
  65. choice = data["choices"][0]
  66. assert data["system_fingerprint"].startswith("b")
  67. assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
  68. if last_cmpl_id is None:
  69. last_cmpl_id = data["id"]
  70. assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
  71. if choice["finish_reason"] in ["stop", "length"]:
  72. assert data["usage"]["prompt_tokens"] == n_prompt
  73. assert data["usage"]["completion_tokens"] == n_predicted
  74. assert "content" not in choice["delta"]
  75. assert match_regex(re_content, content)
  76. assert choice["finish_reason"] == finish_reason
  77. else:
  78. assert choice["finish_reason"] is None
  79. content += choice["delta"]["content"]
  80. def test_chat_completion_with_openai_library():
  81. global server
  82. server.start()
  83. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  84. res = client.chat.completions.create(
  85. model="gpt-3.5-turbo-instruct",
  86. messages=[
  87. {"role": "system", "content": "Book"},
  88. {"role": "user", "content": "What is the best book"},
  89. ],
  90. max_tokens=8,
  91. seed=42,
  92. temperature=0.8,
  93. )
  94. assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
  95. assert res.choices[0].finish_reason == "length"
  96. assert res.choices[0].message.content is not None
  97. assert match_regex("(Suddenly)+", res.choices[0].message.content)
  98. def test_chat_template():
  99. global server
  100. server.chat_template = "llama3"
  101. server.debug = True # to get the "__verbose" object in the response
  102. server.start()
  103. res = server.make_request("POST", "/chat/completions", data={
  104. "max_tokens": 8,
  105. "messages": [
  106. {"role": "system", "content": "Book"},
  107. {"role": "user", "content": "What is the best book"},
  108. ]
  109. })
  110. assert res.status_code == 200
  111. assert "__verbose" in res.body
  112. assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
  113. def test_apply_chat_template():
  114. global server
  115. server.chat_template = "command-r"
  116. server.start()
  117. res = server.make_request("POST", "/apply-template", data={
  118. "messages": [
  119. {"role": "system", "content": "You are a test."},
  120. {"role": "user", "content":"Hi there"},
  121. ]
  122. })
  123. assert res.status_code == 200
  124. assert "prompt" in res.body
  125. assert res.body["prompt"] == "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a test.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
  126. @pytest.mark.parametrize("response_format,n_predicted,re_content", [
  127. ({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
  128. ({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
  129. ({"type": "json_object"}, 10, "(\\{|John)+"),
  130. ({"type": "sound"}, 0, None),
  131. # invalid response format (expected to fail)
  132. ({"type": "json_object", "schema": 123}, 0, None),
  133. ({"type": "json_object", "schema": {"type": 123}}, 0, None),
  134. ({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None),
  135. ])
  136. def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None):
  137. global server
  138. server.start()
  139. res = server.make_request("POST", "/chat/completions", data={
  140. "max_tokens": n_predicted,
  141. "messages": [
  142. {"role": "system", "content": "You are a coding assistant."},
  143. {"role": "user", "content": "Write an example"},
  144. ],
  145. "response_format": response_format,
  146. })
  147. if re_content is not None:
  148. assert res.status_code == 200
  149. choice = res.body["choices"][0]
  150. assert match_regex(re_content, choice["message"]["content"])
  151. else:
  152. assert res.status_code != 200
  153. assert "error" in res.body
  154. @pytest.mark.parametrize("messages", [
  155. None,
  156. "string",
  157. [123],
  158. [{}],
  159. [{"role": 123}],
  160. [{"role": "system", "content": 123}],
  161. # [{"content": "hello"}], # TODO: should not be a valid case
  162. [{"role": "system", "content": "test"}, {}],
  163. ])
  164. def test_invalid_chat_completion_req(messages):
  165. global server
  166. server.start()
  167. res = server.make_request("POST", "/chat/completions", data={
  168. "messages": messages,
  169. })
  170. assert res.status_code == 400 or res.status_code == 500
  171. assert "error" in res.body
  172. def test_chat_completion_with_timings_per_token():
  173. global server
  174. server.start()
  175. res = server.make_stream_request("POST", "/chat/completions", data={
  176. "max_tokens": 10,
  177. "messages": [{"role": "user", "content": "test"}],
  178. "stream": True,
  179. "timings_per_token": True,
  180. })
  181. for data in res:
  182. assert "timings" in data
  183. assert "prompt_per_second" in data["timings"]
  184. assert "predicted_per_second" in data["timings"]
  185. assert "predicted_n" in data["timings"]
  186. assert data["timings"]["predicted_n"] <= 10
  187. def test_logprobs():
  188. global server
  189. server.start()
  190. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  191. res = client.chat.completions.create(
  192. model="gpt-3.5-turbo-instruct",
  193. temperature=0.0,
  194. messages=[
  195. {"role": "system", "content": "Book"},
  196. {"role": "user", "content": "What is the best book"},
  197. ],
  198. max_tokens=5,
  199. logprobs=True,
  200. top_logprobs=10,
  201. )
  202. output_text = res.choices[0].message.content
  203. aggregated_text = ''
  204. assert res.choices[0].logprobs is not None
  205. assert res.choices[0].logprobs.content is not None
  206. for token in res.choices[0].logprobs.content:
  207. aggregated_text += token.token
  208. assert token.logprob <= 0.0
  209. assert token.bytes is not None
  210. assert len(token.top_logprobs) > 0
  211. assert aggregated_text == output_text
  212. def test_logprobs_stream():
  213. global server
  214. server.start()
  215. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  216. res = client.chat.completions.create(
  217. model="gpt-3.5-turbo-instruct",
  218. temperature=0.0,
  219. messages=[
  220. {"role": "system", "content": "Book"},
  221. {"role": "user", "content": "What is the best book"},
  222. ],
  223. max_tokens=5,
  224. logprobs=True,
  225. top_logprobs=10,
  226. stream=True,
  227. )
  228. output_text = ''
  229. aggregated_text = ''
  230. for data in res:
  231. choice = data.choices[0]
  232. if choice.finish_reason is None:
  233. if choice.delta.content:
  234. output_text += choice.delta.content
  235. assert choice.logprobs is not None
  236. assert choice.logprobs.content is not None
  237. for token in choice.logprobs.content:
  238. aggregated_text += token.token
  239. assert token.logprob <= 0.0
  240. assert token.bytes is not None
  241. assert token.top_logprobs is not None
  242. assert len(token.top_logprobs) > 0
  243. assert aggregated_text == output_text