1
0

tool_bench.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. #!/usr/bin/env uv run
  2. '''
  3. Simplistic tool call benchmarks for llama-server and ollama.
  4. Essentially runs the tests at server/tools/server/tests/unit/test_tool_call.py N times, at different temperatures and on different backends (current llama-server, baseline llama-server and ollama),
  5. and plots the results of multiple runs (from same .jsonl file or multiple ones) as a success rate heatmap.
  6. Simple usage example:
  7. cmake -B build -DLLAMA_CURL=1 && cmake --build build --config Release -j -t llama-server
  8. export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server
  9. export LLAMA_CACHE=${LLAMA_CACHE:-$HOME/Library/Caches/llama.cpp}
  10. ./scripts/tool_bench.py run --n 10 --temp -1 --temp 0 --temp 1 --temp 2 --temp 5 --llama-baseline $PWD/buildMaster/bin/llama-server --output qwen14b.jsonl --hf bartowski/Qwen2.5-14B-Instruct-GGUF:Q4_K_L
  11. ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 1.5B Q4_K_M" --output qwen1.5b.jsonl --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M
  12. ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 Coder 7B Q4_K_M" --output qwenc7b.jsonl --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b
  13. ./scripts/tool_bench.py plot *.jsonl # Opens window w/ heatmap
  14. ./scripts/tool_bench.py plot qwen*.jsonl --output qwen.png # Saves heatmap to qwen.png
  15. (please see ./scripts/tool_bench.sh for a more complete example)
  16. '''
  17. # /// script
  18. # requires-python = ">=3.10"
  19. # dependencies = [
  20. # "pytest",
  21. # "pandas",
  22. # "matplotlib",
  23. # "seaborn",
  24. # "requests",
  25. # "wget",
  26. # "typer",
  27. # ]
  28. # ///
  29. from contextlib import contextmanager
  30. from pathlib import Path
  31. import re
  32. from statistics import mean, median
  33. from typing import Annotated, Dict, List, Optional, Tuple
  34. import atexit
  35. import json
  36. import logging
  37. import matplotlib.pyplot as plt
  38. import numpy as np
  39. import pandas as pd
  40. import seaborn as sns
  41. import subprocess
  42. import sys
  43. import time
  44. import typer
  45. sys.path.insert(0, Path(__file__).parent.parent.as_posix())
  46. if True:
  47. from tools.server.tests.utils import ServerProcess
  48. from tools.server.tests.unit.test_tool_call import do_test_calc_result, do_test_hello_world, do_test_weather
  49. @contextmanager
  50. def scoped_server(sp: ServerProcess):
  51. def stop():
  52. nonlocal sp
  53. if sp is not None:
  54. sp.stop()
  55. sp = None # type: ignore
  56. atexit.register(stop)
  57. yield sp
  58. stop()
  59. logging.basicConfig(
  60. level=logging.INFO,
  61. format='%(asctime)s - %(levelname)s - %(message)s'
  62. )
  63. logger = logging.getLogger(__name__)
  64. app = typer.Typer()
  65. @app.command()
  66. def plot(files: List[Path], output: Optional[Path] = None, test_regex: Optional[str] = None, server_regex: Optional[str] = None):
  67. lines: List[Dict] = []
  68. for file in files:
  69. if not file.exists():
  70. logger.error(f"File not found: {file}")
  71. continue
  72. try:
  73. with file.open() as f:
  74. raw_data = f.read()
  75. logger.info(f"Reading {file} ({len(raw_data)} bytes)")
  76. for line_num, line in enumerate(raw_data.split('\n'), 1):
  77. line = line.strip()
  78. if not line:
  79. continue
  80. try:
  81. record = json.loads(line)
  82. lines.append(record)
  83. except json.JSONDecodeError as e:
  84. logger.warning(f"Invalid JSON at {file}:{line_num} - {e}")
  85. except Exception as e:
  86. logger.error(f"Error processing {file}: {e}")
  87. if not lines:
  88. raise Exception("No valid data was loaded")
  89. data_dict: Dict[Tuple, float] = {}
  90. models: List[str] = []
  91. temps = set()
  92. tests = set()
  93. server_names = set()
  94. total_counts = set()
  95. for rec in lines:
  96. try:
  97. model = rec["model"]
  98. temp = rec["temp"]
  99. server_name = rec["server_name"]
  100. test = rec["test"]
  101. success = rec["success_ratio"]
  102. success_count = rec["success_count"]
  103. failure_count = rec["failure_count"]
  104. total_count = success_count + failure_count
  105. total_counts.add(total_count)
  106. if test_regex and not re.search(test_regex, test):
  107. continue
  108. if server_regex and not re.search(server_regex, server_name):
  109. continue
  110. data_dict[(model, temp, server_name, test)] = success
  111. if model not in models:
  112. models.append(model)
  113. temps.add(temp)
  114. tests.add(test)
  115. server_names.add(server_name)
  116. except KeyError as e:
  117. logger.warning(f"Missing required field in record: {e}")
  118. if len(total_counts) > 1:
  119. logger.warning(f"Total counts are not consistent: {total_counts}")
  120. # Sort the collected values
  121. temps = list(sorted(temps, key=lambda x: x if x is not None else -1))
  122. tests = list(sorted(tests))
  123. server_names = list(sorted(server_names))
  124. logger.info(f"Processed {len(lines)} lines")
  125. logger.info(f"Found {len(data_dict)} valid data points")
  126. logger.info(f"Models: {models}")
  127. logger.info(f"Temperatures: {temps}")
  128. logger.info(f"Tests: {tests}")
  129. logger.info(f"Servers: {server_names}")
  130. matrix: list[list[float]] = []
  131. index: list[str] = []
  132. all_cols = [
  133. (server_name, test)
  134. for server_name in server_names
  135. for test in tests
  136. ]
  137. for model in models:
  138. for temp in temps:
  139. index.append(f"{model} @ {temp}")
  140. row_vals = [
  141. data_dict.get((model, temp, server_name, test), np.nan)
  142. for server_name, test in all_cols
  143. ]
  144. matrix.append(row_vals)
  145. columns: list[str] = [f"{server_name}\n{test}" for server_name, test in all_cols]
  146. df = pd.DataFrame(matrix, index=np.array(index), columns=np.array(columns))
  147. plt.figure(figsize=(12, 6))
  148. sns.heatmap(
  149. df, annot=True, cmap="RdYlGn", vmin=0.0, vmax=1.0, cbar=True, fmt=".2f", center=0.5, square=True, linewidths=0.5,
  150. cbar_kws={"label": "Success Ratio"},
  151. )
  152. plt.title(f"Tool Call Bench (n = {str(min(total_counts)) if len(total_counts) == 1 else f'{min(total_counts)}-{max(total_counts)}'})\nSuccess Ratios by Server & Test", pad=20)
  153. plt.xlabel("Server & Test", labelpad=10)
  154. plt.ylabel("Model @ Temperature", labelpad=10)
  155. plt.xticks(rotation=45, ha='right')
  156. plt.yticks(rotation=0)
  157. plt.tight_layout()
  158. if output:
  159. plt.savefig(output, dpi=300, bbox_inches='tight')
  160. logger.info(f"Plot saved to {output}")
  161. else:
  162. plt.show()
  163. @app.command()
  164. def run(
  165. output: Annotated[Path, typer.Option(help="Output JSON file")],
  166. model: Annotated[Optional[str], typer.Option(help="Name of the model to test (server agnostic)")] = None,
  167. hf: Annotated[Optional[str], typer.Option(help="GGUF huggingface model repo id (+ optional quant) to test w/ llama-server")] = None,
  168. chat_template: Annotated[Optional[str], typer.Option(help="Chat template override for llama-server")] = None,
  169. chat_template_file: Annotated[Optional[str], typer.Option(help="Chat template file override for llama-server")] = None,
  170. ollama: Annotated[Optional[str], typer.Option(help="Ollama model tag to test")] = None,
  171. llama_baseline: Annotated[Optional[str], typer.Option(help="llama-server baseline binary path to use as baseline")] = None,
  172. n: Annotated[int, typer.Option(help="Number of times to run each test")] = 10,
  173. temp: Annotated[Optional[List[float]], typer.Option(help="Set of temperatures to test")] = None,
  174. top_p: Annotated[Optional[float], typer.Option(help="top_p")] = None,
  175. top_k: Annotated[Optional[int], typer.Option(help="top_k")] = None,
  176. ctk: Annotated[Optional[str], typer.Option(help="ctk")] = None,
  177. ctv: Annotated[Optional[str], typer.Option(help="ctv")] = None,
  178. fa: Annotated[Optional[bool], typer.Option(help="fa")] = None,
  179. seed: Annotated[Optional[int], typer.Option(help="Random seed")] = None,
  180. port: Annotated[int, typer.Option(help="llama-server port")] = 8084,
  181. force: Annotated[bool, typer.Option(help="Force overwrite of output file")] = False,
  182. append: Annotated[bool, typer.Option(help="Append to output file")] = False,
  183. test_hello_world: Annotated[bool, typer.Option(help="Whether to run the hello world test")] = True,
  184. test_weather: Annotated[bool, typer.Option(help="Whether to run the weather test")] = True,
  185. test_calc_result: Annotated[bool, typer.Option(help="Whether to run the calc result test")] = False,
  186. ):
  187. # Check only one of output and append
  188. n_predict = 512 # High because of DeepSeek R1
  189. # n_ctx = 8192
  190. n_ctx = 2048
  191. if model is None:
  192. if hf is not None:
  193. model = hf.split("/")[-1]
  194. elif ollama is not None:
  195. model = ollama
  196. assert force or append or not output.exists(), f"Output file already exists: {output}; use --force to overwrite"
  197. with output.open('a' if append else 'w') as output_file:
  198. def run(server: ServerProcess, *, server_name: str, model_id: str, temp: Optional[float] = None, output_kwargs={}, request_kwargs={}):
  199. request_kwargs = {**request_kwargs}
  200. if temp is not None:
  201. request_kwargs['temperature'] = temp
  202. if top_p is not None:
  203. request_kwargs['top_p'] = top_p
  204. if top_k is not None:
  205. request_kwargs['top_k'] = top_k
  206. if seed is not None:
  207. request_kwargs['seed'] = seed
  208. request_kwargs['cache_prompt'] = False
  209. tests = {}
  210. if test_hello_world:
  211. tests["hello world"] = lambda server: do_test_hello_world(server, **request_kwargs)
  212. if test_weather:
  213. tests["weather"] = lambda server: do_test_weather(server, **request_kwargs)
  214. if test_calc_result:
  215. tests["calc result"] = lambda server: do_test_calc_result(server, None, 512, **request_kwargs)
  216. for test_name, test in tests.items():
  217. success_count = 0
  218. failure_count = 0
  219. failures = []
  220. success_times = []
  221. failure_times = []
  222. logger.info(f"Running {test_name} ({server_name}, {model}): ")
  223. for i in range(n):
  224. start_time = time.time()
  225. def elapsed():
  226. return time.time() - start_time
  227. try:
  228. test(server)
  229. success_times.append(elapsed())
  230. success_count += 1
  231. logger.info('success')
  232. except Exception as e:
  233. logger.error(f'failure: {e}')
  234. failure_count += 1
  235. failure_times.append(elapsed())
  236. failures.append(str(e))
  237. # import traceback
  238. # traceback.print_exc()
  239. output_file.write(json.dumps({**output_kwargs, **dict(
  240. model=model,
  241. server_name=server_name,
  242. model_id=model_id,
  243. test=test_name,
  244. temp=t,
  245. top_p=top_p,
  246. top_k=top_k,
  247. ctk=ctk,
  248. ctv=ctv,
  249. seed=seed,
  250. success_ratio=float(success_count) / n,
  251. avg_time=mean(success_times + failure_times),
  252. median_time=median(success_times + failure_times),
  253. success_count=success_count,
  254. success_times=success_times,
  255. failure_count=failure_count,
  256. failure_times=failure_times,
  257. failures=list(set(failures)),
  258. )}) + '\n')
  259. output_file.flush()
  260. for t in [None] if temp is None else [t if t >= 0 else None for t in temp]:
  261. if hf is not None:
  262. servers: list[Tuple[str, Optional[str]]] = [('llama-server', None)]
  263. if llama_baseline is not None:
  264. servers.append(('llama-server (baseline)', llama_baseline))
  265. for server_name, server_path in servers:
  266. server = ServerProcess()
  267. server.n_ctx = n_ctx
  268. server.n_slots = 1
  269. server.jinja = True
  270. server.ctk = ctk
  271. server.ctv = ctv
  272. server.fa = "on" if fa else "off"
  273. server.n_predict = n_predict
  274. server.model_hf_repo = hf
  275. server.model_hf_file = None
  276. server.chat_template = chat_template
  277. server.chat_template_file = chat_template_file
  278. server.server_path = server_path
  279. if port is not None:
  280. server.server_port = port
  281. # server.debug = True
  282. with scoped_server(server):
  283. server.start(timeout_seconds=15 * 60)
  284. for ignore_chat_grammar in [False]:
  285. run(
  286. server,
  287. server_name=server_name,
  288. model_id=hf,
  289. temp=t,
  290. output_kwargs=dict(
  291. chat_template=chat_template,
  292. chat_template_file=chat_template_file,
  293. ),
  294. request_kwargs=dict(
  295. ignore_chat_grammar=ignore_chat_grammar,
  296. ),
  297. )
  298. if ollama is not None:
  299. server = ServerProcess()
  300. server.server_port = 11434
  301. server.server_host = "localhost"
  302. subprocess.check_call(["ollama", "pull", ollama])
  303. with scoped_server(server):
  304. run(
  305. server,
  306. server_name="ollama",
  307. model_id=ollama,
  308. temp=t,
  309. output_kwargs=dict(
  310. chat_template=None,
  311. chat_template_file=None,
  312. ),
  313. request_kwargs=dict(
  314. model=ollama,
  315. max_tokens=n_predict,
  316. num_ctx = n_ctx,
  317. ),
  318. )
  319. if __name__ == "__main__":
  320. app()