test_chat_completion.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. import pytest
  2. from openai import OpenAI
  3. from utils import *
  4. server: ServerProcess
  5. @pytest.fixture(autouse=True)
  6. def create_server():
  7. global server
  8. server = ServerPreset.tinyllama2()
  9. @pytest.mark.parametrize(
  10. "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template",
  11. [
  12. (None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", False, None),
  13. (None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", True, None),
  14. (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", False, None),
  15. (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None),
  16. (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'),
  17. (None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
  18. ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", False, None),
  19. ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", True, None),
  20. (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
  21. (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
  22. ]
  23. )
  24. def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
  25. global server
  26. server.jinja = jinja
  27. server.chat_template = chat_template
  28. server.start()
  29. res = server.make_request("POST", "/chat/completions", data={
  30. "model": model,
  31. "max_tokens": max_tokens,
  32. "messages": [
  33. {"role": "system", "content": system_prompt},
  34. {"role": "user", "content": user_prompt},
  35. ],
  36. })
  37. assert res.status_code == 200
  38. assert "cmpl" in res.body["id"] # make sure the completion id has the expected format
  39. assert res.body["system_fingerprint"].startswith("b")
  40. # we no longer reflect back the model name, see https://github.com/ggml-org/llama.cpp/pull/17668
  41. # assert res.body["model"] == model if model is not None else server.model_alias
  42. assert res.body["usage"]["prompt_tokens"] == n_prompt
  43. assert res.body["usage"]["completion_tokens"] == n_predicted
  44. choice = res.body["choices"][0]
  45. assert "assistant" == choice["message"]["role"]
  46. assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
  47. assert choice["finish_reason"] == finish_reason
  48. @pytest.mark.parametrize(
  49. "system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
  50. [
  51. ("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
  52. ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length"),
  53. ]
  54. )
  55. def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
  56. global server
  57. server.model_alias = "llama-test-model"
  58. server.start()
  59. res = server.make_stream_request("POST", "/chat/completions", data={
  60. "max_tokens": max_tokens,
  61. "messages": [
  62. {"role": "system", "content": system_prompt},
  63. {"role": "user", "content": user_prompt},
  64. ],
  65. "stream": True,
  66. })
  67. content = ""
  68. last_cmpl_id = None
  69. for i, data in enumerate(res):
  70. if data["choices"]:
  71. choice = data["choices"][0]
  72. if i == 0:
  73. # Check first role message for stream=True
  74. assert choice["delta"]["content"] is None
  75. assert choice["delta"]["role"] == "assistant"
  76. else:
  77. assert "role" not in choice["delta"]
  78. assert data["system_fingerprint"].startswith("b")
  79. assert data["model"] == "llama-test-model"
  80. if last_cmpl_id is None:
  81. last_cmpl_id = data["id"]
  82. assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
  83. if choice["finish_reason"] in ["stop", "length"]:
  84. assert "content" not in choice["delta"]
  85. assert match_regex(re_content, content)
  86. assert choice["finish_reason"] == finish_reason
  87. else:
  88. assert choice["finish_reason"] is None
  89. content += choice["delta"]["content"] or ''
  90. else:
  91. assert data["usage"]["prompt_tokens"] == n_prompt
  92. assert data["usage"]["completion_tokens"] == n_predicted
  93. def test_chat_completion_with_openai_library():
  94. global server
  95. server.start()
  96. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  97. res = client.chat.completions.create(
  98. model="gpt-3.5-turbo-instruct",
  99. messages=[
  100. {"role": "system", "content": "Book"},
  101. {"role": "user", "content": "What is the best book"},
  102. ],
  103. max_tokens=8,
  104. seed=42,
  105. temperature=0.8,
  106. )
  107. assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
  108. assert res.choices[0].finish_reason == "length"
  109. assert res.choices[0].message.content is not None
  110. assert match_regex("(Suddenly)+", res.choices[0].message.content)
  111. def test_chat_template():
  112. global server
  113. server.chat_template = "llama3"
  114. server.debug = True # to get the "__verbose" object in the response
  115. server.start()
  116. res = server.make_request("POST", "/chat/completions", data={
  117. "max_tokens": 8,
  118. "messages": [
  119. {"role": "system", "content": "Book"},
  120. {"role": "user", "content": "What is the best book"},
  121. ]
  122. })
  123. assert res.status_code == 200
  124. assert "__verbose" in res.body
  125. assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
  126. @pytest.mark.parametrize("prefill,re_prefill", [
  127. ("Whill", "Whill"),
  128. ([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Whill"),
  129. ])
  130. def test_chat_template_assistant_prefill(prefill, re_prefill):
  131. global server
  132. server.chat_template = "llama3"
  133. server.debug = True # to get the "__verbose" object in the response
  134. server.start()
  135. res = server.make_request("POST", "/chat/completions", data={
  136. "max_tokens": 8,
  137. "messages": [
  138. {"role": "system", "content": "Book"},
  139. {"role": "user", "content": "What is the best book"},
  140. {"role": "assistant", "content": prefill},
  141. ]
  142. })
  143. assert res.status_code == 200
  144. assert "__verbose" in res.body
  145. assert res.body["__verbose"]["prompt"] == f"<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}"
  146. def test_apply_chat_template():
  147. global server
  148. server.chat_template = "command-r"
  149. server.start()
  150. res = server.make_request("POST", "/apply-template", data={
  151. "messages": [
  152. {"role": "system", "content": "You are a test."},
  153. {"role": "user", "content":"Hi there"},
  154. ]
  155. })
  156. assert res.status_code == 200
  157. assert "prompt" in res.body
  158. assert res.body["prompt"] == "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a test.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
  159. @pytest.mark.parametrize("response_format,n_predicted,re_content", [
  160. ({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
  161. ({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
  162. ({"type": "json_schema", "json_schema": {"schema": {"const": "foooooo"}}}, 10, "\"foooooo\""),
  163. ({"type": "json_object"}, 10, "(\\{|John)+"),
  164. ({"type": "sound"}, 0, None),
  165. # invalid response format (expected to fail)
  166. ({"type": "json_object", "schema": 123}, 0, None),
  167. ({"type": "json_object", "schema": {"type": 123}}, 0, None),
  168. ({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None),
  169. ])
  170. def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None):
  171. global server
  172. server.start()
  173. res = server.make_request("POST", "/chat/completions", data={
  174. "max_tokens": n_predicted,
  175. "messages": [
  176. {"role": "system", "content": "You are a coding assistant."},
  177. {"role": "user", "content": "Write an example"},
  178. ],
  179. "response_format": response_format,
  180. })
  181. if re_content is not None:
  182. assert res.status_code == 200
  183. choice = res.body["choices"][0]
  184. assert match_regex(re_content, choice["message"]["content"])
  185. else:
  186. assert res.status_code == 400
  187. assert "error" in res.body
  188. @pytest.mark.parametrize("jinja,json_schema,n_predicted,re_content", [
  189. (False, {"const": "42"}, 6, "\"42\""),
  190. (True, {"const": "42"}, 6, "\"42\""),
  191. ])
  192. def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str):
  193. global server
  194. server.jinja = jinja
  195. server.start()
  196. res = server.make_request("POST", "/chat/completions", data={
  197. "max_tokens": n_predicted,
  198. "messages": [
  199. {"role": "system", "content": "You are a coding assistant."},
  200. {"role": "user", "content": "Write an example"},
  201. ],
  202. "json_schema": json_schema,
  203. })
  204. assert res.status_code == 200, f'Expected 200, got {res.status_code}'
  205. choice = res.body["choices"][0]
  206. assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
  207. @pytest.mark.parametrize("jinja,grammar,n_predicted,re_content", [
  208. (False, 'root ::= "a"{5,5}', 6, "a{5,5}"),
  209. (True, 'root ::= "a"{5,5}', 6, "a{5,5}"),
  210. ])
  211. def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re_content: str):
  212. global server
  213. server.jinja = jinja
  214. server.start()
  215. res = server.make_request("POST", "/chat/completions", data={
  216. "max_tokens": n_predicted,
  217. "messages": [
  218. {"role": "user", "content": "Does not matter what I say, does it?"},
  219. ],
  220. "grammar": grammar,
  221. })
  222. assert res.status_code == 200, res.body
  223. choice = res.body["choices"][0]
  224. assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"]
  225. @pytest.mark.parametrize("messages", [
  226. None,
  227. "string",
  228. [123],
  229. [{}],
  230. [{"role": 123}],
  231. [{"role": "system", "content": 123}],
  232. # [{"content": "hello"}], # TODO: should not be a valid case
  233. [{"role": "system", "content": "test"}, {}],
  234. [{"role": "user", "content": "test"}, {"role": "assistant", "content": "test"}, {"role": "assistant", "content": "test"}],
  235. ])
  236. def test_invalid_chat_completion_req(messages):
  237. global server
  238. server.start()
  239. res = server.make_request("POST", "/chat/completions", data={
  240. "messages": messages,
  241. })
  242. assert res.status_code == 400 or res.status_code == 500
  243. assert "error" in res.body
  244. def test_chat_completion_with_timings_per_token():
  245. global server
  246. server.start()
  247. res = server.make_stream_request("POST", "/chat/completions", data={
  248. "max_tokens": 10,
  249. "messages": [{"role": "user", "content": "test"}],
  250. "stream": True,
  251. "stream_options": {"include_usage": True},
  252. "timings_per_token": True,
  253. })
  254. stats_received = False
  255. for i, data in enumerate(res):
  256. if i == 0:
  257. # Check first role message for stream=True
  258. assert data["choices"][0]["delta"]["content"] is None
  259. assert data["choices"][0]["delta"]["role"] == "assistant"
  260. assert "timings" not in data, f'First event should not have timings: {data}'
  261. else:
  262. if data["choices"]:
  263. assert "role" not in data["choices"][0]["delta"]
  264. else:
  265. assert "timings" in data
  266. assert "prompt_per_second" in data["timings"]
  267. assert "predicted_per_second" in data["timings"]
  268. assert "predicted_n" in data["timings"]
  269. assert data["timings"]["predicted_n"] <= 10
  270. stats_received = True
  271. assert stats_received
  272. def test_logprobs():
  273. global server
  274. server.start()
  275. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  276. res = client.chat.completions.create(
  277. model="gpt-3.5-turbo-instruct",
  278. temperature=0.0,
  279. messages=[
  280. {"role": "system", "content": "Book"},
  281. {"role": "user", "content": "What is the best book"},
  282. ],
  283. max_tokens=5,
  284. logprobs=True,
  285. top_logprobs=10,
  286. )
  287. output_text = res.choices[0].message.content
  288. aggregated_text = ''
  289. assert res.choices[0].logprobs is not None
  290. assert res.choices[0].logprobs.content is not None
  291. for token in res.choices[0].logprobs.content:
  292. aggregated_text += token.token
  293. assert token.logprob <= 0.0
  294. assert token.bytes is not None
  295. assert len(token.top_logprobs) > 0
  296. assert aggregated_text == output_text
  297. def test_logprobs_stream():
  298. global server
  299. server.start()
  300. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  301. res = client.chat.completions.create(
  302. model="gpt-3.5-turbo-instruct",
  303. temperature=0.0,
  304. messages=[
  305. {"role": "system", "content": "Book"},
  306. {"role": "user", "content": "What is the best book"},
  307. ],
  308. max_tokens=5,
  309. logprobs=True,
  310. top_logprobs=10,
  311. stream=True,
  312. )
  313. output_text = ''
  314. aggregated_text = ''
  315. for i, data in enumerate(res):
  316. if data.choices:
  317. choice = data.choices[0]
  318. if i == 0:
  319. # Check first role message for stream=True
  320. assert choice.delta.content is None
  321. assert choice.delta.role == "assistant"
  322. else:
  323. assert choice.delta.role is None
  324. if choice.finish_reason is None:
  325. if choice.delta.content:
  326. output_text += choice.delta.content
  327. assert choice.logprobs is not None
  328. assert choice.logprobs.content is not None
  329. for token in choice.logprobs.content:
  330. aggregated_text += token.token
  331. assert token.logprob <= 0.0
  332. assert token.bytes is not None
  333. assert token.top_logprobs is not None
  334. assert len(token.top_logprobs) > 0
  335. assert aggregated_text == output_text
  336. def test_logit_bias():
  337. global server
  338. server.start()
  339. exclude = ["i", "I", "the", "The", "to", "a", "an", "be", "is", "was", "but", "But", "and", "And", "so", "So", "you", "You", "he", "He", "she", "She", "we", "We", "they", "They", "it", "It", "his", "His", "her", "Her", "book", "Book"]
  340. res = server.make_request("POST", "/tokenize", data={
  341. "content": " " + " ".join(exclude) + " ",
  342. })
  343. assert res.status_code == 200
  344. tokens = res.body["tokens"]
  345. logit_bias = {tok: -100 for tok in tokens}
  346. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  347. res = client.chat.completions.create(
  348. model="gpt-3.5-turbo-instruct",
  349. temperature=0.0,
  350. messages=[
  351. {"role": "system", "content": "Book"},
  352. {"role": "user", "content": "What is the best book"},
  353. ],
  354. max_tokens=64,
  355. logit_bias=logit_bias
  356. )
  357. output_text = res.choices[0].message.content
  358. assert output_text
  359. assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude)
  360. def test_context_size_exceeded():
  361. global server
  362. server.start()
  363. res = server.make_request("POST", "/chat/completions", data={
  364. "messages": [
  365. {"role": "system", "content": "Book"},
  366. {"role": "user", "content": "What is the best book"},
  367. ] * 100, # make the prompt too long
  368. })
  369. assert res.status_code == 400
  370. assert "error" in res.body
  371. assert res.body["error"]["type"] == "exceed_context_size_error"
  372. assert res.body["error"]["n_prompt_tokens"] > 0
  373. assert server.n_ctx is not None
  374. assert server.n_slots is not None
  375. assert res.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
  376. def test_context_size_exceeded_stream():
  377. global server
  378. server.start()
  379. try:
  380. for _ in server.make_stream_request("POST", "/chat/completions", data={
  381. "messages": [
  382. {"role": "system", "content": "Book"},
  383. {"role": "user", "content": "What is the best book"},
  384. ] * 100, # make the prompt too long
  385. "stream": True}):
  386. pass
  387. assert False, "Should have failed"
  388. except ServerError as e:
  389. assert e.code == 400
  390. assert "error" in e.body
  391. assert e.body["error"]["type"] == "exceed_context_size_error"
  392. assert e.body["error"]["n_prompt_tokens"] > 0
  393. assert server.n_ctx is not None
  394. assert server.n_slots is not None
  395. assert e.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
  396. @pytest.mark.parametrize(
  397. "n_batch,batch_count,reuse_cache",
  398. [
  399. (64, 3, False),
  400. (64, 1, True),
  401. ]
  402. )
  403. def test_return_progress(n_batch, batch_count, reuse_cache):
  404. global server
  405. server.n_batch = n_batch
  406. server.n_ctx = 256
  407. server.n_slots = 1
  408. server.start()
  409. def make_cmpl_request():
  410. return server.make_stream_request("POST", "/chat/completions", data={
  411. "max_tokens": 10,
  412. "messages": [
  413. {"role": "user", "content": "This is a test" * 10},
  414. ],
  415. "stream": True,
  416. "return_progress": True,
  417. })
  418. if reuse_cache:
  419. # make a first request to populate the cache
  420. res0 = make_cmpl_request()
  421. for _ in res0:
  422. pass # discard the output
  423. res = make_cmpl_request()
  424. last_progress = None
  425. total_batch_count = 0
  426. for data in res:
  427. cur_progress = data.get("prompt_progress", None)
  428. if cur_progress is None:
  429. continue
  430. if last_progress is not None:
  431. assert cur_progress["total"] == last_progress["total"]
  432. assert cur_progress["cache"] == last_progress["cache"]
  433. assert cur_progress["processed"] > last_progress["processed"]
  434. total_batch_count += 1
  435. last_progress = cur_progress
  436. assert last_progress is not None
  437. assert last_progress["total"] > 0
  438. assert last_progress["processed"] == last_progress["total"]
  439. assert total_batch_count == batch_count