server-bench.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. #!/usr/bin/env python3
  2. import argparse
  3. import json
  4. import os
  5. import random
  6. import subprocess
  7. from time import sleep, time
  8. from typing import Optional, Union
  9. import datasets
  10. import logging
  11. import matplotlib.pyplot as plt
  12. import numpy as np
  13. import requests
  14. from tqdm.contrib.concurrent import thread_map
  15. logging.basicConfig(level=logging.INFO, format='%(message)s')
  16. logger = logging.getLogger("server-bench")
  17. def get_prompts_text(dataset_name: str, n_prompts: int) -> Optional[list[str]]:
  18. ret = []
  19. if dataset_name.lower() == "mmlu":
  20. logger.info("Loading MMLU dataset...")
  21. ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"] # type: ignore
  22. else:
  23. return None
  24. if n_prompts >= 0:
  25. ret = ret[:n_prompts]
  26. return ret
  27. def get_prompt_lengths_rng(n_prompts: int, prompt_length_min: int, prompt_length_max: int) -> list[int]:
  28. assert n_prompts >= 0
  29. ret: list[int] = []
  30. for i in range(n_prompts):
  31. random.seed(13 * i + 0)
  32. ret.append(random.randint(prompt_length_min, prompt_length_max))
  33. return ret
  34. def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]:
  35. return [[random.randint(100, 10000) for _ in range(pl)] for pl in prompt_lengths]
  36. def get_server(path_server: str, path_log: Optional[str]) -> dict:
  37. logger.info("Starting the llama.cpp server...")
  38. hostname: str = os.environ.get("LLAMA_ARG_HOST", "127.0.0.1")
  39. port: str = os.environ.get("LLAMA_ARG_PORT", "8080")
  40. address: str = f"http://{hostname}:{port}"
  41. fout = open(path_log, "w") if path_log is not None else subprocess.DEVNULL
  42. process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT)
  43. n_failures: int = 0
  44. while True:
  45. try:
  46. sleep(1.0)
  47. exit_code = process.poll()
  48. if exit_code is not None:
  49. raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}, see {path_log}")
  50. response = requests.get(f"{address}/health")
  51. if response.status_code == 200:
  52. break
  53. except requests.ConnectionError:
  54. n_failures += 1
  55. if n_failures >= 10:
  56. raise RuntimeError("llama.cpp server is not healthy after 10 seconds")
  57. return {"process": process, "address": address, "fout": fout}
  58. def get_prompt_length(data: dict) -> int:
  59. session = data["session"]
  60. server_address: str = data["server_address"]
  61. response = session.post(
  62. f"{server_address}/apply-template",
  63. json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
  64. )
  65. if response.status_code != 200:
  66. raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
  67. prompt: str = json.loads(response.text)["prompt"]
  68. response = session.post(
  69. f"{server_address}/tokenize",
  70. json={"content": prompt, "add_special": True}
  71. )
  72. if response.status_code != 200:
  73. raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
  74. tokens: list[str] = json.loads(response.text)["tokens"]
  75. return len(tokens)
  76. def send_prompt(data: dict) -> tuple[float, list[float]]:
  77. session = data["session"]
  78. server_address: str = data["server_address"]
  79. t_submit = time()
  80. if data["synthetic_prompt"]:
  81. json_data: dict = {
  82. "prompt": data["prompt"], "ignore_eos": True, "cache_prompt": False,
  83. "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
  84. response = session.post(f"{server_address}/completion", json=json_data, stream=True)
  85. else:
  86. response = session.post(
  87. f"{server_address}/apply-template",
  88. json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
  89. )
  90. if response.status_code != 200:
  91. raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
  92. prompt: str = json.loads(response.text)["prompt"]
  93. json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
  94. response = session.post(f"{server_address}/completion", json=json_data, stream=True)
  95. token_arrival_times: list[float] = []
  96. for line in response.iter_lines(decode_unicode=False):
  97. if not line.startswith(b"data: "):
  98. continue
  99. token_arrival_times.append(time())
  100. token_arrival_times = token_arrival_times[:-1]
  101. if response.status_code != 200:
  102. raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
  103. return (t_submit, token_arrival_times)
  104. def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, n_predict_min: int):
  105. if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
  106. logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
  107. os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
  108. if os.environ.get("LLAMA_ARG_N_GPU_LAYERS") is None:
  109. logger.info("LLAMA_ARG_N_GPU_LAYERS not explicitly set, using 999")
  110. os.environ["LLAMA_ARG_N_GPU_LAYERS"] = "999"
  111. if os.environ.get("LLAMA_ARG_FLASH_ATTN") is None:
  112. logger.info("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'")
  113. os.environ["LLAMA_ARG_FLASH_ATTN"] = "true"
  114. parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL", 1))
  115. prompts: Union[None, list[str], list[list[int]]] = get_prompts_text(prompt_source, n_prompts)
  116. synthetic_prompts: bool = prompts is None
  117. prompt_n = []
  118. if synthetic_prompts:
  119. prompt_source_split: list[str] = prompt_source.split("-")
  120. assert len(prompt_source_split) == 3
  121. assert prompt_source_split[0].lower() == "rng"
  122. prompt_length_min: int = int(prompt_source_split[1])
  123. prompt_length_max: int = int(prompt_source_split[2])
  124. logger.info("Generating random prompts...")
  125. prompt_n = get_prompt_lengths_rng(n_prompts, prompt_length_min, prompt_length_max)
  126. prompts = get_prompts_rng(prompt_n)
  127. else:
  128. n_predict_min = n_predict
  129. if os.environ.get("LLAMA_ARG_CTX_SIZE") is None:
  130. context_per_slot: int = int(1.05 * (n_predict + (np.max(prompt_n) if synthetic_prompts else 2048)))
  131. context_total: int = context_per_slot * parallel
  132. os.environ["LLAMA_ARG_CTX_SIZE"] = str(context_total)
  133. logger.info(f"LLAMA_ARG_CTX_SIZE not explicitly set, using {context_total} ({context_per_slot} per slot).")
  134. server: Optional[dict] = None
  135. session = None
  136. try:
  137. server = get_server(path_server, path_log)
  138. server_address: str = server["address"]
  139. adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore
  140. session = requests.Session()
  141. session.mount("http://", adapter)
  142. session.mount("https://", adapter)
  143. data: list[dict] = []
  144. for i, p in enumerate(prompts):
  145. random.seed(13 * i + 1)
  146. data.append({
  147. "session": session, "server_address": server_address, "prompt": p, "synthetic_prompt": synthetic_prompts,
  148. "n_predict": random.randint(n_predict_min, n_predict), "seed": 13 * i + 2})
  149. if not synthetic_prompts:
  150. logger.info("Getting the prompt lengths...")
  151. prompt_n = [get_prompt_length(d) for d in data]
  152. logger.info("Starting the benchmark...\n")
  153. t0 = time()
  154. results: list[tuple[float, list[float]]] = thread_map(send_prompt, data, max_workers=parallel, chunksize=1)
  155. finally:
  156. if server is not None:
  157. server["process"].terminate()
  158. server["process"].wait()
  159. if session is not None:
  160. session.close()
  161. prompt_t = []
  162. token_t = []
  163. depth_sum: int = 0
  164. for pn, (t_submit, tat) in zip(prompt_n, results):
  165. prompt_t.append(tat[0] - t_submit)
  166. token_t += tat
  167. n_tokens: int = len(tat)
  168. depth_sum += n_tokens * pn
  169. depth_sum += n_tokens * (n_tokens + 1) // 2
  170. assert len(token_t) > 0
  171. prompt_n = np.array(prompt_n, dtype=np.int64)
  172. prompt_t = np.array(prompt_t, dtype=np.float64)
  173. token_t = np.array(token_t, dtype=np.float64)
  174. token_t -= t0
  175. token_t_last = np.max(token_t)
  176. logger.info("")
  177. logger.info(f"Benchmark duration: {token_t_last:.2f} s")
  178. logger.info(f"Request throughput: {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min")
  179. logger.info(f"Total prompt length: {np.sum(prompt_n)} tokens")
  180. logger.info(f"Average prompt length: {np.mean(prompt_n):.2f} tokens")
  181. logger.info(f"Average prompt latency: {1e3 * np.mean(prompt_t):.2f} ms")
  182. logger.info(f"Average prompt speed: {np.sum(prompt_n) / np.sum(prompt_t):.2f} tokens/s")
  183. logger.info(f"Total generated tokens: {token_t.shape[0]}")
  184. logger.info(f"Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens")
  185. logger.info(f"Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s")
  186. logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot")
  187. logger.info("")
  188. logger.info(
  189. "The above numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
  190. "particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).")
  191. plt.figure()
  192. plt.scatter(prompt_n, 1e3 * prompt_t, s=10.0, marker=".", alpha=0.25)
  193. plt.xlim(0, 1.05e0 * np.max(prompt_n))
  194. plt.ylim(0, 1.05e3 * np.max(prompt_t))
  195. plt.xlabel("Prompt length [tokens]")
  196. plt.ylabel("Time to first token [ms]")
  197. plt.savefig("prompt_time.png", dpi=240)
  198. bin_max = np.ceil(token_t_last) + 1
  199. plt.figure()
  200. plt.hist(token_t, np.arange(0, bin_max))
  201. plt.xlim(0, bin_max + 1)
  202. plt.xlabel("Time [s]")
  203. plt.ylabel("Num. tokens generated per second")
  204. plt.savefig("gen_rate.png", dpi=240)
  205. if __name__ == "__main__":
  206. parser = argparse.ArgumentParser(
  207. description="Tool for benchmarking the throughput of the llama.cpp HTTP server. "
  208. "Results are printed to console and visualized as plots (saved to current working directory). "
  209. "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help).")
  210. parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary")
  211. parser.add_argument("--path_log", type=str, default="server-bench.log", help="Path to the model to use for the benchmark")
  212. parser.add_argument(
  213. "--prompt_source", type=str, default="rng-1024-2048",
  214. help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "
  215. "rng-MIN-MAX for synthetic prompts with random lengths in the interval [MIN, MAX]")
  216. parser.add_argument("--n_prompts", type=int, default=100, help="Number of prompts to evaluate")
  217. parser.add_argument("--n_predict", type=int, default=2048, help="Max. number of tokens to predict per prompt")
  218. parser.add_argument(
  219. "--n_predict_min", type=int, default=1024,
  220. help="Min. number of tokens to predict per prompt (supported for synthetic prompts only)")
  221. args = parser.parse_args()
  222. benchmark(**vars(args))