server-bench.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. #!/usr/bin/env python3
  2. import argparse
  3. import json
  4. import subprocess
  5. from time import sleep, time
  6. from typing import Optional
  7. import datasets
  8. import logging
  9. import matplotlib.pyplot as plt
  10. import numpy as np
  11. import requests
  12. from tqdm.contrib.concurrent import thread_map
  13. logging.basicConfig(level=logging.INFO, format='%(message)s')
  14. logger = logging.getLogger("server-bench")
  15. def get_prompts(n_prompts: int) -> list[str]:
  16. logger.info("Loading MMLU dataset...")
  17. ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"] # type: ignore
  18. if n_prompts >= 0:
  19. ret = ret[:n_prompts]
  20. return ret
  21. def get_server(path_server: str, path_model: str, path_log: Optional[str], port: int, n_gpu_layers: int, parallel: int, ctx_size: int) -> dict:
  22. logger.info("Starting the llama.cpp server...")
  23. address = f"http://localhost:{port}"
  24. popen_args: list[str] = [
  25. path_server,
  26. "--flash-attn",
  27. "--n-gpu-layers", str(n_gpu_layers),
  28. "--parallel", str(parallel),
  29. "--ctx-size", str(parallel * ctx_size),
  30. "--model", path_model,
  31. "--port", str(port),
  32. "--swa-full", # FIXME performance bad otherwise
  33. # "--attn-streams",
  34. ]
  35. fout = open("bench.log", "w") if path_log is not None else subprocess.DEVNULL
  36. process = subprocess.Popen(popen_args, stdout=fout, stderr=subprocess.STDOUT)
  37. n_failures: int = 0
  38. while True:
  39. try:
  40. sleep(1.0)
  41. exit_code = process.poll()
  42. if exit_code is not None:
  43. raise RuntimeError(f"llama.cpp server for {path_model} exited unexpectedly with exit code {exit_code}")
  44. response = requests.get(f"{address}/health")
  45. if response.status_code == 200:
  46. break
  47. except requests.ConnectionError:
  48. n_failures += 1
  49. if n_failures >= 10:
  50. raise RuntimeError(f"llama.cpp server for {path_model} is not healthy after 10 seconds")
  51. return {"process": process, "address": address, "fout": fout}
  52. def get_prompt_length(data: dict) -> int:
  53. session = data["session"]
  54. server_address: str = data["server_address"]
  55. response = session.post(
  56. f"{server_address}/apply-template",
  57. json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
  58. )
  59. if response.status_code != 200:
  60. raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
  61. prompt: str = json.loads(response.text)["prompt"]
  62. response = session.post(
  63. f"{server_address}/tokenize",
  64. json={"content": prompt, "add_special": True}
  65. )
  66. if response.status_code != 200:
  67. raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
  68. tokens: list[str] = json.loads(response.text)["tokens"]
  69. return len(tokens)
  70. def send_prompt(data: dict) -> tuple[float, list[float]]:
  71. session = data["session"]
  72. server_address: str = data["server_address"]
  73. response = session.post(
  74. f"{server_address}/apply-template",
  75. json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
  76. )
  77. if response.status_code != 200:
  78. raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
  79. prompt: str = json.loads(response.text)["prompt"]
  80. json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
  81. response = session.post(f"{server_address}/completion", json=json_data, stream=True)
  82. last_valid_line: str = ""
  83. token_arrival_times: list[float] = []
  84. for line in response.iter_lines(decode_unicode=True):
  85. if not line.startswith("data: "):
  86. continue
  87. last_valid_line = line
  88. token_arrival_times.append(time())
  89. token_arrival_times = token_arrival_times[:-1]
  90. if response.status_code != 200:
  91. raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
  92. timings: dict = json.loads(last_valid_line[6:])["timings"]
  93. return (timings["prompt_ms"], token_arrival_times)
  94. def benchmark(path_server: str, path_model: str, path_log: Optional[str], port: int, n_gpu_layers: int, parallel: int, ctx_size: int, n_prompts: int, n_predict: int):
  95. num_workers: int = parallel + 1
  96. prompts: list[str] = get_prompts(n_prompts)
  97. server: Optional[dict] = None
  98. session = None
  99. try:
  100. server = get_server(path_server, path_model, path_log, port, n_gpu_layers, parallel, ctx_size)
  101. server_address: str = server["address"]
  102. adapter = requests.adapters.HTTPAdapter(pool_connections=num_workers, pool_maxsize=num_workers) # type: ignore
  103. session = requests.Session()
  104. session.mount("http://", adapter)
  105. session.mount("https://", adapter)
  106. data: list[dict] = []
  107. for i, p in enumerate(prompts):
  108. data.append({"session": session, "server_address": server_address, "prompt": p, "n_predict": n_predict, "seed": i})
  109. logger.info("Getting the prompt lengths...")
  110. prompt_n = [get_prompt_length(d) for d in data]
  111. logger.info("Starting the benchmark...\n")
  112. t0 = time()
  113. results: list[tuple[int, list[float]]] = thread_map(send_prompt, data, max_workers=num_workers, chunksize=1)
  114. finally:
  115. if server is not None:
  116. server["process"].terminate()
  117. server["process"].wait()
  118. if session is not None:
  119. session.close()
  120. prompt_ms = []
  121. token_t = []
  122. depth_sum: int = 0
  123. for pn, (pms, tat) in zip(prompt_n, results):
  124. prompt_ms.append(pms)
  125. token_t += tat
  126. n_tokens: int = len(tat)
  127. depth_sum += n_tokens * pn
  128. depth_sum += n_tokens * (n_tokens + 1) // 2
  129. prompt_n = np.array(prompt_n, dtype=np.int64)
  130. prompt_ms = np.array(prompt_ms, dtype=np.float64)
  131. token_t = np.array(token_t, dtype=np.float64)
  132. token_t -= t0
  133. token_t_last = np.max(token_t)
  134. logger.info("")
  135. logger.info(f"Benchmark duration: {token_t_last:.2f} s")
  136. logger.info(f"Request throughput: {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min")
  137. logger.info(f"Total prompt length: {np.sum(prompt_n)} tokens")
  138. logger.info(f"Average prompt length: {np.mean(prompt_n):.2f} tokens")
  139. logger.info(f"Average prompt latency: {np.mean(prompt_ms):.2f} ms")
  140. logger.info(f"Average prompt speed: {np.sum(prompt_n) / (1e-3 * np.sum(prompt_ms)):.2f} tokens/s")
  141. logger.info(f"Total generated tokens: {token_t.shape[0]}")
  142. logger.info(f"Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens")
  143. logger.info(f"Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s")
  144. logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot")
  145. plt.figure()
  146. plt.scatter(prompt_n, prompt_ms, s=10.0, marker=".", alpha=0.25)
  147. plt.xlim(0, 1.05 * np.max(prompt_n))
  148. plt.ylim(0, 1.05 * np.max(prompt_ms))
  149. plt.title(path_model)
  150. plt.xlabel("Prompt length [tokens]")
  151. plt.ylabel("Time to first token [ms]")
  152. plt.savefig("prompt_time.png", dpi=240)
  153. bin_max = np.ceil(token_t_last) + 1
  154. plt.figure()
  155. plt.hist(token_t, np.arange(0, bin_max))
  156. plt.xlim(0, bin_max + 1)
  157. plt.title(path_model)
  158. plt.xlabel("Time [s]")
  159. plt.ylabel("Num. tokens generated per second")
  160. plt.savefig("gen_rate.png", dpi=240)
  161. if __name__ == "__main__":
  162. parser = argparse.ArgumentParser(
  163. description="Tool for benchmarking the throughput of the llama.cpp HTTP server. "
  164. "Results are printed to console and visualized as plots (saved to current working directory).")
  165. parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary")
  166. parser.add_argument("--path_model", type=str, required=True, help="Path to the model to use for the benchmark")
  167. parser.add_argument("--path_log", type=str, default=None, help="Path to the model to use for the benchmark")
  168. parser.add_argument("--port", type=int, default=18725, help="Port to use for the server during the benchmark")
  169. parser.add_argument("--n_gpu_layers", type=int, default=999, help="Number of GPU layers for the server")
  170. parser.add_argument("--parallel", type=int, default=16, help="Number of slots for the server")
  171. parser.add_argument("--ctx_size", type=int, default=4096, help="Server context size per slot")
  172. parser.add_argument("--n_prompts", type=int, default=1000, help="Number of prompts to evaluate")
  173. parser.add_argument("--n_predict", type=int, default=2048, help="Max. number of tokens to predict per prompt")
  174. args = parser.parse_args()
  175. benchmark(**vars(args))