test_chat_completion.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. import pytest
  2. from openai import OpenAI
  3. from utils import *
  4. server = ServerPreset.tinyllama2()
  5. @pytest.fixture(scope="module", autouse=True)
  6. def create_server():
  7. global server
  8. server = ServerPreset.tinyllama2()
  9. @pytest.mark.parametrize(
  10. "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
  11. [
  12. ("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
  13. ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
  14. ]
  15. )
  16. def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
  17. global server
  18. server.start()
  19. res = server.make_request("POST", "/chat/completions", data={
  20. "model": model,
  21. "max_tokens": max_tokens,
  22. "messages": [
  23. {"role": "system", "content": system_prompt},
  24. {"role": "user", "content": user_prompt},
  25. ],
  26. })
  27. assert res.status_code == 200
  28. assert res.body["usage"]["prompt_tokens"] == n_prompt
  29. assert res.body["usage"]["completion_tokens"] == n_predicted
  30. choice = res.body["choices"][0]
  31. assert "assistant" == choice["message"]["role"]
  32. assert match_regex(re_content, choice["message"]["content"])
  33. if truncated:
  34. assert choice["finish_reason"] == "length"
  35. else:
  36. assert choice["finish_reason"] == "stop"
  37. @pytest.mark.parametrize(
  38. "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
  39. [
  40. ("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
  41. ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
  42. ]
  43. )
  44. def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
  45. global server
  46. server.start()
  47. res = server.make_stream_request("POST", "/chat/completions", data={
  48. "model": model,
  49. "max_tokens": max_tokens,
  50. "messages": [
  51. {"role": "system", "content": system_prompt},
  52. {"role": "user", "content": user_prompt},
  53. ],
  54. "stream": True,
  55. })
  56. content = ""
  57. for data in res:
  58. choice = data["choices"][0]
  59. if choice["finish_reason"] in ["stop", "length"]:
  60. assert data["usage"]["prompt_tokens"] == n_prompt
  61. assert data["usage"]["completion_tokens"] == n_predicted
  62. assert "content" not in choice["delta"]
  63. assert match_regex(re_content, content)
  64. # FIXME: not sure why this is incorrect in stream mode
  65. # if truncated:
  66. # assert choice["finish_reason"] == "length"
  67. # else:
  68. # assert choice["finish_reason"] == "stop"
  69. else:
  70. assert choice["finish_reason"] is None
  71. content += choice["delta"]["content"]
  72. def test_chat_completion_with_openai_library():
  73. global server
  74. server.start()
  75. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
  76. res = client.chat.completions.create(
  77. model="gpt-3.5-turbo-instruct",
  78. messages=[
  79. {"role": "system", "content": "Book"},
  80. {"role": "user", "content": "What is the best book"},
  81. ],
  82. max_tokens=8,
  83. seed=42,
  84. temperature=0.8,
  85. )
  86. print(res)
  87. assert res.choices[0].finish_reason == "stop"
  88. assert res.choices[0].message.content is not None
  89. assert match_regex("(Suddenly)+", res.choices[0].message.content)
  90. @pytest.mark.parametrize("response_format,n_predicted,re_content", [
  91. ({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
  92. ({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
  93. ({"type": "json_object"}, 10, "(\\{|John)+"),
  94. ({"type": "sound"}, 0, None),
  95. # invalid response format (expected to fail)
  96. ({"type": "json_object", "schema": 123}, 0, None),
  97. ({"type": "json_object", "schema": {"type": 123}}, 0, None),
  98. ({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None),
  99. ])
  100. def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None):
  101. global server
  102. server.start()
  103. res = server.make_request("POST", "/chat/completions", data={
  104. "max_tokens": n_predicted,
  105. "messages": [
  106. {"role": "system", "content": "You are a coding assistant."},
  107. {"role": "user", "content": "Write an example"},
  108. ],
  109. "response_format": response_format,
  110. })
  111. if re_content is not None:
  112. assert res.status_code == 200
  113. choice = res.body["choices"][0]
  114. assert match_regex(re_content, choice["message"]["content"])
  115. else:
  116. assert res.status_code != 200
  117. assert "error" in res.body
  118. @pytest.mark.parametrize("messages", [
  119. None,
  120. "string",
  121. [123],
  122. [{}],
  123. [{"role": 123}],
  124. [{"role": "system", "content": 123}],
  125. # [{"content": "hello"}], # TODO: should not be a valid case
  126. [{"role": "system", "content": "test"}, {}],
  127. ])
  128. def test_invalid_chat_completion_req(messages):
  129. global server
  130. server.start()
  131. res = server.make_request("POST", "/chat/completions", data={
  132. "messages": messages,
  133. })
  134. assert res.status_code == 400 or res.status_code == 500
  135. assert "error" in res.body