server-bench.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. #!/usr/bin/env python3
  2. import argparse
  3. import json
  4. import os
  5. import random
  6. import sqlite3
  7. import subprocess
  8. from time import sleep, time
  9. from typing import Optional, Union
  10. import datasets
  11. import logging
  12. import matplotlib.pyplot as plt
  13. import numpy as np
  14. import requests
  15. from tqdm.contrib.concurrent import thread_map
  16. logging.basicConfig(level=logging.INFO, format='%(message)s')
  17. logger = logging.getLogger("server-bench")
  18. def get_prompts_text(dataset_name: str, n_prompts: int) -> Optional[list[str]]:
  19. ret = []
  20. if dataset_name.lower() == "mmlu":
  21. logger.info("Loading MMLU dataset...")
  22. ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"] # type: ignore
  23. else:
  24. return None
  25. if n_prompts >= 0:
  26. ret = ret[:n_prompts]
  27. return ret
  28. def get_prompt_lengths_rng(n_prompts: int, prompt_length_min: int, prompt_length_max: int, seed_offset: int) -> list[int]:
  29. assert n_prompts >= 0
  30. ret: list[int] = []
  31. for i in range(n_prompts):
  32. if seed_offset >= 0:
  33. random.seed(3 * (seed_offset + 1000 * i) + 0)
  34. ret.append(random.randint(prompt_length_min, prompt_length_max))
  35. return ret
  36. def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]:
  37. return [[random.randint(100, 10000) for _ in range(pl)] for pl in prompt_lengths]
  38. def get_server(path_server: str, path_log: Optional[str]) -> dict:
  39. if path_server.startswith("http://") or path_server.startswith("https://"):
  40. return {"process": None, "address": path_server, "fout": None}
  41. if os.environ.get("LLAMA_ARG_HOST") is None:
  42. logger.info("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1")
  43. os.environ["LLAMA_ARG_HOST"] = "127.0.0.1"
  44. if os.environ.get("LLAMA_ARG_PORT") is None:
  45. logger.info("LLAMA_ARG_PORT not explicitly set, using 8080")
  46. os.environ["LLAMA_ARG_PORT"] = "8080"
  47. hostname: Optional[str] = os.environ.get("LLAMA_ARG_HOST")
  48. port: Optional[str] = os.environ.get("LLAMA_ARG_PORT")
  49. assert hostname is not None
  50. assert port is not None
  51. address: str = f"http://{hostname}:{port}"
  52. logger.info(f"Starting the llama.cpp server under {address}...")
  53. fout = open(path_log.format(port=port), "w") if path_log is not None else subprocess.DEVNULL
  54. process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT)
  55. n_failures: int = 0
  56. while True:
  57. try:
  58. sleep(1.0)
  59. exit_code = process.poll()
  60. if exit_code is not None:
  61. raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}{path_log and f', see {path_log.format(port=port)}' or ''}")
  62. response = requests.get(f"{address}/health")
  63. if response.status_code == 200:
  64. break
  65. except requests.ConnectionError:
  66. n_failures += 1
  67. if n_failures >= 10:
  68. raise RuntimeError("llama.cpp server is not healthy after 10 seconds")
  69. return {"process": process, "address": address, "fout": fout}
  70. def get_prompt_length(data: dict) -> int:
  71. session = data["session"]
  72. server_address: str = data["server_address"]
  73. response = session.post(
  74. f"{server_address}/apply-template",
  75. json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
  76. )
  77. response.raise_for_status()
  78. prompt: str = json.loads(response.text)["prompt"]
  79. response = session.post(
  80. f"{server_address}/tokenize",
  81. json={"content": prompt, "add_special": True}
  82. )
  83. response.raise_for_status()
  84. tokens: list[str] = json.loads(response.text)["tokens"]
  85. return len(tokens)
  86. def send_prompt(data: dict) -> tuple[float, list[float]]:
  87. session = data["session"]
  88. server_address: str = data["server_address"]
  89. t_submit = time()
  90. if data["external_server"]:
  91. json_data: dict = {
  92. "prompt": data["prompt"], "ignore_eos": True,
  93. "seed": data["seed"], "max_tokens": data["n_predict"], "stream": True}
  94. response = session.post(f"{server_address}/v1/completions", json=json_data, stream=True)
  95. elif data["synthetic_prompt"]:
  96. json_data: dict = {
  97. "prompt": data["prompt"], "ignore_eos": True, "cache_prompt": False,
  98. "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
  99. response = session.post(f"{server_address}/completion", json=json_data, stream=True)
  100. else:
  101. response = session.post(
  102. f"{server_address}/apply-template",
  103. json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
  104. )
  105. response.raise_for_status()
  106. prompt: str = json.loads(response.text)["prompt"]
  107. json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
  108. response = session.post(f"{server_address}/completion", json=json_data, stream=True)
  109. response.raise_for_status()
  110. lines = []
  111. token_arrival_times: list[float] = []
  112. for line in response.iter_lines(decode_unicode=False):
  113. if not line.startswith(b"data: "):
  114. continue
  115. lines.append(line)
  116. token_arrival_times.append(time())
  117. token_arrival_times = token_arrival_times[:-1]
  118. if len(lines) > 1 and "timings" in json.loads(lines[-2][6:]):
  119. token_arrival_times = token_arrival_times[:-1]
  120. return (t_submit, token_arrival_times)
  121. def benchmark(
  122. path_server: str, path_log: Optional[str], path_db: Optional[str], name: Optional[str], prompt_source: str, n_prompts: int,
  123. n_predict: int, n_predict_min: int, seed_offset: int):
  124. external_server: bool = path_server.startswith("http://") or path_server.startswith("https://")
  125. if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
  126. logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
  127. os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
  128. if not external_server and os.environ.get("LLAMA_ARG_N_GPU_LAYERS") is None:
  129. logger.info("LLAMA_ARG_N_GPU_LAYERS not explicitly set, using 999")
  130. os.environ["LLAMA_ARG_N_GPU_LAYERS"] = "999"
  131. if not external_server and os.environ.get("LLAMA_ARG_FLASH_ATTN") is None:
  132. logger.info("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'")
  133. os.environ["LLAMA_ARG_FLASH_ATTN"] = "true"
  134. parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore
  135. prompts: Union[None, list[str], list[list[int]]] = get_prompts_text(prompt_source, n_prompts)
  136. synthetic_prompts: bool = prompts is None
  137. prompt_n = []
  138. if synthetic_prompts:
  139. prompt_source_split: list[str] = prompt_source.split("-")
  140. assert len(prompt_source_split) == 3
  141. assert prompt_source_split[0].lower() == "rng"
  142. prompt_length_min: int = int(prompt_source_split[1])
  143. prompt_length_max: int = int(prompt_source_split[2])
  144. logger.info("Generating random prompts...")
  145. prompt_n = get_prompt_lengths_rng(n_prompts, prompt_length_min, prompt_length_max, seed_offset)
  146. prompts = get_prompts_rng(prompt_n)
  147. else:
  148. n_predict_min = n_predict
  149. if not external_server and os.environ.get("LLAMA_ARG_CTX_SIZE") is None:
  150. context_per_slot: int = int(1.05 * (n_predict + (np.max(prompt_n) if synthetic_prompts else 2048)))
  151. context_total: int = context_per_slot * parallel
  152. os.environ["LLAMA_ARG_CTX_SIZE"] = str(context_total)
  153. logger.info(f"LLAMA_ARG_CTX_SIZE not explicitly set, using {context_total} ({context_per_slot} per slot).")
  154. server: Optional[dict] = None
  155. session = None
  156. try:
  157. server = get_server(path_server, path_log)
  158. server_address: str = server["address"]
  159. assert external_server == (server["process"] is None)
  160. adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore
  161. session = requests.Session()
  162. session.mount("http://", adapter)
  163. session.mount("https://", adapter)
  164. data: list[dict] = []
  165. for i, p in enumerate(prompts):
  166. if seed_offset >= 0:
  167. random.seed(3 * (seed_offset + 1000 * i) + 1)
  168. data.append({
  169. "session": session, "server_address": server_address, "external_server": external_server, "prompt": p,
  170. "synthetic_prompt": synthetic_prompts, "n_predict": random.randint(n_predict_min, n_predict),
  171. "seed": (3 * (seed_offset + 1000 * i) + 2) if seed_offset >= 0 else -1})
  172. if not synthetic_prompts:
  173. logger.info("Getting the prompt lengths...")
  174. prompt_n = [get_prompt_length(d) for d in data]
  175. logger.info("Starting the benchmark...\n")
  176. t0 = time()
  177. results: list[tuple[float, list[float]]] = thread_map(send_prompt, data, max_workers=parallel, chunksize=1)
  178. finally:
  179. if server is not None and server["process"] is not None:
  180. server["process"].terminate()
  181. server["process"].wait()
  182. if session is not None:
  183. session.close()
  184. prompt_t = []
  185. token_t = []
  186. depth_sum: int = 0
  187. for pn, (t_submit, tat) in zip(prompt_n, results):
  188. prompt_t.append(tat[0] - t_submit)
  189. token_t += tat
  190. n_tokens: int = len(tat)
  191. depth_sum += n_tokens * pn
  192. depth_sum += n_tokens * (n_tokens + 1) // 2
  193. assert len(token_t) > 0
  194. prompt_n = np.array(prompt_n, dtype=np.int64)
  195. prompt_t = np.array(prompt_t, dtype=np.float64)
  196. token_t = np.array(token_t, dtype=np.float64)
  197. token_t -= t0
  198. token_t_last = np.max(token_t)
  199. logger.info("")
  200. logger.info(f"Benchmark duration: {token_t_last:.2f} s")
  201. logger.info(f"Request throughput: {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min")
  202. logger.info(f"Total prompt length: {np.sum(prompt_n)} tokens")
  203. logger.info(f"Average prompt length: {np.mean(prompt_n):.2f} tokens")
  204. logger.info(f"Average prompt latency: {1e3 * np.mean(prompt_t):.2f} ms")
  205. logger.info(f"Average prompt speed: {np.sum(prompt_n) / np.sum(prompt_t):.2f} tokens/s")
  206. logger.info(f"Total generated tokens: {token_t.shape[0]}")
  207. logger.info(f"Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens")
  208. logger.info(f"Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s")
  209. logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot")
  210. if path_db is not None:
  211. con = sqlite3.connect(path_db)
  212. cursor = con.cursor()
  213. cursor.execute(
  214. "CREATE TABLE IF NOT EXISTS server_bench"
  215. "(name TEXT, n_parallel INTEGER, prompt_source TEXT, n_prompts INTEGER, "
  216. "n_predict INTEGER, n_predict_min INTEGER, seed_offset INTEGER, runtime REAL);")
  217. cursor.execute(
  218. "INSERT INTO server_bench VALUES (?, ?, ?, ?, ?, ?, ?, ?);",
  219. [name, parallel, prompt_source, n_prompts, n_predict, n_predict_min, seed_offset, token_t_last])
  220. con.commit()
  221. plt.figure()
  222. plt.scatter(prompt_n, 1e3 * prompt_t, s=10.0, marker=".", alpha=0.25)
  223. plt.xlim(0, 1.05e0 * np.max(prompt_n))
  224. plt.ylim(0, 1.05e3 * np.max(prompt_t))
  225. plt.title(name or "")
  226. plt.xlabel("Prompt length [tokens]")
  227. plt.ylabel("Time to first token [ms]")
  228. plt.savefig("prompt_time.png", dpi=240)
  229. bin_max = np.ceil(token_t_last) + 1
  230. plt.figure()
  231. plt.hist(token_t, np.arange(0, bin_max))
  232. plt.xlim(0, bin_max + 1)
  233. plt.title(name or "")
  234. plt.xlabel("Time [s]")
  235. plt.ylabel("Num. tokens generated per second")
  236. plt.savefig("gen_rate.png", dpi=240)
  237. if __name__ == "__main__":
  238. parser = argparse.ArgumentParser(
  239. description="Tool for benchmarking the throughput of the llama.cpp HTTP server. "
  240. "Results are printed to console and visualized as plots (saved to current working directory). "
  241. "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help). "
  242. "The reported numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
  243. "particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).")
  244. parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary")
  245. parser.add_argument("--path_log", type=str, default="server-bench-{port}.log", help="Path to the model to use for the benchmark")
  246. parser.add_argument("--path_db", type=str, default=None, help="Path to an sqlite database to store the benchmark results in")
  247. parser.add_argument("--name", type=str, default=None, help="Name to label plots and database entries with")
  248. parser.add_argument(
  249. "--prompt_source", type=str, default="rng-1024-2048",
  250. help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "
  251. "rng-MIN-MAX for synthetic prompts with random lengths in the interval [MIN, MAX]")
  252. parser.add_argument("--n_prompts", type=int, default=100, help="Number of prompts to evaluate")
  253. parser.add_argument("--n_predict", type=int, default=2048, help="Max. number of tokens to predict per prompt")
  254. parser.add_argument(
  255. "--n_predict_min", type=int, default=1024,
  256. help="Min. number of tokens to predict per prompt (supported for synthetic prompts only)")
  257. parser.add_argument("--seed_offset", type=int, default=0, help="Offset for determining the seeds for pseudorandom prompt/generation lengths. "
  258. "Corelations between seeds can occur when set >= 1000. Negative values mean no seed.")
  259. args = parser.parse_args()
  260. benchmark(**vars(args))