test_chat_completion.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. import pytest
  2. from openai import OpenAI
  3. from utils import *
  4. server = ServerPreset.tinyllama2()
  5. @pytest.fixture(scope="module", autouse=True)
  6. def create_server():
  7. global server
  8. server = ServerPreset.tinyllama2()
  9. @pytest.mark.parametrize(
  10. "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
  11. [
  12. (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
  13. ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
  14. ]
  15. )
  16. def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
  17. global server
  18. server.start()
  19. res = server.make_request("POST", "/chat/completions", data={
  20. "model": model,
  21. "max_tokens": max_tokens,
  22. "messages": [
  23. {"role": "system", "content": system_prompt},
  24. {"role": "user", "content": user_prompt},
  25. ],
  26. })
  27. assert res.status_code == 200
  28. assert "cmpl" in res.body["id"] # make sure the completion id has the expected format
  29. assert res.body["system_fingerprint"].startswith("b")
  30. assert res.body["model"] == model if model is not None else server.model_alias
  31. assert res.body["usage"]["prompt_tokens"] == n_prompt
  32. assert res.body["usage"]["completion_tokens"] == n_predicted
  33. choice = res.body["choices"][0]
  34. assert "assistant" == choice["message"]["role"]
  35. assert match_regex(re_content, choice["message"]["content"])
  36. assert choice["finish_reason"] == finish_reason
  37. @pytest.mark.parametrize(
  38. "system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
  39. [
  40. ("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
  41. ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
  42. ]
  43. )
  44. def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
  45. global server
  46. server.model_alias = None # try using DEFAULT_OAICOMPAT_MODEL
  47. server.start()
  48. res = server.make_stream_request("POST", "/chat/completions", data={
  49. "max_tokens": max_tokens,
  50. "messages": [
  51. {"role": "system", "content": system_prompt},
  52. {"role": "user", "content": user_prompt},
  53. ],
  54. "stream": True,
  55. })
  56. content = ""
  57. last_cmpl_id = None
  58. for data in res:
  59. choice = data["choices"][0]
  60. assert data["system_fingerprint"].startswith("b")
  61. assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
  62. if last_cmpl_id is None:
  63. last_cmpl_id = data["id"]
  64. assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
  65. if choice["finish_reason"] in ["stop", "length"]:
  66. assert data["usage"]["prompt_tokens"] == n_prompt
  67. assert data["usage"]["completion_tokens"] == n_predicted
  68. assert "content" not in choice["delta"]
  69. assert match_regex(re_content, content)
  70. assert choice["finish_reason"] == finish_reason
  71. else:
  72. assert choice["finish_reason"] is None
  73. content += choice["delta"]["content"]
  74. def test_chat_completion_with_openai_library():
  75. global server
  76. server.start()
  77. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
  78. res = client.chat.completions.create(
  79. model="gpt-3.5-turbo-instruct",
  80. messages=[
  81. {"role": "system", "content": "Book"},
  82. {"role": "user", "content": "What is the best book"},
  83. ],
  84. max_tokens=8,
  85. seed=42,
  86. temperature=0.8,
  87. )
  88. assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
  89. assert res.choices[0].finish_reason == "length"
  90. assert res.choices[0].message.content is not None
  91. assert match_regex("(Suddenly)+", res.choices[0].message.content)
  92. @pytest.mark.parametrize("response_format,n_predicted,re_content", [
  93. ({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
  94. ({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
  95. ({"type": "json_object"}, 10, "(\\{|John)+"),
  96. ({"type": "sound"}, 0, None),
  97. # invalid response format (expected to fail)
  98. ({"type": "json_object", "schema": 123}, 0, None),
  99. ({"type": "json_object", "schema": {"type": 123}}, 0, None),
  100. ({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None),
  101. ])
  102. def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None):
  103. global server
  104. server.start()
  105. res = server.make_request("POST", "/chat/completions", data={
  106. "max_tokens": n_predicted,
  107. "messages": [
  108. {"role": "system", "content": "You are a coding assistant."},
  109. {"role": "user", "content": "Write an example"},
  110. ],
  111. "response_format": response_format,
  112. })
  113. if re_content is not None:
  114. assert res.status_code == 200
  115. choice = res.body["choices"][0]
  116. assert match_regex(re_content, choice["message"]["content"])
  117. else:
  118. assert res.status_code != 200
  119. assert "error" in res.body
  120. @pytest.mark.parametrize("messages", [
  121. None,
  122. "string",
  123. [123],
  124. [{}],
  125. [{"role": 123}],
  126. [{"role": "system", "content": 123}],
  127. # [{"content": "hello"}], # TODO: should not be a valid case
  128. [{"role": "system", "content": "test"}, {}],
  129. ])
  130. def test_invalid_chat_completion_req(messages):
  131. global server
  132. server.start()
  133. res = server.make_request("POST", "/chat/completions", data={
  134. "messages": messages,
  135. })
  136. assert res.status_code == 400 or res.status_code == 500
  137. assert "error" in res.body
  138. def test_chat_completion_with_timings_per_token():
  139. global server
  140. server.start()
  141. res = server.make_stream_request("POST", "/chat/completions", data={
  142. "max_tokens": 10,
  143. "messages": [{"role": "user", "content": "test"}],
  144. "stream": True,
  145. "timings_per_token": True,
  146. })
  147. for data in res:
  148. assert "timings" in data
  149. assert "prompt_per_second" in data["timings"]
  150. assert "predicted_per_second" in data["timings"]
  151. assert "predicted_n" in data["timings"]
  152. assert data["timings"]["predicted_n"] <= 10
  153. def test_logprobs():
  154. global server
  155. server.start()
  156. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
  157. res = client.chat.completions.create(
  158. model="gpt-3.5-turbo-instruct",
  159. temperature=0.0,
  160. messages=[
  161. {"role": "system", "content": "Book"},
  162. {"role": "user", "content": "What is the best book"},
  163. ],
  164. max_tokens=5,
  165. logprobs=True,
  166. top_logprobs=10,
  167. )
  168. output_text = res.choices[0].message.content
  169. aggregated_text = ''
  170. assert res.choices[0].logprobs is not None
  171. assert res.choices[0].logprobs.content is not None
  172. for token in res.choices[0].logprobs.content:
  173. aggregated_text += token.token
  174. assert token.logprob <= 0.0
  175. assert token.bytes is not None
  176. assert len(token.top_logprobs) > 0
  177. assert aggregated_text == output_text
  178. def test_logprobs_stream():
  179. global server
  180. server.start()
  181. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
  182. res = client.chat.completions.create(
  183. model="gpt-3.5-turbo-instruct",
  184. temperature=0.0,
  185. messages=[
  186. {"role": "system", "content": "Book"},
  187. {"role": "user", "content": "What is the best book"},
  188. ],
  189. max_tokens=5,
  190. logprobs=True,
  191. top_logprobs=10,
  192. stream=True,
  193. )
  194. output_text = ''
  195. aggregated_text = ''
  196. for data in res:
  197. choice = data.choices[0]
  198. if choice.finish_reason is None:
  199. if choice.delta.content:
  200. output_text += choice.delta.content
  201. assert choice.logprobs is not None
  202. assert choice.logprobs.content is not None
  203. for token in choice.logprobs.content:
  204. aggregated_text += token.token
  205. assert token.logprob <= 0.0
  206. assert token.bytes is not None
  207. assert token.top_logprobs is not None
  208. assert len(token.top_logprobs) > 0
  209. assert aggregated_text == output_text