| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478 |
- import pytest
- from openai import OpenAI
- from utils import *
- server: ServerProcess
- @pytest.fixture(autouse=True)
- def create_server():
- global server
- server = ServerPreset.tinyllama2()
- @pytest.mark.parametrize(
- "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template",
- [
- (None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", False, None),
- (None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", True, None),
- (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", False, None),
- (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None),
- (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'),
- (None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
- ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", False, None),
- ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", True, None),
- (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
- (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
- ]
- )
- def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
- global server
- server.jinja = jinja
- server.chat_template = chat_template
- server.start()
- res = server.make_request("POST", "/chat/completions", data={
- "model": model,
- "max_tokens": max_tokens,
- "messages": [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": user_prompt},
- ],
- })
- assert res.status_code == 200
- assert "cmpl" in res.body["id"] # make sure the completion id has the expected format
- assert res.body["system_fingerprint"].startswith("b")
- assert res.body["model"] == model if model is not None else server.model_alias
- assert res.body["usage"]["prompt_tokens"] == n_prompt
- assert res.body["usage"]["completion_tokens"] == n_predicted
- choice = res.body["choices"][0]
- assert "assistant" == choice["message"]["role"]
- assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
- assert choice["finish_reason"] == finish_reason
- @pytest.mark.parametrize(
- "system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
- [
- ("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
- ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length"),
- ]
- )
- def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
- global server
- server.model_alias = None # try using DEFAULT_OAICOMPAT_MODEL
- server.start()
- res = server.make_stream_request("POST", "/chat/completions", data={
- "max_tokens": max_tokens,
- "messages": [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": user_prompt},
- ],
- "stream": True,
- })
- content = ""
- last_cmpl_id = None
- for i, data in enumerate(res):
- if data["choices"]:
- choice = data["choices"][0]
- if i == 0:
- # Check first role message for stream=True
- assert choice["delta"]["content"] is None
- assert choice["delta"]["role"] == "assistant"
- else:
- assert "role" not in choice["delta"]
- assert data["system_fingerprint"].startswith("b")
- assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
- if last_cmpl_id is None:
- last_cmpl_id = data["id"]
- assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
- if choice["finish_reason"] in ["stop", "length"]:
- assert "content" not in choice["delta"]
- assert match_regex(re_content, content)
- assert choice["finish_reason"] == finish_reason
- else:
- assert choice["finish_reason"] is None
- content += choice["delta"]["content"] or ''
- else:
- assert data["usage"]["prompt_tokens"] == n_prompt
- assert data["usage"]["completion_tokens"] == n_predicted
- def test_chat_completion_with_openai_library():
- global server
- server.start()
- client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
- res = client.chat.completions.create(
- model="gpt-3.5-turbo-instruct",
- messages=[
- {"role": "system", "content": "Book"},
- {"role": "user", "content": "What is the best book"},
- ],
- max_tokens=8,
- seed=42,
- temperature=0.8,
- )
- assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
- assert res.choices[0].finish_reason == "length"
- assert res.choices[0].message.content is not None
- assert match_regex("(Suddenly)+", res.choices[0].message.content)
- def test_chat_template():
- global server
- server.chat_template = "llama3"
- server.debug = True # to get the "__verbose" object in the response
- server.start()
- res = server.make_request("POST", "/chat/completions", data={
- "max_tokens": 8,
- "messages": [
- {"role": "system", "content": "Book"},
- {"role": "user", "content": "What is the best book"},
- ]
- })
- assert res.status_code == 200
- assert "__verbose" in res.body
- assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
- @pytest.mark.parametrize("prefill,re_prefill", [
- ("Whill", "Whill"),
- ([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Whill"),
- ])
- def test_chat_template_assistant_prefill(prefill, re_prefill):
- global server
- server.chat_template = "llama3"
- server.debug = True # to get the "__verbose" object in the response
- server.start()
- res = server.make_request("POST", "/chat/completions", data={
- "max_tokens": 8,
- "messages": [
- {"role": "system", "content": "Book"},
- {"role": "user", "content": "What is the best book"},
- {"role": "assistant", "content": prefill},
- ]
- })
- assert res.status_code == 200
- assert "__verbose" in res.body
- assert res.body["__verbose"]["prompt"] == f"<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}"
- def test_apply_chat_template():
- global server
- server.chat_template = "command-r"
- server.start()
- res = server.make_request("POST", "/apply-template", data={
- "messages": [
- {"role": "system", "content": "You are a test."},
- {"role": "user", "content":"Hi there"},
- ]
- })
- assert res.status_code == 200
- assert "prompt" in res.body
- assert res.body["prompt"] == "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a test.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
- @pytest.mark.parametrize("response_format,n_predicted,re_content", [
- ({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
- ({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
- ({"type": "json_schema", "json_schema": {"schema": {"const": "foooooo"}}}, 10, "\"foooooo\""),
- ({"type": "json_object"}, 10, "(\\{|John)+"),
- ({"type": "sound"}, 0, None),
- # invalid response format (expected to fail)
- ({"type": "json_object", "schema": 123}, 0, None),
- ({"type": "json_object", "schema": {"type": 123}}, 0, None),
- ({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None),
- ])
- def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None):
- global server
- server.start()
- res = server.make_request("POST", "/chat/completions", data={
- "max_tokens": n_predicted,
- "messages": [
- {"role": "system", "content": "You are a coding assistant."},
- {"role": "user", "content": "Write an example"},
- ],
- "response_format": response_format,
- })
- if re_content is not None:
- assert res.status_code == 200
- choice = res.body["choices"][0]
- assert match_regex(re_content, choice["message"]["content"])
- else:
- assert res.status_code != 200
- assert "error" in res.body
- @pytest.mark.parametrize("jinja,json_schema,n_predicted,re_content", [
- (False, {"const": "42"}, 6, "\"42\""),
- (True, {"const": "42"}, 6, "\"42\""),
- ])
- def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str):
- global server
- server.jinja = jinja
- server.start()
- res = server.make_request("POST", "/chat/completions", data={
- "max_tokens": n_predicted,
- "messages": [
- {"role": "system", "content": "You are a coding assistant."},
- {"role": "user", "content": "Write an example"},
- ],
- "json_schema": json_schema,
- })
- assert res.status_code == 200, f'Expected 200, got {res.status_code}'
- choice = res.body["choices"][0]
- assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
- @pytest.mark.parametrize("jinja,grammar,n_predicted,re_content", [
- (False, 'root ::= "a"{5,5}', 6, "a{5,5}"),
- (True, 'root ::= "a"{5,5}', 6, "a{5,5}"),
- ])
- def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re_content: str):
- global server
- server.jinja = jinja
- server.start()
- res = server.make_request("POST", "/chat/completions", data={
- "max_tokens": n_predicted,
- "messages": [
- {"role": "user", "content": "Does not matter what I say, does it?"},
- ],
- "grammar": grammar,
- })
- assert res.status_code == 200, res.body
- choice = res.body["choices"][0]
- assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"]
- @pytest.mark.parametrize("messages", [
- None,
- "string",
- [123],
- [{}],
- [{"role": 123}],
- [{"role": "system", "content": 123}],
- # [{"content": "hello"}], # TODO: should not be a valid case
- [{"role": "system", "content": "test"}, {}],
- [{"role": "user", "content": "test"}, {"role": "assistant", "content": "test"}, {"role": "assistant", "content": "test"}],
- ])
- def test_invalid_chat_completion_req(messages):
- global server
- server.start()
- res = server.make_request("POST", "/chat/completions", data={
- "messages": messages,
- })
- assert res.status_code == 400 or res.status_code == 500
- assert "error" in res.body
- def test_chat_completion_with_timings_per_token():
- global server
- server.start()
- res = server.make_stream_request("POST", "/chat/completions", data={
- "max_tokens": 10,
- "messages": [{"role": "user", "content": "test"}],
- "stream": True,
- "stream_options": {"include_usage": True},
- "timings_per_token": True,
- })
- stats_received = False
- for i, data in enumerate(res):
- if i == 0:
- # Check first role message for stream=True
- assert data["choices"][0]["delta"]["content"] is None
- assert data["choices"][0]["delta"]["role"] == "assistant"
- assert "timings" not in data, f'First event should not have timings: {data}'
- else:
- if data["choices"]:
- assert "role" not in data["choices"][0]["delta"]
- else:
- assert "timings" in data
- assert "prompt_per_second" in data["timings"]
- assert "predicted_per_second" in data["timings"]
- assert "predicted_n" in data["timings"]
- assert data["timings"]["predicted_n"] <= 10
- stats_received = True
- assert stats_received
- def test_logprobs():
- global server
- server.start()
- client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
- res = client.chat.completions.create(
- model="gpt-3.5-turbo-instruct",
- temperature=0.0,
- messages=[
- {"role": "system", "content": "Book"},
- {"role": "user", "content": "What is the best book"},
- ],
- max_tokens=5,
- logprobs=True,
- top_logprobs=10,
- )
- output_text = res.choices[0].message.content
- aggregated_text = ''
- assert res.choices[0].logprobs is not None
- assert res.choices[0].logprobs.content is not None
- for token in res.choices[0].logprobs.content:
- aggregated_text += token.token
- assert token.logprob <= 0.0
- assert token.bytes is not None
- assert len(token.top_logprobs) > 0
- assert aggregated_text == output_text
- def test_logprobs_stream():
- global server
- server.start()
- client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
- res = client.chat.completions.create(
- model="gpt-3.5-turbo-instruct",
- temperature=0.0,
- messages=[
- {"role": "system", "content": "Book"},
- {"role": "user", "content": "What is the best book"},
- ],
- max_tokens=5,
- logprobs=True,
- top_logprobs=10,
- stream=True,
- )
- output_text = ''
- aggregated_text = ''
- for i, data in enumerate(res):
- if data.choices:
- choice = data.choices[0]
- if i == 0:
- # Check first role message for stream=True
- assert choice.delta.content is None
- assert choice.delta.role == "assistant"
- else:
- assert choice.delta.role is None
- if choice.finish_reason is None:
- if choice.delta.content:
- output_text += choice.delta.content
- assert choice.logprobs is not None
- assert choice.logprobs.content is not None
- for token in choice.logprobs.content:
- aggregated_text += token.token
- assert token.logprob <= 0.0
- assert token.bytes is not None
- assert token.top_logprobs is not None
- assert len(token.top_logprobs) > 0
- assert aggregated_text == output_text
- def test_logit_bias():
- global server
- server.start()
- exclude = ["i", "I", "the", "The", "to", "a", "an", "be", "is", "was", "but", "But", "and", "And", "so", "So", "you", "You", "he", "He", "she", "She", "we", "We", "they", "They", "it", "It", "his", "His", "her", "Her", "book", "Book"]
- res = server.make_request("POST", "/tokenize", data={
- "content": " " + " ".join(exclude) + " ",
- })
- assert res.status_code == 200
- tokens = res.body["tokens"]
- logit_bias = {tok: -100 for tok in tokens}
- client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
- res = client.chat.completions.create(
- model="gpt-3.5-turbo-instruct",
- temperature=0.0,
- messages=[
- {"role": "system", "content": "Book"},
- {"role": "user", "content": "What is the best book"},
- ],
- max_tokens=64,
- logit_bias=logit_bias
- )
- output_text = res.choices[0].message.content
- assert output_text
- assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude)
- def test_context_size_exceeded():
- global server
- server.start()
- res = server.make_request("POST", "/chat/completions", data={
- "messages": [
- {"role": "system", "content": "Book"},
- {"role": "user", "content": "What is the best book"},
- ] * 100, # make the prompt too long
- })
- assert res.status_code == 400
- assert "error" in res.body
- assert res.body["error"]["type"] == "exceed_context_size_error"
- assert res.body["error"]["n_prompt_tokens"] > 0
- assert server.n_ctx is not None
- assert server.n_slots is not None
- assert res.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
- def test_context_size_exceeded_stream():
- global server
- server.start()
- try:
- for _ in server.make_stream_request("POST", "/chat/completions", data={
- "messages": [
- {"role": "system", "content": "Book"},
- {"role": "user", "content": "What is the best book"},
- ] * 100, # make the prompt too long
- "stream": True}):
- pass
- assert False, "Should have failed"
- except ServerError as e:
- assert e.code == 400
- assert "error" in e.body
- assert e.body["error"]["type"] == "exceed_context_size_error"
- assert e.body["error"]["n_prompt_tokens"] > 0
- assert server.n_ctx is not None
- assert server.n_slots is not None
- assert e.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
- @pytest.mark.parametrize(
- "n_batch,batch_count,reuse_cache",
- [
- (64, 3, False),
- (64, 1, True),
- ]
- )
- def test_return_progress(n_batch, batch_count, reuse_cache):
- global server
- server.n_batch = n_batch
- server.n_ctx = 256
- server.n_slots = 1
- server.start()
- def make_cmpl_request():
- return server.make_stream_request("POST", "/chat/completions", data={
- "max_tokens": 10,
- "messages": [
- {"role": "user", "content": "This is a test" * 10},
- ],
- "stream": True,
- "return_progress": True,
- })
- if reuse_cache:
- # make a first request to populate the cache
- res0 = make_cmpl_request()
- for _ in res0:
- pass # discard the output
- res = make_cmpl_request()
- last_progress = None
- total_batch_count = 0
- for data in res:
- cur_progress = data.get("prompt_progress", None)
- if cur_progress is None:
- continue
- if last_progress is not None:
- assert cur_progress["total"] == last_progress["total"]
- assert cur_progress["cache"] == last_progress["cache"]
- assert cur_progress["processed"] > last_progress["processed"]
- total_batch_count += 1
- last_progress = cur_progress
- assert last_progress is not None
- assert last_progress["total"] > 0
- assert last_progress["processed"] == last_progress["total"]
- assert total_batch_count == batch_count
|