test_completion.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608
  1. import pytest
  2. import requests
  3. import time
  4. import random
  5. from openai import OpenAI
  6. from utils import *
  7. server = ServerPreset.tinyllama2()
  8. JSON_MULTIMODAL_KEY = "multimodal_data"
  9. JSON_PROMPT_STRING_KEY = "prompt_string"
  10. @pytest.fixture(autouse=True)
  11. def create_server():
  12. global server
  13. server = ServerPreset.tinyllama2()
  14. @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [
  15. ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False),
  16. ("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True),
  17. ])
  18. def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool):
  19. global server
  20. server.start()
  21. res = server.make_request("POST", "/completion", data={
  22. "n_predict": n_predict,
  23. "prompt": prompt,
  24. "return_tokens": return_tokens,
  25. })
  26. assert res.status_code == 200
  27. assert res.body["timings"]["prompt_n"] == n_prompt
  28. assert res.body["timings"]["predicted_n"] == n_predicted
  29. assert res.body["truncated"] == truncated
  30. assert type(res.body["has_new_line"]) == bool
  31. assert match_regex(re_content, res.body["content"])
  32. if return_tokens:
  33. assert len(res.body["tokens"]) > 0
  34. assert all(type(tok) == int for tok in res.body["tokens"])
  35. else:
  36. assert res.body["tokens"] == []
  37. @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
  38. ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
  39. ("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
  40. ])
  41. def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
  42. global server
  43. server.start()
  44. res = server.make_stream_request("POST", "/completion", data={
  45. "n_predict": n_predict,
  46. "prompt": prompt,
  47. "stream": True,
  48. })
  49. content = ""
  50. for data in res:
  51. assert "stop" in data and type(data["stop"]) == bool
  52. if data["stop"]:
  53. assert data["timings"]["prompt_n"] == n_prompt
  54. assert data["timings"]["predicted_n"] == n_predicted
  55. assert data["truncated"] == truncated
  56. assert data["stop_type"] == "limit"
  57. assert type(data["has_new_line"]) == bool
  58. assert "generation_settings" in data
  59. assert server.n_predict is not None
  60. assert data["generation_settings"]["n_predict"] == min(n_predict, server.n_predict)
  61. assert data["generation_settings"]["seed"] == server.seed
  62. assert match_regex(re_content, content)
  63. else:
  64. assert len(data["tokens"]) > 0
  65. assert all(type(tok) == int for tok in data["tokens"])
  66. content += data["content"]
  67. def test_completion_stream_vs_non_stream():
  68. global server
  69. server.start()
  70. res_stream = server.make_stream_request("POST", "/completion", data={
  71. "n_predict": 8,
  72. "prompt": "I believe the meaning of life is",
  73. "stream": True,
  74. })
  75. res_non_stream = server.make_request("POST", "/completion", data={
  76. "n_predict": 8,
  77. "prompt": "I believe the meaning of life is",
  78. })
  79. content_stream = ""
  80. for data in res_stream:
  81. content_stream += data["content"]
  82. assert content_stream == res_non_stream.body["content"]
  83. def test_completion_with_openai_library():
  84. global server
  85. server.start()
  86. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  87. res = client.completions.create(
  88. model="davinci-002",
  89. prompt="I believe the meaning of life is",
  90. max_tokens=8,
  91. )
  92. assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
  93. assert res.choices[0].finish_reason == "length"
  94. assert res.choices[0].text is not None
  95. assert match_regex("(going|bed)+", res.choices[0].text)
  96. def test_completion_stream_with_openai_library():
  97. global server
  98. server.start()
  99. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  100. res = client.completions.create(
  101. model="davinci-002",
  102. prompt="I believe the meaning of life is",
  103. max_tokens=8,
  104. stream=True,
  105. )
  106. output_text = ''
  107. for data in res:
  108. choice = data.choices[0]
  109. if choice.finish_reason is None:
  110. assert choice.text is not None
  111. output_text += choice.text
  112. assert match_regex("(going|bed)+", output_text)
  113. # Test case from https://github.com/ggml-org/llama.cpp/issues/13780
  114. @pytest.mark.slow
  115. def test_completion_stream_with_openai_library_stops():
  116. global server
  117. server.model_hf_repo = "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M"
  118. server.model_hf_file = None
  119. server.start()
  120. client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
  121. res = client.completions.create(
  122. model="davinci-002",
  123. prompt="System: You are helpfull assistant.\nAssistant:\nHey! How could I help?\nUser:\nTell me a joke.\nAssistant:\n",
  124. stop=["User:\n", "Assistant:\n"],
  125. max_tokens=200,
  126. stream=True,
  127. )
  128. output_text = ''
  129. for data in res:
  130. choice = data.choices[0]
  131. if choice.finish_reason is None:
  132. assert choice.text is not None
  133. output_text += choice.text
  134. assert match_regex("Sure, here's one for[\\s\\S]*", output_text), f'Unexpected output: {output_text}'
  135. @pytest.mark.parametrize("n_slots", [1, 2])
  136. def test_consistent_result_same_seed(n_slots: int):
  137. global server
  138. server.n_slots = n_slots
  139. server.start()
  140. last_res = None
  141. for _ in range(4):
  142. res = server.make_request("POST", "/completion", data={
  143. "prompt": "I believe the meaning of life is",
  144. "seed": 42,
  145. "temperature": 0.0,
  146. "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
  147. })
  148. if last_res is not None:
  149. assert res.body["content"] == last_res.body["content"]
  150. last_res = res
  151. @pytest.mark.parametrize("n_slots", [1, 2])
  152. def test_different_result_different_seed(n_slots: int):
  153. global server
  154. server.n_slots = n_slots
  155. server.start()
  156. last_res = None
  157. for seed in range(4):
  158. res = server.make_request("POST", "/completion", data={
  159. "prompt": "I believe the meaning of life is",
  160. "seed": seed,
  161. "temperature": 1.0,
  162. "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
  163. })
  164. if last_res is not None:
  165. assert res.body["content"] != last_res.body["content"]
  166. last_res = res
  167. # TODO figure why it don't work with temperature = 1
  168. # @pytest.mark.parametrize("temperature", [0.0, 1.0])
  169. @pytest.mark.parametrize("n_batch", [16, 32])
  170. @pytest.mark.parametrize("temperature", [0.0])
  171. def test_consistent_result_different_batch_size(n_batch: int, temperature: float):
  172. global server
  173. server.n_batch = n_batch
  174. server.start()
  175. last_res = None
  176. for _ in range(4):
  177. res = server.make_request("POST", "/completion", data={
  178. "prompt": "I believe the meaning of life is",
  179. "seed": 42,
  180. "temperature": temperature,
  181. "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
  182. })
  183. if last_res is not None:
  184. assert res.body["content"] == last_res.body["content"]
  185. last_res = res
  186. @pytest.mark.skip(reason="This test fails on linux, need to be fixed")
  187. def test_cache_vs_nocache_prompt():
  188. global server
  189. server.start()
  190. res_cache = server.make_request("POST", "/completion", data={
  191. "prompt": "I believe the meaning of life is",
  192. "seed": 42,
  193. "temperature": 1.0,
  194. "cache_prompt": True,
  195. })
  196. res_no_cache = server.make_request("POST", "/completion", data={
  197. "prompt": "I believe the meaning of life is",
  198. "seed": 42,
  199. "temperature": 1.0,
  200. "cache_prompt": False,
  201. })
  202. assert res_cache.body["content"] == res_no_cache.body["content"]
  203. def test_nocache_long_input_prompt():
  204. global server
  205. server.start()
  206. res = server.make_request("POST", "/completion", data={
  207. "prompt": "I believe the meaning of life is"*32,
  208. "seed": 42,
  209. "temperature": 1.0,
  210. "cache_prompt": False,
  211. })
  212. assert res.status_code == 400
  213. def test_json_prompt_no_mtmd():
  214. global server
  215. server.start()
  216. res = server.make_request("POST", "/completion", data={
  217. "prompt": { JSON_PROMPT_STRING_KEY: "I believe the meaning of life is" },
  218. "seed": 42,
  219. "temperature": 1.0,
  220. "cache_prompt": False,
  221. })
  222. assert res.status_code == 200
  223. def test_json_prompt_mtm_error_when_not_supported():
  224. global server
  225. server.start()
  226. res = server.make_request("POST", "/completion", data={
  227. "prompt": { JSON_PROMPT_STRING_KEY: "I believe the meaning of life is <__media__>", JSON_MULTIMODAL_KEY: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=" },
  228. "seed": 42,
  229. "temperature": 1.0,
  230. "cache_prompt": False,
  231. })
  232. # MTMD is disabled on this model, so this should fail.
  233. assert res.status_code != 200
  234. def test_completion_with_tokens_input():
  235. global server
  236. server.temperature = 0.0
  237. server.start()
  238. prompt_str = "I believe the meaning of life is"
  239. res = server.make_request("POST", "/tokenize", data={
  240. "content": prompt_str,
  241. "add_special": True,
  242. })
  243. assert res.status_code == 200
  244. tokens = res.body["tokens"]
  245. # single completion
  246. res = server.make_request("POST", "/completion", data={
  247. "prompt": tokens,
  248. })
  249. assert res.status_code == 200
  250. assert type(res.body["content"]) == str
  251. # batch completion
  252. res = server.make_request("POST", "/completion", data={
  253. "prompt": [tokens, tokens],
  254. })
  255. assert res.status_code == 200
  256. assert type(res.body) == list
  257. assert len(res.body) == 2
  258. assert res.body[0]["content"] == res.body[1]["content"]
  259. # mixed string and tokens
  260. res = server.make_request("POST", "/completion", data={
  261. "prompt": [tokens, prompt_str],
  262. })
  263. assert res.status_code == 200
  264. assert type(res.body) == list
  265. assert len(res.body) == 2
  266. assert res.body[0]["content"] == res.body[1]["content"]
  267. # mixed JSON and tokens
  268. res = server.make_request("POST", "/completion", data={
  269. "prompt": [
  270. tokens,
  271. {
  272. JSON_PROMPT_STRING_KEY: "I believe the meaning of life is",
  273. },
  274. ],
  275. })
  276. assert res.status_code == 200
  277. assert type(res.body) == list
  278. assert len(res.body) == 2
  279. assert res.body[0]["content"] == res.body[1]["content"]
  280. # mixed string and tokens in one sequence
  281. res = server.make_request("POST", "/completion", data={
  282. "prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str],
  283. })
  284. assert res.status_code == 200
  285. assert type(res.body["content"]) == str
  286. @pytest.mark.parametrize("n_slots,n_requests", [
  287. (1, 3),
  288. (2, 2),
  289. (2, 4),
  290. (4, 2), # some slots must be idle
  291. (4, 6),
  292. ])
  293. def test_completion_parallel_slots(n_slots: int, n_requests: int):
  294. global server
  295. server.n_slots = n_slots
  296. server.temperature = 0.0
  297. server.start()
  298. PROMPTS = [
  299. ("Write a very long book.", "(very|special|big)+"),
  300. ("Write another a poem.", "(small|house)+"),
  301. ("What is LLM?", "(Dad|said)+"),
  302. ("The sky is blue and I love it.", "(climb|leaf)+"),
  303. ("Write another very long music lyrics.", "(friends|step|sky)+"),
  304. ("Write a very long joke.", "(cat|Whiskers)+"),
  305. ]
  306. def check_slots_status():
  307. should_all_slots_busy = n_requests >= n_slots
  308. time.sleep(0.1)
  309. res = server.make_request("GET", "/slots")
  310. n_busy = sum([1 for slot in res.body if slot["is_processing"]])
  311. if should_all_slots_busy:
  312. assert n_busy == n_slots
  313. else:
  314. assert n_busy <= n_slots
  315. tasks = []
  316. for i in range(n_requests):
  317. prompt, re_content = PROMPTS[i % len(PROMPTS)]
  318. tasks.append((server.make_request, ("POST", "/completion", {
  319. "prompt": prompt,
  320. "seed": 42,
  321. "temperature": 1.0,
  322. })))
  323. tasks.append((check_slots_status, ()))
  324. results = parallel_function_calls(tasks)
  325. # check results
  326. for i in range(n_requests):
  327. prompt, re_content = PROMPTS[i % len(PROMPTS)]
  328. res = results[i]
  329. assert res.status_code == 200
  330. assert type(res.body["content"]) == str
  331. assert len(res.body["content"]) > 10
  332. # FIXME: the result is not deterministic when using other slot than slot 0
  333. # assert match_regex(re_content, res.body["content"])
  334. @pytest.mark.parametrize(
  335. "n_ctx,n_slots,n_predict_vals,expected_success",
  336. [
  337. (256, 4, [80, 40, 80, 80], [True, True, True, True]),
  338. (256, 4, [70, 70, 70, 70], [False, False, False, False]),
  339. (256, 4, [90, 90, 40, 90], [False, False, True, False]),
  340. (256, 4, [90, 90, 40, 75], [True, True, True, True]),
  341. ],
  342. )
  343. def test_completion_unified(n_ctx, n_slots, n_predict_vals, expected_success):
  344. global server
  345. server.n_slots = n_slots
  346. server.kv_unified = True
  347. server.n_ctx = n_ctx
  348. server.start()
  349. prompt = "A"
  350. tasks = []
  351. for n_predict in n_predict_vals:
  352. tasks.append((server.make_request, ("POST", "/completion", {"prompt": prompt, "n_predict": n_predict})))
  353. results = parallel_function_calls(tasks)
  354. for res, n_predict, expect_ok in zip(results, n_predict_vals, expected_success):
  355. if expect_ok:
  356. assert res.status_code == 200
  357. assert "content" in res.body
  358. if "timings" in res.body:
  359. assert res.body["timings"]["predicted_n"] == n_predict
  360. else:
  361. assert res.status_code == 500
  362. assert "content" not in res.body
  363. @pytest.mark.parametrize(
  364. "prompt,n_predict,response_fields",
  365. [
  366. ("I believe the meaning of life is", 8, []),
  367. ("I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"]),
  368. ],
  369. )
  370. def test_completion_response_fields(
  371. prompt: str, n_predict: int, response_fields: list[str]
  372. ):
  373. global server
  374. server.start()
  375. res = server.make_request(
  376. "POST",
  377. "/completion",
  378. data={
  379. "n_predict": n_predict,
  380. "prompt": prompt,
  381. "response_fields": response_fields,
  382. },
  383. )
  384. assert res.status_code == 200
  385. assert "content" in res.body
  386. assert len(res.body["content"])
  387. if len(response_fields):
  388. assert res.body["generation_settings/n_predict"] == n_predict
  389. assert res.body["prompt"] == "<s> " + prompt
  390. assert isinstance(res.body["content"], str)
  391. assert len(res.body) == len(response_fields)
  392. else:
  393. assert len(res.body)
  394. assert "generation_settings" in res.body
  395. def test_n_probs():
  396. global server
  397. server.start()
  398. res = server.make_request("POST", "/completion", data={
  399. "prompt": "I believe the meaning of life is",
  400. "n_probs": 10,
  401. "temperature": 0.0,
  402. "n_predict": 5,
  403. })
  404. assert res.status_code == 200
  405. assert "completion_probabilities" in res.body
  406. assert len(res.body["completion_probabilities"]) == 5
  407. for tok in res.body["completion_probabilities"]:
  408. assert "id" in tok and tok["id"] > 0
  409. assert "token" in tok and type(tok["token"]) == str
  410. assert "logprob" in tok and tok["logprob"] <= 0.0
  411. assert "bytes" in tok and type(tok["bytes"]) == list
  412. assert len(tok["top_logprobs"]) == 10
  413. for prob in tok["top_logprobs"]:
  414. assert "id" in prob and prob["id"] > 0
  415. assert "token" in prob and type(prob["token"]) == str
  416. assert "logprob" in prob and prob["logprob"] <= 0.0
  417. assert "bytes" in prob and type(prob["bytes"]) == list
  418. def test_n_probs_stream():
  419. global server
  420. server.start()
  421. res = server.make_stream_request("POST", "/completion", data={
  422. "prompt": "I believe the meaning of life is",
  423. "n_probs": 10,
  424. "temperature": 0.0,
  425. "n_predict": 5,
  426. "stream": True,
  427. })
  428. for data in res:
  429. if data["stop"] == False:
  430. assert "completion_probabilities" in data
  431. assert len(data["completion_probabilities"]) == 1
  432. for tok in data["completion_probabilities"]:
  433. assert "id" in tok and tok["id"] > 0
  434. assert "token" in tok and type(tok["token"]) == str
  435. assert "logprob" in tok and tok["logprob"] <= 0.0
  436. assert "bytes" in tok and type(tok["bytes"]) == list
  437. assert len(tok["top_logprobs"]) == 10
  438. for prob in tok["top_logprobs"]:
  439. assert "id" in prob and prob["id"] > 0
  440. assert "token" in prob and type(prob["token"]) == str
  441. assert "logprob" in prob and prob["logprob"] <= 0.0
  442. assert "bytes" in prob and type(prob["bytes"]) == list
  443. def test_n_probs_post_sampling():
  444. global server
  445. server.start()
  446. res = server.make_request("POST", "/completion", data={
  447. "prompt": "I believe the meaning of life is",
  448. "n_probs": 10,
  449. "temperature": 0.0,
  450. "n_predict": 5,
  451. "post_sampling_probs": True,
  452. })
  453. assert res.status_code == 200
  454. assert "completion_probabilities" in res.body
  455. assert len(res.body["completion_probabilities"]) == 5
  456. for tok in res.body["completion_probabilities"]:
  457. assert "id" in tok and tok["id"] > 0
  458. assert "token" in tok and type(tok["token"]) == str
  459. assert "prob" in tok and 0.0 < tok["prob"] <= 1.0
  460. assert "bytes" in tok and type(tok["bytes"]) == list
  461. assert len(tok["top_probs"]) == 10
  462. for prob in tok["top_probs"]:
  463. assert "id" in prob and prob["id"] > 0
  464. assert "token" in prob and type(prob["token"]) == str
  465. assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0
  466. assert "bytes" in prob and type(prob["bytes"]) == list
  467. # because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
  468. assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
  469. @pytest.mark.parametrize("tokenize,openai_style", [(False, False), (False, True), (True, False), (True, True)])
  470. def test_logit_bias(tokenize, openai_style):
  471. global server
  472. server.start()
  473. exclude = ["i", "I", "the", "The", "to", "a", "an", "be", "is", "was", "but", "But", "and", "And", "so", "So", "you", "You", "he", "He", "she", "She", "we", "We", "they", "They", "it", "It", "his", "His", "her", "Her", "book", "Book"]
  474. logit_bias = []
  475. if tokenize:
  476. res = server.make_request("POST", "/tokenize", data={
  477. "content": " " + " ".join(exclude) + " ",
  478. })
  479. assert res.status_code == 200
  480. tokens = res.body["tokens"]
  481. logit_bias = [[tok, -100] for tok in tokens]
  482. else:
  483. logit_bias = [[" " + tok + " ", -100] for tok in exclude]
  484. if openai_style:
  485. logit_bias = {el[0]: -100 for el in logit_bias}
  486. res = server.make_request("POST", "/completion", data={
  487. "n_predict": 64,
  488. "prompt": "What is the best book",
  489. "logit_bias": logit_bias,
  490. "temperature": 0.0
  491. })
  492. assert res.status_code == 200
  493. output_text = res.body["content"]
  494. assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude)
  495. def test_cancel_request():
  496. global server
  497. server.n_ctx = 4096
  498. server.n_predict = -1
  499. server.n_slots = 1
  500. server.server_slots = True
  501. server.start()
  502. # send a request that will take a long time, but cancel it before it finishes
  503. try:
  504. server.make_request("POST", "/completion", data={
  505. "prompt": "I believe the meaning of life is",
  506. }, timeout=0.1)
  507. except requests.exceptions.ReadTimeout:
  508. pass # expected
  509. # make sure the slot is free
  510. time.sleep(1) # wait for HTTP_POLLING_SECONDS
  511. res = server.make_request("GET", "/slots")
  512. assert res.body[0]["is_processing"] == False
  513. # this test exercises the host-memory prompt cache
  514. # ref: https://github.com/ggml-org/llama.cpp/pull/16391
  515. # ref: https://github.com/ggml-org/llama.cpp/pull/17078
  516. def test_completion_prompt_cache():
  517. global server
  518. server.n_slots = 2
  519. server.kv_unified = True
  520. server.start()
  521. for _ in range(16):
  522. # generate alternating random prompts with variable lengths in order to get them in and out of the cache
  523. r = random.randint(0, 4)
  524. prompt = (" Hello " + str(r)) * (40 + r)
  525. n_prompt = (40 + r)*5 + 2
  526. n_predict = random.randint(1, 8)
  527. res = server.make_request(
  528. "POST",
  529. "/completion",
  530. data={
  531. "prompt": prompt,
  532. "n_predict": n_predict,
  533. },
  534. )
  535. assert res.status_code == 200
  536. assert "content" in res.body
  537. content = res.body["content"]
  538. assert isinstance(content, str)
  539. assert len(content) > 0
  540. assert type(res.body["has_new_line"]) == bool
  541. assert "timings" in res.body
  542. timings = res.body["timings"]
  543. assert "prompt_n" in timings and timings["prompt_n"] + timings["cache_n"] == n_prompt
  544. assert "predicted_n" in timings and timings["predicted_n"] == n_predict
  545. assert "tokens" in res.body and isinstance(res.body["tokens"], list)