| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265 |
- #!/usr/bin/env python3
- import argparse
- import json
- import os
- import random
- import subprocess
- from time import sleep, time
- from typing import Optional, Union
- import datasets
- import logging
- import matplotlib.pyplot as plt
- import numpy as np
- import requests
- from tqdm.contrib.concurrent import thread_map
- logging.basicConfig(level=logging.INFO, format='%(message)s')
- logger = logging.getLogger("server-bench")
- def get_prompts_text(dataset_name: str, n_prompts: int) -> Optional[list[str]]:
- ret = []
- if dataset_name.lower() == "mmlu":
- logger.info("Loading MMLU dataset...")
- ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"] # type: ignore
- else:
- return None
- if n_prompts >= 0:
- ret = ret[:n_prompts]
- return ret
- def get_prompt_lengths_rng(n_prompts: int, prompt_length_min: int, prompt_length_max: int) -> list[int]:
- assert n_prompts >= 0
- ret: list[int] = []
- for i in range(n_prompts):
- random.seed(13 * i + 0)
- ret.append(random.randint(prompt_length_min, prompt_length_max))
- return ret
- def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]:
- return [[random.randint(100, 10000) for _ in range(pl)] for pl in prompt_lengths]
- def get_server(path_server: str, path_log: Optional[str]) -> dict:
- logger.info("Starting the llama.cpp server...")
- hostname: str = os.environ.get("LLAMA_ARG_HOST", "127.0.0.1")
- port: str = os.environ.get("LLAMA_ARG_PORT", "8080")
- address: str = f"http://{hostname}:{port}"
- fout = open(path_log, "w") if path_log is not None else subprocess.DEVNULL
- process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT)
- n_failures: int = 0
- while True:
- try:
- sleep(1.0)
- exit_code = process.poll()
- if exit_code is not None:
- raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}, see {path_log}")
- response = requests.get(f"{address}/health")
- if response.status_code == 200:
- break
- except requests.ConnectionError:
- n_failures += 1
- if n_failures >= 10:
- raise RuntimeError("llama.cpp server is not healthy after 10 seconds")
- return {"process": process, "address": address, "fout": fout}
- def get_prompt_length(data: dict) -> int:
- session = data["session"]
- server_address: str = data["server_address"]
- response = session.post(
- f"{server_address}/apply-template",
- json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
- )
- if response.status_code != 200:
- raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
- prompt: str = json.loads(response.text)["prompt"]
- response = session.post(
- f"{server_address}/tokenize",
- json={"content": prompt, "add_special": True}
- )
- if response.status_code != 200:
- raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
- tokens: list[str] = json.loads(response.text)["tokens"]
- return len(tokens)
- def send_prompt(data: dict) -> tuple[float, list[float]]:
- session = data["session"]
- server_address: str = data["server_address"]
- t_submit = time()
- if data["synthetic_prompt"]:
- json_data: dict = {
- "prompt": data["prompt"], "ignore_eos": True, "cache_prompt": False,
- "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
- response = session.post(f"{server_address}/completion", json=json_data, stream=True)
- else:
- response = session.post(
- f"{server_address}/apply-template",
- json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
- )
- if response.status_code != 200:
- raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
- prompt: str = json.loads(response.text)["prompt"]
- json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
- response = session.post(f"{server_address}/completion", json=json_data, stream=True)
- token_arrival_times: list[float] = []
- for line in response.iter_lines(decode_unicode=False):
- if not line.startswith(b"data: "):
- continue
- token_arrival_times.append(time())
- token_arrival_times = token_arrival_times[:-1]
- if response.status_code != 200:
- raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
- return (t_submit, token_arrival_times)
- def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, n_predict_min: int):
- if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
- logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
- os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
- if os.environ.get("LLAMA_ARG_N_GPU_LAYERS") is None:
- logger.info("LLAMA_ARG_N_GPU_LAYERS not explicitly set, using 999")
- os.environ["LLAMA_ARG_N_GPU_LAYERS"] = "999"
- if os.environ.get("LLAMA_ARG_FLASH_ATTN") is None:
- logger.info("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'")
- os.environ["LLAMA_ARG_FLASH_ATTN"] = "true"
- parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL", 1))
- prompts: Union[None, list[str], list[list[int]]] = get_prompts_text(prompt_source, n_prompts)
- synthetic_prompts: bool = prompts is None
- prompt_n = []
- if synthetic_prompts:
- prompt_source_split: list[str] = prompt_source.split("-")
- assert len(prompt_source_split) == 3
- assert prompt_source_split[0].lower() == "rng"
- prompt_length_min: int = int(prompt_source_split[1])
- prompt_length_max: int = int(prompt_source_split[2])
- logger.info("Generating random prompts...")
- prompt_n = get_prompt_lengths_rng(n_prompts, prompt_length_min, prompt_length_max)
- prompts = get_prompts_rng(prompt_n)
- else:
- n_predict_min = n_predict
- if os.environ.get("LLAMA_ARG_CTX_SIZE") is None:
- context_per_slot: int = int(1.05 * (n_predict + (np.max(prompt_n) if synthetic_prompts else 2048)))
- context_total: int = context_per_slot * parallel
- os.environ["LLAMA_ARG_CTX_SIZE"] = str(context_total)
- logger.info(f"LLAMA_ARG_CTX_SIZE not explicitly set, using {context_total} ({context_per_slot} per slot).")
- server: Optional[dict] = None
- session = None
- try:
- server = get_server(path_server, path_log)
- server_address: str = server["address"]
- adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore
- session = requests.Session()
- session.mount("http://", adapter)
- session.mount("https://", adapter)
- data: list[dict] = []
- for i, p in enumerate(prompts):
- random.seed(13 * i + 1)
- data.append({
- "session": session, "server_address": server_address, "prompt": p, "synthetic_prompt": synthetic_prompts,
- "n_predict": random.randint(n_predict_min, n_predict), "seed": 13 * i + 2})
- if not synthetic_prompts:
- logger.info("Getting the prompt lengths...")
- prompt_n = [get_prompt_length(d) for d in data]
- logger.info("Starting the benchmark...\n")
- t0 = time()
- results: list[tuple[float, list[float]]] = thread_map(send_prompt, data, max_workers=parallel, chunksize=1)
- finally:
- if server is not None:
- server["process"].terminate()
- server["process"].wait()
- if session is not None:
- session.close()
- prompt_t = []
- token_t = []
- depth_sum: int = 0
- for pn, (t_submit, tat) in zip(prompt_n, results):
- prompt_t.append(tat[0] - t_submit)
- token_t += tat
- n_tokens: int = len(tat)
- depth_sum += n_tokens * pn
- depth_sum += n_tokens * (n_tokens + 1) // 2
- assert len(token_t) > 0
- prompt_n = np.array(prompt_n, dtype=np.int64)
- prompt_t = np.array(prompt_t, dtype=np.float64)
- token_t = np.array(token_t, dtype=np.float64)
- token_t -= t0
- token_t_last = np.max(token_t)
- logger.info("")
- logger.info(f"Benchmark duration: {token_t_last:.2f} s")
- logger.info(f"Request throughput: {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min")
- logger.info(f"Total prompt length: {np.sum(prompt_n)} tokens")
- logger.info(f"Average prompt length: {np.mean(prompt_n):.2f} tokens")
- logger.info(f"Average prompt latency: {1e3 * np.mean(prompt_t):.2f} ms")
- logger.info(f"Average prompt speed: {np.sum(prompt_n) / np.sum(prompt_t):.2f} tokens/s")
- logger.info(f"Total generated tokens: {token_t.shape[0]}")
- logger.info(f"Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens")
- logger.info(f"Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s")
- logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot")
- logger.info("")
- logger.info(
- "The above numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
- "particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).")
- plt.figure()
- plt.scatter(prompt_n, 1e3 * prompt_t, s=10.0, marker=".", alpha=0.25)
- plt.xlim(0, 1.05e0 * np.max(prompt_n))
- plt.ylim(0, 1.05e3 * np.max(prompt_t))
- plt.xlabel("Prompt length [tokens]")
- plt.ylabel("Time to first token [ms]")
- plt.savefig("prompt_time.png", dpi=240)
- bin_max = np.ceil(token_t_last) + 1
- plt.figure()
- plt.hist(token_t, np.arange(0, bin_max))
- plt.xlim(0, bin_max + 1)
- plt.xlabel("Time [s]")
- plt.ylabel("Num. tokens generated per second")
- plt.savefig("gen_rate.png", dpi=240)
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(
- description="Tool for benchmarking the throughput of the llama.cpp HTTP server. "
- "Results are printed to console and visualized as plots (saved to current working directory). "
- "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help).")
- parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary")
- parser.add_argument("--path_log", type=str, default="server-bench.log", help="Path to the model to use for the benchmark")
- parser.add_argument(
- "--prompt_source", type=str, default="rng-1024-2048",
- help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "
- "rng-MIN-MAX for synthetic prompts with random lengths in the interval [MIN, MAX]")
- parser.add_argument("--n_prompts", type=int, default=100, help="Number of prompts to evaluate")
- parser.add_argument("--n_predict", type=int, default=2048, help="Max. number of tokens to predict per prompt")
- parser.add_argument(
- "--n_predict_min", type=int, default=1024,
- help="Min. number of tokens to predict per prompt (supported for synthetic prompts only)")
- args = parser.parse_args()
- benchmark(**vars(args))
|