| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 |
- import pytest
- from openai import OpenAI
- from utils import *
- server = ServerPreset.tinyllama2()
- @pytest.fixture(scope="module", autouse=True)
- def create_server():
- global server
- server = ServerPreset.tinyllama2()
- @pytest.mark.parametrize(
- "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
- [
- ("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
- ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
- ]
- )
- def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
- global server
- server.start()
- res = server.make_request("POST", "/chat/completions", data={
- "model": model,
- "max_tokens": max_tokens,
- "messages": [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": user_prompt},
- ],
- })
- assert res.status_code == 200
- assert res.body["usage"]["prompt_tokens"] == n_prompt
- assert res.body["usage"]["completion_tokens"] == n_predicted
- choice = res.body["choices"][0]
- assert "assistant" == choice["message"]["role"]
- assert match_regex(re_content, choice["message"]["content"])
- if truncated:
- assert choice["finish_reason"] == "length"
- else:
- assert choice["finish_reason"] == "stop"
- @pytest.mark.parametrize(
- "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
- [
- ("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
- ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
- ]
- )
- def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
- global server
- server.start()
- res = server.make_stream_request("POST", "/chat/completions", data={
- "model": model,
- "max_tokens": max_tokens,
- "messages": [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": user_prompt},
- ],
- "stream": True,
- })
- content = ""
- for data in res:
- choice = data["choices"][0]
- if choice["finish_reason"] in ["stop", "length"]:
- assert data["usage"]["prompt_tokens"] == n_prompt
- assert data["usage"]["completion_tokens"] == n_predicted
- assert "content" not in choice["delta"]
- assert match_regex(re_content, content)
- # FIXME: not sure why this is incorrect in stream mode
- # if truncated:
- # assert choice["finish_reason"] == "length"
- # else:
- # assert choice["finish_reason"] == "stop"
- else:
- assert choice["finish_reason"] is None
- content += choice["delta"]["content"]
- def test_chat_completion_with_openai_library():
- global server
- server.start()
- client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
- res = client.chat.completions.create(
- model="gpt-3.5-turbo-instruct",
- messages=[
- {"role": "system", "content": "Book"},
- {"role": "user", "content": "What is the best book"},
- ],
- max_tokens=8,
- seed=42,
- temperature=0.8,
- )
- print(res)
- assert res.choices[0].finish_reason == "stop"
- assert res.choices[0].message.content is not None
- assert match_regex("(Suddenly)+", res.choices[0].message.content)
- @pytest.mark.parametrize("response_format,n_predicted,re_content", [
- ({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
- ({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
- ({"type": "json_object"}, 10, "(\\{|John)+"),
- ({"type": "sound"}, 0, None),
- # invalid response format (expected to fail)
- ({"type": "json_object", "schema": 123}, 0, None),
- ({"type": "json_object", "schema": {"type": 123}}, 0, None),
- ({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None),
- ])
- def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None):
- global server
- server.start()
- res = server.make_request("POST", "/chat/completions", data={
- "max_tokens": n_predicted,
- "messages": [
- {"role": "system", "content": "You are a coding assistant."},
- {"role": "user", "content": "Write an example"},
- ],
- "response_format": response_format,
- })
- if re_content is not None:
- assert res.status_code == 200
- choice = res.body["choices"][0]
- assert match_regex(re_content, choice["message"]["content"])
- else:
- assert res.status_code != 200
- assert "error" in res.body
- @pytest.mark.parametrize("messages", [
- None,
- "string",
- [123],
- [{}],
- [{"role": 123}],
- [{"role": "system", "content": 123}],
- # [{"content": "hello"}], # TODO: should not be a valid case
- [{"role": "system", "content": "test"}, {}],
- ])
- def test_invalid_chat_completion_req(messages):
- global server
- server.start()
- res = server.make_request("POST", "/chat/completions", data={
- "messages": messages,
- })
- assert res.status_code == 400 or res.status_code == 500
- assert "error" in res.body
- def test_chat_completion_with_timings_per_token():
- global server
- server.start()
- res = server.make_stream_request("POST", "/chat/completions", data={
- "max_tokens": 10,
- "messages": [{"role": "user", "content": "test"}],
- "stream": True,
- "timings_per_token": True,
- })
- for data in res:
- assert "timings" in data
- assert "prompt_per_second" in data["timings"]
- assert "predicted_per_second" in data["timings"]
- assert "predicted_n" in data["timings"]
- assert data["timings"]["predicted_n"] <= 10
|