1
0

tool_bench.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. #!/usr/bin/env uv run
  2. '''
  3. Simplistic tool call benchmarks for llama-server and ollama.
  4. Essentially runs the tests at server/tools/server/tests/unit/test_tool_call.py N times, at different temperatures and on different backends (current llama-server, baseline llama-server and ollama),
  5. and plots the results of multiple runs (from same .jsonl file or multiple ones) as a success rate heatmap.
  6. Simple usage example:
  7. cmake -B build -DLLAMA_CURL=1 && cmake --build build --config Release -j -t llama-server
  8. export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server
  9. export LLAMA_CACHE=${LLAMA_CACHE:-$HOME/Library/Caches/llama.cpp}
  10. ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 1.5B Q4_K_M" --output qwen1.5b.jsonl --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M
  11. ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 Coder 7B Q4_K_M" --output qwenc7b.jsonl --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b
  12. ./scripts/tool_bench.py plot *.jsonl # Opens window w/ heatmap
  13. ./scripts/tool_bench.py plot qwen*.jsonl --output qwen.png # Saves heatmap to qwen.png
  14. (please see ./scripts/tool_bench.sh for a more complete example)
  15. '''
  16. # /// script
  17. # requires-python = ">=3.10"
  18. # dependencies = [
  19. # "pytest",
  20. # "pandas",
  21. # "matplotlib",
  22. # "seaborn",
  23. # "requests",
  24. # "wget",
  25. # "typer",
  26. # ]
  27. # ///
  28. from contextlib import contextmanager
  29. from pathlib import Path
  30. import re
  31. from statistics import mean, median
  32. from typing import Annotated, Dict, List, Optional, Tuple
  33. import atexit
  34. import json
  35. import logging
  36. import matplotlib.pyplot as plt
  37. import numpy as np
  38. import pandas as pd
  39. import seaborn as sns
  40. import subprocess
  41. import sys
  42. import time
  43. import typer
  44. sys.path.insert(0, Path(__file__).parent.parent.as_posix())
  45. if True:
  46. from tools.server.tests.utils import ServerProcess
  47. from tools.server.tests.unit.test_tool_call import TIMEOUT_SERVER_START, do_test_calc_result, do_test_hello_world, do_test_weather
  48. @contextmanager
  49. def scoped_server(sp: ServerProcess):
  50. def stop():
  51. nonlocal sp
  52. if sp is not None:
  53. sp.stop()
  54. sp = None # type: ignore
  55. atexit.register(stop)
  56. yield sp
  57. stop()
  58. logging.basicConfig(
  59. level=logging.INFO,
  60. format='%(asctime)s - %(levelname)s - %(message)s'
  61. )
  62. logger = logging.getLogger(__name__)
  63. app = typer.Typer()
  64. @app.command()
  65. def plot(files: List[Path], output: Optional[Path] = None, test_regex: Optional[str] = None, server_regex: Optional[str] = None):
  66. lines: List[Dict] = []
  67. for file in files:
  68. if not file.exists():
  69. logger.error(f"File not found: {file}")
  70. continue
  71. try:
  72. with file.open() as f:
  73. raw_data = f.read()
  74. logger.info(f"Reading {file} ({len(raw_data)} bytes)")
  75. for line_num, line in enumerate(raw_data.split('\n'), 1):
  76. line = line.strip()
  77. if not line:
  78. continue
  79. try:
  80. record = json.loads(line)
  81. lines.append(record)
  82. except json.JSONDecodeError as e:
  83. logger.warning(f"Invalid JSON at {file}:{line_num} - {e}")
  84. except Exception as e:
  85. logger.error(f"Error processing {file}: {e}")
  86. if not lines:
  87. raise Exception("No valid data was loaded")
  88. data_dict: Dict[Tuple, float] = {}
  89. models: List[str] = []
  90. temps = set()
  91. tests = set()
  92. server_names = set()
  93. total_counts = set()
  94. for rec in lines:
  95. try:
  96. model = rec["model"]
  97. temp = rec["temp"]
  98. server_name = rec["server_name"]
  99. test = rec["test"]
  100. success = rec["success_ratio"]
  101. success_count = rec["success_count"]
  102. failure_count = rec["failure_count"]
  103. total_count = success_count + failure_count
  104. total_counts.add(total_count)
  105. if test_regex and not re.search(test_regex, test):
  106. continue
  107. if server_regex and not re.search(server_regex, server_name):
  108. continue
  109. data_dict[(model, temp, server_name, test)] = success
  110. if model not in models:
  111. models.append(model)
  112. temps.add(temp)
  113. tests.add(test)
  114. server_names.add(server_name)
  115. except KeyError as e:
  116. logger.warning(f"Missing required field in record: {e}")
  117. if len(total_counts) > 1:
  118. logger.warning(f"Total counts are not consistent: {total_counts}")
  119. # Sort the collected values
  120. temps = list(sorted(temps, key=lambda x: x if x is not None else -1))
  121. tests = list(sorted(tests))
  122. server_names = list(sorted(server_names))
  123. logger.info(f"Processed {len(lines)} lines")
  124. logger.info(f"Found {len(data_dict)} valid data points")
  125. logger.info(f"Models: {models}")
  126. logger.info(f"Temperatures: {temps}")
  127. logger.info(f"Tests: {tests}")
  128. logger.info(f"Servers: {server_names}")
  129. matrix: list[list[float]] = []
  130. index: list[str] = []
  131. all_cols = [
  132. (server_name, test)
  133. for server_name in server_names
  134. for test in tests
  135. ]
  136. for model in models:
  137. for temp in temps:
  138. index.append(f"{model} @ {temp}")
  139. row_vals = [
  140. data_dict.get((model, temp, server_name, test), np.nan)
  141. for server_name, test in all_cols
  142. ]
  143. matrix.append(row_vals)
  144. columns: list[str] = [f"{server_name}\n{test}" for server_name, test in all_cols]
  145. df = pd.DataFrame(matrix, index=np.array(index), columns=np.array(columns))
  146. plt.figure(figsize=(12, 6))
  147. sns.heatmap(
  148. df, annot=True, cmap="RdYlGn", vmin=0.0, vmax=1.0, cbar=True, fmt=".2f", center=0.5, square=True, linewidths=0.5,
  149. cbar_kws={"label": "Success Ratio"},
  150. )
  151. plt.title(f"Tool Call Bench (n = {str(min(total_counts)) if len(total_counts) == 1 else f'{min(total_counts)}-{max(total_counts)}'})\nSuccess Ratios by Server & Test", pad=20)
  152. plt.xlabel("Server & Test", labelpad=10)
  153. plt.ylabel("Model @ Temperature", labelpad=10)
  154. plt.xticks(rotation=45, ha='right')
  155. plt.yticks(rotation=0)
  156. plt.tight_layout()
  157. if output:
  158. plt.savefig(output, dpi=300, bbox_inches='tight')
  159. logger.info(f"Plot saved to {output}")
  160. else:
  161. plt.show()
  162. @app.command()
  163. def run(
  164. output: Annotated[Path, typer.Option(help="Output JSON file")],
  165. model: Annotated[Optional[str], typer.Option(help="Name of the model to test (server agnostic)")] = None,
  166. hf: Annotated[Optional[str], typer.Option(help="GGUF huggingface model repo id (+ optional quant) to test w/ llama-server")] = None,
  167. chat_template: Annotated[Optional[str], typer.Option(help="Chat template override for llama-server")] = None,
  168. ollama: Annotated[Optional[str], typer.Option(help="Ollama model tag to test")] = None,
  169. llama_baseline: Annotated[Optional[str], typer.Option(help="llama-server baseline binary path to use as baseline")] = None,
  170. n: Annotated[int, typer.Option(help="Number of times to run each test")] = 10,
  171. temp: Annotated[Optional[List[float]], typer.Option(help="Set of temperatures to test")] = None,
  172. top_p: Annotated[Optional[float], typer.Option(help="top_p")] = None,
  173. top_k: Annotated[Optional[int], typer.Option(help="top_k")] = None,
  174. ctk: Annotated[Optional[str], typer.Option(help="ctk")] = None,
  175. ctv: Annotated[Optional[str], typer.Option(help="ctv")] = None,
  176. fa: Annotated[Optional[bool], typer.Option(help="fa")] = None,
  177. seed: Annotated[Optional[int], typer.Option(help="Random seed")] = None,
  178. port: Annotated[int, typer.Option(help="llama-server port")] = 8084,
  179. force: Annotated[bool, typer.Option(help="Force overwrite of output file")] = False,
  180. append: Annotated[bool, typer.Option(help="Append to output file")] = False,
  181. test_hello_world: Annotated[bool, typer.Option(help="Whether to run the hello world test")] = True,
  182. test_weather: Annotated[bool, typer.Option(help="Whether to run the weather test")] = True,
  183. test_calc_result: Annotated[bool, typer.Option(help="Whether to run the calc result test")] = False,
  184. ):
  185. # Check only one of output and append
  186. n_predict = 512 # High because of DeepSeek R1
  187. # n_ctx = 8192
  188. n_ctx = 2048
  189. assert force or append or not output.exists(), f"Output file already exists: {output}; use --force to overwrite"
  190. with output.open('a' if append else 'w') as output_file:
  191. def run(server: ServerProcess, *, server_name: str, model_id: str, temp: Optional[float] = None, output_kwargs={}, request_kwargs={}):
  192. request_kwargs = {**request_kwargs}
  193. if temp is not None:
  194. request_kwargs['temperature'] = temp
  195. if top_p is not None:
  196. request_kwargs['top_p'] = top_p
  197. if top_k is not None:
  198. request_kwargs['top_k'] = top_k
  199. if seed is not None:
  200. request_kwargs['seed'] = seed
  201. request_kwargs['cache_prompt'] = False
  202. tests = {}
  203. if test_hello_world:
  204. tests["hello world"] = lambda server: do_test_hello_world(server, **request_kwargs)
  205. if test_weather:
  206. tests["weather"] = lambda server: do_test_weather(server, **request_kwargs)
  207. if test_calc_result:
  208. tests["calc result"] = lambda server: do_test_calc_result(server, None, 512, **request_kwargs)
  209. for test_name, test in tests.items():
  210. success_count = 0
  211. failure_count = 0
  212. failures = []
  213. success_times = []
  214. failure_times = []
  215. logger.info(f"Running {test_name} ({server_name}, {model}): ")
  216. for i in range(n):
  217. start_time = time.time()
  218. def elapsed():
  219. return time.time() - start_time
  220. try:
  221. test(server)
  222. success_times.append(elapsed())
  223. success_count += 1
  224. logger.info('success')
  225. except Exception as e:
  226. logger.error(f'failure: {e}')
  227. failure_count += 1
  228. failure_times.append(elapsed())
  229. failures.append(str(e))
  230. # import traceback
  231. # traceback.print_exc()
  232. output_file.write(json.dumps({**output_kwargs, **dict(
  233. model=model,
  234. server_name=server_name,
  235. model_id=model_id,
  236. test=test_name,
  237. temp=t,
  238. top_p=top_p,
  239. top_k=top_k,
  240. ctk=ctk,
  241. ctv=ctv,
  242. seed=seed,
  243. success_ratio=float(success_count) / n,
  244. avg_time=mean(success_times + failure_times),
  245. median_time=median(success_times + failure_times),
  246. success_count=success_count,
  247. success_times=success_times,
  248. failure_count=failure_count,
  249. failure_times=failure_times,
  250. failures=list(set(failures)),
  251. )}) + '\n')
  252. output_file.flush()
  253. for t in [None] if temp is None else [t if t >= 0 else None for t in temp]:
  254. if hf is not None:
  255. servers: list[Tuple[str, Optional[str]]] = [('llama-server', None)]
  256. if llama_baseline is not None:
  257. servers.append(('llama-server (baseline)', llama_baseline))
  258. for server_name, server_path in servers:
  259. server = ServerProcess()
  260. server.n_ctx = n_ctx
  261. server.n_slots = 1
  262. server.jinja = True
  263. server.ctk = ctk
  264. server.ctv = ctv
  265. server.fa = fa
  266. server.n_predict = n_predict
  267. server.model_hf_repo = hf
  268. server.model_hf_file = None
  269. server.chat_template = chat_template
  270. server.server_path = server_path
  271. if port is not None:
  272. server.server_port = port
  273. # server.debug = True
  274. with scoped_server(server):
  275. server.start(timeout_seconds=TIMEOUT_SERVER_START)
  276. for ignore_chat_grammar in [False]:
  277. run(
  278. server,
  279. server_name=server_name,
  280. model_id=hf,
  281. temp=t,
  282. output_kwargs=dict(
  283. chat_template=chat_template,
  284. ),
  285. request_kwargs=dict(
  286. ignore_chat_grammar=ignore_chat_grammar,
  287. ),
  288. )
  289. if ollama is not None:
  290. server = ServerProcess()
  291. server.server_port = 11434
  292. server.server_host = "localhost"
  293. subprocess.check_call(["ollama", "pull", ollama])
  294. with scoped_server(server):
  295. run(
  296. server,
  297. server_name="ollama",
  298. model_id=ollama,
  299. temp=t,
  300. output_kwargs=dict(
  301. chat_template=None,
  302. ),
  303. request_kwargs=dict(
  304. model=ollama,
  305. max_tokens=n_predict,
  306. num_ctx = n_ctx,
  307. ),
  308. )
  309. if __name__ == "__main__":
  310. app()