server-bench.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. #!/usr/bin/env python3
  2. import argparse
  3. import json
  4. import os
  5. import random
  6. import subprocess
  7. from time import sleep, time
  8. from typing import Optional, Union
  9. import datasets
  10. import logging
  11. import matplotlib.pyplot as plt
  12. import numpy as np
  13. import requests
  14. from tqdm.contrib.concurrent import thread_map
  15. logging.basicConfig(level=logging.INFO, format='%(message)s')
  16. logger = logging.getLogger("server-bench")
  17. def get_prompts_text(dataset_name: str, n_prompts: int) -> Optional[list[str]]:
  18. ret = []
  19. if dataset_name.lower() == "mmlu":
  20. logger.info("Loading MMLU dataset...")
  21. ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"] # type: ignore
  22. else:
  23. return None
  24. if n_prompts >= 0:
  25. ret = ret[:n_prompts]
  26. return ret
  27. def get_prompt_lengths_rng(n_prompts: int, prompt_length_min: int, prompt_length_max: int, seed_offset: int) -> list[int]:
  28. assert n_prompts >= 0
  29. ret: list[int] = []
  30. for i in range(n_prompts):
  31. if seed_offset >= 0:
  32. random.seed(3 * (seed_offset + 1000 * i) + 0)
  33. ret.append(random.randint(prompt_length_min, prompt_length_max))
  34. return ret
  35. def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]:
  36. return [[random.randint(100, 10000) for _ in range(pl)] for pl in prompt_lengths]
  37. def get_server(path_server: str, path_log: Optional[str]) -> dict:
  38. if os.environ.get("LLAMA_ARG_HOST") is None:
  39. logger.info("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1")
  40. os.environ["LLAMA_ARG_HOST"] = "127.0.0.1"
  41. if os.environ.get("LLAMA_ARG_PORT") is None:
  42. logger.info("LLAMA_ARG_PORT not explicitly set, using 8080")
  43. os.environ["LLAMA_ARG_PORT"] = "8080"
  44. hostname: Optional[str] = os.environ.get("LLAMA_ARG_HOST")
  45. port: Optional[str] = os.environ.get("LLAMA_ARG_PORT")
  46. assert hostname is not None
  47. assert port is not None
  48. address: str = f"http://{hostname}:{port}"
  49. logger.info(f"Starting the llama.cpp server under {address}...")
  50. fout = open(path_log.format(port=port), "w") if path_log is not None else subprocess.DEVNULL
  51. process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT)
  52. n_failures: int = 0
  53. while True:
  54. try:
  55. sleep(1.0)
  56. exit_code = process.poll()
  57. if exit_code is not None:
  58. raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}{path_log and f', see {path_log.format(port=port)}' or ''}")
  59. response = requests.get(f"{address}/health")
  60. if response.status_code == 200:
  61. break
  62. except requests.ConnectionError:
  63. n_failures += 1
  64. if n_failures >= 10:
  65. raise RuntimeError("llama.cpp server is not healthy after 10 seconds")
  66. return {"process": process, "address": address, "fout": fout}
  67. def get_prompt_length(data: dict) -> int:
  68. session = data["session"]
  69. server_address: str = data["server_address"]
  70. response = session.post(
  71. f"{server_address}/apply-template",
  72. json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
  73. )
  74. if response.status_code != 200:
  75. raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
  76. prompt: str = json.loads(response.text)["prompt"]
  77. response = session.post(
  78. f"{server_address}/tokenize",
  79. json={"content": prompt, "add_special": True}
  80. )
  81. if response.status_code != 200:
  82. raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
  83. tokens: list[str] = json.loads(response.text)["tokens"]
  84. return len(tokens)
  85. def send_prompt(data: dict) -> tuple[float, list[float]]:
  86. session = data["session"]
  87. server_address: str = data["server_address"]
  88. t_submit = time()
  89. if data["synthetic_prompt"]:
  90. json_data: dict = {
  91. "prompt": data["prompt"], "ignore_eos": True, "cache_prompt": False,
  92. "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
  93. response = session.post(f"{server_address}/completion", json=json_data, stream=True)
  94. else:
  95. response = session.post(
  96. f"{server_address}/apply-template",
  97. json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
  98. )
  99. if response.status_code != 200:
  100. raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
  101. prompt: str = json.loads(response.text)["prompt"]
  102. json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
  103. response = session.post(f"{server_address}/completion", json=json_data, stream=True)
  104. token_arrival_times: list[float] = []
  105. for line in response.iter_lines(decode_unicode=False):
  106. if not line.startswith(b"data: "):
  107. continue
  108. token_arrival_times.append(time())
  109. token_arrival_times = token_arrival_times[:-1]
  110. if response.status_code != 200:
  111. raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
  112. return (t_submit, token_arrival_times)
  113. def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, n_predict_min: int, seed_offset: int):
  114. if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
  115. logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
  116. os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
  117. if os.environ.get("LLAMA_ARG_N_GPU_LAYERS") is None:
  118. logger.info("LLAMA_ARG_N_GPU_LAYERS not explicitly set, using 999")
  119. os.environ["LLAMA_ARG_N_GPU_LAYERS"] = "999"
  120. if os.environ.get("LLAMA_ARG_FLASH_ATTN") is None:
  121. logger.info("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'")
  122. os.environ["LLAMA_ARG_FLASH_ATTN"] = "true"
  123. parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore
  124. prompts: Union[None, list[str], list[list[int]]] = get_prompts_text(prompt_source, n_prompts)
  125. synthetic_prompts: bool = prompts is None
  126. prompt_n = []
  127. if synthetic_prompts:
  128. prompt_source_split: list[str] = prompt_source.split("-")
  129. assert len(prompt_source_split) == 3
  130. assert prompt_source_split[0].lower() == "rng"
  131. prompt_length_min: int = int(prompt_source_split[1])
  132. prompt_length_max: int = int(prompt_source_split[2])
  133. logger.info("Generating random prompts...")
  134. prompt_n = get_prompt_lengths_rng(n_prompts, prompt_length_min, prompt_length_max, seed_offset)
  135. prompts = get_prompts_rng(prompt_n)
  136. else:
  137. n_predict_min = n_predict
  138. if os.environ.get("LLAMA_ARG_CTX_SIZE") is None:
  139. context_per_slot: int = int(1.05 * (n_predict + (np.max(prompt_n) if synthetic_prompts else 2048)))
  140. context_total: int = context_per_slot * parallel
  141. os.environ["LLAMA_ARG_CTX_SIZE"] = str(context_total)
  142. logger.info(f"LLAMA_ARG_CTX_SIZE not explicitly set, using {context_total} ({context_per_slot} per slot).")
  143. server: Optional[dict] = None
  144. session = None
  145. try:
  146. server = get_server(path_server, path_log)
  147. server_address: str = server["address"]
  148. adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore
  149. session = requests.Session()
  150. session.mount("http://", adapter)
  151. session.mount("https://", adapter)
  152. data: list[dict] = []
  153. for i, p in enumerate(prompts):
  154. if seed_offset >= 0:
  155. random.seed(3 * (seed_offset + 1000 * i) + 1)
  156. data.append({
  157. "session": session, "server_address": server_address, "prompt": p, "synthetic_prompt": synthetic_prompts,
  158. "n_predict": random.randint(n_predict_min, n_predict), "seed": (3 * (seed_offset + 1000 * i) + 2) if seed_offset >= 0 else -1})
  159. if not synthetic_prompts:
  160. logger.info("Getting the prompt lengths...")
  161. prompt_n = [get_prompt_length(d) for d in data]
  162. logger.info("Starting the benchmark...\n")
  163. t0 = time()
  164. results: list[tuple[float, list[float]]] = thread_map(send_prompt, data, max_workers=parallel, chunksize=1)
  165. finally:
  166. if server is not None:
  167. server["process"].terminate()
  168. server["process"].wait()
  169. if session is not None:
  170. session.close()
  171. prompt_t = []
  172. token_t = []
  173. depth_sum: int = 0
  174. for pn, (t_submit, tat) in zip(prompt_n, results):
  175. prompt_t.append(tat[0] - t_submit)
  176. token_t += tat
  177. n_tokens: int = len(tat)
  178. depth_sum += n_tokens * pn
  179. depth_sum += n_tokens * (n_tokens + 1) // 2
  180. assert len(token_t) > 0
  181. prompt_n = np.array(prompt_n, dtype=np.int64)
  182. prompt_t = np.array(prompt_t, dtype=np.float64)
  183. token_t = np.array(token_t, dtype=np.float64)
  184. token_t -= t0
  185. token_t_last = np.max(token_t)
  186. logger.info("")
  187. logger.info(f"Benchmark duration: {token_t_last:.2f} s")
  188. logger.info(f"Request throughput: {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min")
  189. logger.info(f"Total prompt length: {np.sum(prompt_n)} tokens")
  190. logger.info(f"Average prompt length: {np.mean(prompt_n):.2f} tokens")
  191. logger.info(f"Average prompt latency: {1e3 * np.mean(prompt_t):.2f} ms")
  192. logger.info(f"Average prompt speed: {np.sum(prompt_n) / np.sum(prompt_t):.2f} tokens/s")
  193. logger.info(f"Total generated tokens: {token_t.shape[0]}")
  194. logger.info(f"Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens")
  195. logger.info(f"Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s")
  196. logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot")
  197. logger.info("")
  198. logger.info(
  199. "The above numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
  200. "particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).")
  201. plt.figure()
  202. plt.scatter(prompt_n, 1e3 * prompt_t, s=10.0, marker=".", alpha=0.25)
  203. plt.xlim(0, 1.05e0 * np.max(prompt_n))
  204. plt.ylim(0, 1.05e3 * np.max(prompt_t))
  205. plt.xlabel("Prompt length [tokens]")
  206. plt.ylabel("Time to first token [ms]")
  207. plt.savefig("prompt_time.png", dpi=240)
  208. bin_max = np.ceil(token_t_last) + 1
  209. plt.figure()
  210. plt.hist(token_t, np.arange(0, bin_max))
  211. plt.xlim(0, bin_max + 1)
  212. plt.xlabel("Time [s]")
  213. plt.ylabel("Num. tokens generated per second")
  214. plt.savefig("gen_rate.png", dpi=240)
  215. if __name__ == "__main__":
  216. parser = argparse.ArgumentParser(
  217. description="Tool for benchmarking the throughput of the llama.cpp HTTP server. "
  218. "Results are printed to console and visualized as plots (saved to current working directory). "
  219. "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help).")
  220. parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary")
  221. parser.add_argument("--path_log", type=str, default="server-bench-{port}.log", help="Path to the model to use for the benchmark")
  222. parser.add_argument(
  223. "--prompt_source", type=str, default="rng-1024-2048",
  224. help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "
  225. "rng-MIN-MAX for synthetic prompts with random lengths in the interval [MIN, MAX]")
  226. parser.add_argument("--n_prompts", type=int, default=100, help="Number of prompts to evaluate")
  227. parser.add_argument("--n_predict", type=int, default=2048, help="Max. number of tokens to predict per prompt")
  228. parser.add_argument(
  229. "--n_predict_min", type=int, default=1024,
  230. help="Min. number of tokens to predict per prompt (supported for synthetic prompts only)")
  231. parser.add_argument("--seed_offset", type=int, default=0, help="Offset for determining the seeds for pseudorandom prompt/generation lengths. "
  232. "Corelations between seeds can occur when set >= 1000. Negative values mean no seed.")
  233. args = parser.parse_args()
  234. benchmark(**vars(args))