compare-llama-bench.py 45 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094
  1. #!/usr/bin/env python3
  2. import argparse
  3. import csv
  4. import heapq
  5. import json
  6. import logging
  7. import os
  8. import sqlite3
  9. import sys
  10. from collections.abc import Iterator, Sequence
  11. from glob import glob
  12. from typing import Any, Optional, Union
  13. try:
  14. import git
  15. from tabulate import tabulate
  16. except ImportError as e:
  17. print("the following Python libraries are required: GitPython, tabulate.") # noqa: NP100
  18. raise e
  19. logger = logging.getLogger("compare-llama-bench")
  20. # All llama-bench SQL fields
  21. LLAMA_BENCH_DB_FIELDS = [
  22. "build_commit", "build_number", "cpu_info", "gpu_info", "backends", "model_filename",
  23. "model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", "n_threads",
  24. "cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers",
  25. "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "tensor_buft_overrides",
  26. "use_mmap", "embeddings", "no_op_offload", "n_prompt", "n_gen", "n_depth",
  27. "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts",
  28. ]
  29. LLAMA_BENCH_DB_TYPES = [
  30. "TEXT", "INTEGER", "TEXT", "TEXT", "TEXT", "TEXT",
  31. "TEXT", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER",
  32. "TEXT", "INTEGER", "INTEGER", "TEXT", "TEXT", "INTEGER",
  33. "TEXT", "INTEGER", "INTEGER", "INTEGER", "TEXT", "TEXT",
  34. "REAL",
  35. "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER",
  36. "TEXT", "INTEGER", "INTEGER", "REAL", "REAL",
  37. ]
  38. # All test-backend-ops SQL fields
  39. TEST_BACKEND_OPS_DB_FIELDS = [
  40. "test_time", "build_commit", "backend_name", "op_name", "op_params", "test_mode",
  41. "supported", "passed", "error_message", "time_us", "flops", "bandwidth_gb_s",
  42. "memory_kb", "n_runs"
  43. ]
  44. TEST_BACKEND_OPS_DB_TYPES = [
  45. "TEXT", "TEXT", "TEXT", "TEXT", "TEXT", "TEXT",
  46. "INTEGER", "INTEGER", "TEXT", "REAL", "REAL", "REAL",
  47. "INTEGER", "INTEGER"
  48. ]
  49. assert len(LLAMA_BENCH_DB_FIELDS) == len(LLAMA_BENCH_DB_TYPES)
  50. assert len(TEST_BACKEND_OPS_DB_FIELDS) == len(TEST_BACKEND_OPS_DB_TYPES)
  51. # Properties by which to differentiate results per commit for llama-bench:
  52. LLAMA_BENCH_KEY_PROPERTIES = [
  53. "cpu_info", "gpu_info", "backends", "n_gpu_layers", "tensor_buft_overrides", "model_filename", "model_type",
  54. "n_batch", "n_ubatch", "embeddings", "cpu_mask", "cpu_strict", "poll", "n_threads", "type_k", "type_v",
  55. "use_mmap", "no_kv_offload", "split_mode", "main_gpu", "tensor_split", "flash_attn", "n_prompt", "n_gen", "n_depth"
  56. ]
  57. # Properties by which to differentiate results per commit for test-backend-ops:
  58. TEST_BACKEND_OPS_KEY_PROPERTIES = [
  59. "backend_name", "op_name", "op_params", "test_mode"
  60. ]
  61. # Properties that are boolean and are converted to Yes/No for the table:
  62. LLAMA_BENCH_BOOL_PROPERTIES = ["embeddings", "cpu_strict", "use_mmap", "no_kv_offload", "flash_attn"]
  63. TEST_BACKEND_OPS_BOOL_PROPERTIES = ["supported", "passed"]
  64. # Header names for the table (llama-bench):
  65. LLAMA_BENCH_PRETTY_NAMES = {
  66. "cpu_info": "CPU", "gpu_info": "GPU", "backends": "Backends", "n_gpu_layers": "GPU layers",
  67. "tensor_buft_overrides": "Tensor overrides", "model_filename": "File", "model_type": "Model", "model_size": "Model size [GiB]",
  68. "model_n_params": "Num. of par.", "n_batch": "Batch size", "n_ubatch": "Microbatch size", "embeddings": "Embeddings",
  69. "cpu_mask": "CPU mask", "cpu_strict": "CPU strict", "poll": "Poll", "n_threads": "Threads", "type_k": "K type", "type_v": "V type",
  70. "use_mmap": "Use mmap", "no_kv_offload": "NKVO", "split_mode": "Split mode", "main_gpu": "Main GPU", "tensor_split": "Tensor split",
  71. "flash_attn": "FlashAttention",
  72. }
  73. # Header names for the table (test-backend-ops):
  74. TEST_BACKEND_OPS_PRETTY_NAMES = {
  75. "backend_name": "Backend", "op_name": "GGML op", "op_params": "Op parameters", "test_mode": "Mode",
  76. "supported": "Supported", "passed": "Passed", "error_message": "Error",
  77. "flops": "FLOPS", "bandwidth_gb_s": "Bandwidth (GB/s)", "memory_kb": "Memory (KB)", "n_runs": "Runs"
  78. }
  79. DEFAULT_SHOW_LLAMA_BENCH = ["model_type"] # Always show these properties by default.
  80. DEFAULT_HIDE_LLAMA_BENCH = ["model_filename"] # Always hide these properties by default.
  81. DEFAULT_SHOW_TEST_BACKEND_OPS = ["backend_name", "op_name"] # Always show these properties by default.
  82. DEFAULT_HIDE_TEST_BACKEND_OPS = ["error_message"] # Always hide these properties by default.
  83. GPU_NAME_STRIP = ["NVIDIA GeForce ", "Tesla ", "AMD Radeon "] # Strip prefixes for smaller tables.
  84. MODEL_SUFFIX_REPLACE = {" - Small": "_S", " - Medium": "_M", " - Large": "_L"}
  85. DESCRIPTION = """Creates tables from llama-bench or test-backend-ops data written to multiple JSON/CSV files, a single JSONL file or SQLite database. Example usage (Linux):
  86. For llama-bench:
  87. $ git checkout master
  88. $ cmake -B ${BUILD_DIR} ${CMAKE_OPTS} && cmake --build ${BUILD_DIR} -t llama-bench -j $(nproc)
  89. $ ./llama-bench -o sql | sqlite3 llama-bench.sqlite
  90. $ git checkout some_branch
  91. $ cmake -B ${BUILD_DIR} ${CMAKE_OPTS} && cmake --build ${BUILD_DIR} -t llama-bench -j $(nproc)
  92. $ ./llama-bench -o sql | sqlite3 llama-bench.sqlite
  93. $ ./scripts/compare-llama-bench.py
  94. For test-backend-ops:
  95. $ git checkout master
  96. $ cmake -B ${BUILD_DIR} ${CMAKE_OPTS} && cmake --build ${BUILD_DIR} -t test-backend-ops -j $(nproc)
  97. $ ./test-backend-ops perf --output sql | sqlite3 test-backend-ops.sqlite
  98. $ git checkout some_branch
  99. $ cmake -B ${BUILD_DIR} ${CMAKE_OPTS} && cmake --build ${BUILD_DIR} -t test-backend-ops -j $(nproc)
  100. $ ./test-backend-ops perf --output sql | sqlite3 test-backend-ops.sqlite
  101. $ ./scripts/compare-llama-bench.py --tool test-backend-ops -i test-backend-ops.sqlite
  102. Performance numbers from multiple runs per commit are averaged WITHOUT being weighted by the --repetitions parameter of llama-bench.
  103. """
  104. parser = argparse.ArgumentParser(
  105. description=DESCRIPTION, formatter_class=argparse.RawDescriptionHelpFormatter)
  106. help_b = (
  107. "The baseline commit to compare performance to. "
  108. "Accepts either a branch name, tag name, or commit hash. "
  109. "Defaults to latest master commit with data."
  110. )
  111. parser.add_argument("-b", "--baseline", help=help_b)
  112. help_c = (
  113. "The commit whose performance is to be compared to the baseline. "
  114. "Accepts either a branch name, tag name, or commit hash. "
  115. "Defaults to the non-master commit for which llama-bench was run most recently."
  116. )
  117. parser.add_argument("-c", "--compare", help=help_c)
  118. help_t = (
  119. "The tool whose data is being compared. "
  120. "Either 'llama-bench' or 'test-backend-ops'. "
  121. "This determines the database schema and comparison logic used. "
  122. "If left unspecified, try to determine from the input file."
  123. )
  124. parser.add_argument("-t", "--tool", help=help_t, default=None, choices=[None, "llama-bench", "test-backend-ops"])
  125. help_i = (
  126. "JSON/JSONL/SQLite/CSV files for comparing commits. "
  127. "Specify multiple times to use multiple input files (JSON/CSV only). "
  128. "Defaults to 'llama-bench.sqlite' in the current working directory. "
  129. "If no such file is found and there is exactly one .sqlite file in the current directory, "
  130. "that file is instead used as input."
  131. )
  132. parser.add_argument("-i", "--input", action="append", help=help_i)
  133. help_o = (
  134. "Output format for the table. "
  135. "Defaults to 'pipe' (GitHub compatible). "
  136. "Also supports e.g. 'latex' or 'mediawiki'. "
  137. "See tabulate documentation for full list."
  138. )
  139. parser.add_argument("-o", "--output", help=help_o, default="pipe")
  140. help_s = (
  141. "Columns to add to the table. "
  142. "Accepts a comma-separated list of values. "
  143. f"Legal values for test-backend-ops: {', '.join(TEST_BACKEND_OPS_KEY_PROPERTIES)}. "
  144. f"Legal values for llama-bench: {', '.join(LLAMA_BENCH_KEY_PROPERTIES[:-3])}. "
  145. "Defaults to model name (model_type) and CPU and/or GPU name (cpu_info, gpu_info) "
  146. "plus any column where not all data points are the same. "
  147. "If the columns are manually specified, then the results for each unique combination of the "
  148. "specified values are averaged WITHOUT weighing by the --repetitions parameter of llama-bench."
  149. )
  150. parser.add_argument("--check", action="store_true", help="check if all required Python libraries are installed")
  151. parser.add_argument("-s", "--show", help=help_s)
  152. parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
  153. parser.add_argument("--plot", help="generate a performance comparison plot and save to specified file (e.g., plot.png)")
  154. parser.add_argument("--plot_x", help="parameter to use as x axis for plotting (default: n_depth)", default="n_depth")
  155. parser.add_argument("--plot_log_scale", action="store_true", help="use log scale for x axis in plots (off by default)")
  156. known_args, unknown_args = parser.parse_known_args()
  157. logging.basicConfig(level=logging.DEBUG if known_args.verbose else logging.INFO)
  158. if known_args.check:
  159. # Check if all required Python libraries are installed. Would have failed earlier if not.
  160. sys.exit(0)
  161. if unknown_args:
  162. logger.error(f"Received unknown args: {unknown_args}.\n")
  163. parser.print_help()
  164. sys.exit(1)
  165. input_file = known_args.input
  166. tool = known_args.tool
  167. if not input_file:
  168. if tool == "llama-bench" and os.path.exists("./llama-bench.sqlite"):
  169. input_file = ["llama-bench.sqlite"]
  170. elif tool == "test-backend-ops" and os.path.exists("./test-backend-ops.sqlite"):
  171. input_file = ["test-backend-ops.sqlite"]
  172. if not input_file:
  173. sqlite_files = glob("*.sqlite")
  174. if len(sqlite_files) == 1:
  175. input_file = sqlite_files
  176. if not input_file:
  177. logger.error("Cannot find a suitable input file, please provide one.\n")
  178. parser.print_help()
  179. sys.exit(1)
  180. class LlamaBenchData:
  181. repo: Optional[git.Repo]
  182. build_len_min: int
  183. build_len_max: int
  184. build_len: int = 8
  185. builds: list[str] = []
  186. tool: str = "llama-bench" # Tool type: "llama-bench" or "test-backend-ops"
  187. def __init__(self, tool: str = "llama-bench"):
  188. self.tool = tool
  189. try:
  190. self.repo = git.Repo(".", search_parent_directories=True)
  191. except git.InvalidGitRepositoryError:
  192. self.repo = None
  193. # Set schema-specific properties based on tool
  194. if self.tool == "llama-bench":
  195. self.check_keys = set(LLAMA_BENCH_KEY_PROPERTIES + ["build_commit", "test_time", "avg_ts"])
  196. elif self.tool == "test-backend-ops":
  197. self.check_keys = set(TEST_BACKEND_OPS_KEY_PROPERTIES + ["build_commit", "test_time"])
  198. else:
  199. assert False
  200. def _builds_init(self):
  201. self.build_len = self.build_len_min
  202. def _check_keys(self, keys: set) -> Optional[set]:
  203. """Private helper method that checks against required data keys and returns missing ones."""
  204. if not keys >= self.check_keys:
  205. return self.check_keys - keys
  206. return None
  207. def find_parent_in_data(self, commit: git.Commit) -> Optional[str]:
  208. """Helper method to find the most recent parent measured in number of commits for which there is data."""
  209. heap: list[tuple[int, git.Commit]] = [(0, commit)]
  210. seen_hexsha8 = set()
  211. while heap:
  212. depth, current_commit = heapq.heappop(heap)
  213. current_hexsha8 = commit.hexsha[:self.build_len]
  214. if current_hexsha8 in self.builds:
  215. return current_hexsha8
  216. for parent in commit.parents:
  217. parent_hexsha8 = parent.hexsha[:self.build_len]
  218. if parent_hexsha8 not in seen_hexsha8:
  219. seen_hexsha8.add(parent_hexsha8)
  220. heapq.heappush(heap, (depth + 1, parent))
  221. return None
  222. def get_all_parent_hexsha8s(self, commit: git.Commit) -> Sequence[str]:
  223. """Helper method to recursively get hexsha8 values for all parents of a commit."""
  224. unvisited = [commit]
  225. visited = []
  226. while unvisited:
  227. current_commit = unvisited.pop(0)
  228. visited.append(current_commit.hexsha[:self.build_len])
  229. for parent in current_commit.parents:
  230. if parent.hexsha[:self.build_len] not in visited:
  231. unvisited.append(parent)
  232. return visited
  233. def get_commit_name(self, hexsha8: str) -> str:
  234. """Helper method to find a human-readable name for a commit if possible."""
  235. if self.repo is None:
  236. return hexsha8
  237. for h in self.repo.heads:
  238. if h.commit.hexsha[:self.build_len] == hexsha8:
  239. return h.name
  240. for t in self.repo.tags:
  241. if t.commit.hexsha[:self.build_len] == hexsha8:
  242. return t.name
  243. return hexsha8
  244. def get_commit_hexsha8(self, name: str) -> Optional[str]:
  245. """Helper method to search for a commit given a human-readable name."""
  246. if self.repo is None:
  247. return None
  248. for h in self.repo.heads:
  249. if h.name == name:
  250. return h.commit.hexsha[:self.build_len]
  251. for t in self.repo.tags:
  252. if t.name == name:
  253. return t.commit.hexsha[:self.build_len]
  254. for c in self.repo.iter_commits("--all"):
  255. if c.hexsha[:self.build_len] == name[:self.build_len]:
  256. return c.hexsha[:self.build_len]
  257. return None
  258. def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequence[tuple]]:
  259. """Helper method that gets rows of (build_commit, test_time) sorted by the latter."""
  260. return []
  261. def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]:
  262. """
  263. Helper method that gets table rows for some list of properties.
  264. Rows are created by combining those where all provided properties are equal.
  265. The resulting rows are then grouped by the provided properties and the t/s values are averaged.
  266. The returned rows are unique in terms of property combinations.
  267. """
  268. return []
  269. class LlamaBenchDataSQLite3(LlamaBenchData):
  270. connection: Optional[sqlite3.Connection] = None
  271. cursor: sqlite3.Cursor
  272. table_name: str
  273. def __init__(self, tool: str = "llama-bench"):
  274. super().__init__(tool)
  275. if self.connection is None:
  276. self.connection = sqlite3.connect(":memory:")
  277. self.cursor = self.connection.cursor()
  278. # Set table name and schema based on tool
  279. if self.tool == "llama-bench":
  280. self.table_name = "llama_bench"
  281. db_fields = LLAMA_BENCH_DB_FIELDS
  282. db_types = LLAMA_BENCH_DB_TYPES
  283. elif self.tool == "test-backend-ops":
  284. self.table_name = "test_backend_ops"
  285. db_fields = TEST_BACKEND_OPS_DB_FIELDS
  286. db_types = TEST_BACKEND_OPS_DB_TYPES
  287. else:
  288. assert False
  289. self.cursor.execute(f"CREATE TABLE {self.table_name}({', '.join(' '.join(x) for x in zip(db_fields, db_types))});")
  290. def _builds_init(self):
  291. if self.connection:
  292. self.build_len_min = self.cursor.execute(f"SELECT MIN(LENGTH(build_commit)) from {self.table_name};").fetchone()[0]
  293. self.build_len_max = self.cursor.execute(f"SELECT MAX(LENGTH(build_commit)) from {self.table_name};").fetchone()[0]
  294. if self.build_len_min != self.build_len_max:
  295. logger.warning("Data contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. "
  296. "Try purging the the database of old commits.")
  297. self.cursor.execute(f"UPDATE {self.table_name} SET build_commit = SUBSTRING(build_commit, 1, {self.build_len_min});")
  298. builds = self.cursor.execute(f"SELECT DISTINCT build_commit FROM {self.table_name};").fetchall()
  299. self.builds = list(map(lambda b: b[0], builds)) # list[tuple[str]] -> list[str]
  300. super()._builds_init()
  301. def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequence[tuple]]:
  302. data = self.cursor.execute(
  303. f"SELECT build_commit, test_time FROM {self.table_name} ORDER BY test_time;").fetchall()
  304. return reversed(data) if reverse else data
  305. def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]:
  306. if self.tool == "llama-bench":
  307. return self._get_rows_llama_bench(properties, hexsha8_baseline, hexsha8_compare)
  308. elif self.tool == "test-backend-ops":
  309. return self._get_rows_test_backend_ops(properties, hexsha8_baseline, hexsha8_compare)
  310. else:
  311. assert False
  312. def _get_rows_llama_bench(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]:
  313. select_string = ", ".join(
  314. [f"tb.{p}" for p in properties] + ["tb.n_prompt", "tb.n_gen", "tb.n_depth", "AVG(tb.avg_ts)", "AVG(tc.avg_ts)"])
  315. equal_string = " AND ".join(
  316. [f"tb.{p} = tc.{p}" for p in LLAMA_BENCH_KEY_PROPERTIES] + [
  317. f"tb.build_commit = '{hexsha8_baseline}'", f"tc.build_commit = '{hexsha8_compare}'"]
  318. )
  319. group_order_string = ", ".join([f"tb.{p}" for p in properties] + ["tb.n_gen", "tb.n_prompt", "tb.n_depth"])
  320. query = (f"SELECT {select_string} FROM {self.table_name} tb JOIN {self.table_name} tc ON {equal_string} "
  321. f"GROUP BY {group_order_string} ORDER BY {group_order_string};")
  322. return self.cursor.execute(query).fetchall()
  323. def _get_rows_test_backend_ops(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]:
  324. # For test-backend-ops, we compare FLOPS and bandwidth metrics (prioritizing FLOPS over bandwidth)
  325. select_string = ", ".join(
  326. [f"tb.{p}" for p in properties] + [
  327. "AVG(tb.flops)", "AVG(tc.flops)",
  328. "AVG(tb.bandwidth_gb_s)", "AVG(tc.bandwidth_gb_s)"
  329. ])
  330. equal_string = " AND ".join(
  331. [f"tb.{p} = tc.{p}" for p in TEST_BACKEND_OPS_KEY_PROPERTIES] + [
  332. f"tb.build_commit = '{hexsha8_baseline}'", f"tc.build_commit = '{hexsha8_compare}'",
  333. "tb.supported = 1", "tc.supported = 1", "tb.passed = 1", "tc.passed = 1"] # Only compare successful tests
  334. )
  335. group_order_string = ", ".join([f"tb.{p}" for p in properties])
  336. query = (f"SELECT {select_string} FROM {self.table_name} tb JOIN {self.table_name} tc ON {equal_string} "
  337. f"GROUP BY {group_order_string} ORDER BY {group_order_string};")
  338. return self.cursor.execute(query).fetchall()
  339. class LlamaBenchDataSQLite3File(LlamaBenchDataSQLite3):
  340. def __init__(self, data_file: str, tool: Any):
  341. self.connection = sqlite3.connect(data_file)
  342. self.cursor = self.connection.cursor()
  343. # Check which table exists in the database
  344. tables = self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
  345. table_names = [table[0] for table in tables]
  346. # Tool selection logic
  347. if tool is None:
  348. if "llama_bench" in table_names:
  349. self.table_name = "llama_bench"
  350. tool = "llama-bench"
  351. elif "test_backend_ops" in table_names:
  352. self.table_name = "test_backend_ops"
  353. tool = "test-backend-ops"
  354. else:
  355. raise RuntimeError(f"No suitable table found in database. Available tables: {table_names}")
  356. elif tool == "llama-bench":
  357. if "llama_bench" in table_names:
  358. self.table_name = "llama_bench"
  359. tool = "llama-bench"
  360. else:
  361. raise RuntimeError(f"Table 'test' not found for tool 'llama-bench'. Available tables: {table_names}")
  362. elif tool == "test-backend-ops":
  363. if "test_backend_ops" in table_names:
  364. self.table_name = "test_backend_ops"
  365. tool = "test-backend-ops"
  366. else:
  367. raise RuntimeError(f"Table 'test_backend_ops' not found for tool 'test-backend-ops'. Available tables: {table_names}")
  368. else:
  369. raise RuntimeError(f"Unknown tool: {tool}")
  370. super().__init__(tool)
  371. self._builds_init()
  372. @staticmethod
  373. def valid_format(data_file: str) -> bool:
  374. connection = sqlite3.connect(data_file)
  375. cursor = connection.cursor()
  376. try:
  377. if cursor.execute("PRAGMA schema_version;").fetchone()[0] == 0:
  378. raise sqlite3.DatabaseError("The provided input file does not exist or is empty.")
  379. except sqlite3.DatabaseError as e:
  380. logger.debug(f'"{data_file}" is not a valid SQLite3 file.', exc_info=e)
  381. cursor = None
  382. connection.close()
  383. return True if cursor else False
  384. class LlamaBenchDataJSONL(LlamaBenchDataSQLite3):
  385. def __init__(self, data_file: str, tool: str = "llama-bench"):
  386. super().__init__(tool)
  387. # Get the appropriate field list based on tool
  388. db_fields = LLAMA_BENCH_DB_FIELDS if tool == "llama-bench" else TEST_BACKEND_OPS_DB_FIELDS
  389. with open(data_file, "r", encoding="utf-8") as fp:
  390. for i, line in enumerate(fp):
  391. parsed = json.loads(line)
  392. for k in parsed.keys() - set(db_fields):
  393. del parsed[k]
  394. if (missing_keys := self._check_keys(parsed.keys())):
  395. raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}")
  396. self.cursor.execute(f"INSERT INTO {self.table_name}({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", tuple(parsed.values()))
  397. self._builds_init()
  398. @staticmethod
  399. def valid_format(data_file: str) -> bool:
  400. try:
  401. with open(data_file, "r", encoding="utf-8") as fp:
  402. for line in fp:
  403. json.loads(line)
  404. break
  405. except Exception as e:
  406. logger.debug(f'"{data_file}" is not a valid JSONL file.', exc_info=e)
  407. return False
  408. return True
  409. class LlamaBenchDataJSON(LlamaBenchDataSQLite3):
  410. def __init__(self, data_files: list[str], tool: str = "llama-bench"):
  411. super().__init__(tool)
  412. # Get the appropriate field list based on tool
  413. db_fields = LLAMA_BENCH_DB_FIELDS if tool == "llama-bench" else TEST_BACKEND_OPS_DB_FIELDS
  414. for data_file in data_files:
  415. with open(data_file, "r", encoding="utf-8") as fp:
  416. parsed = json.load(fp)
  417. for i, entry in enumerate(parsed):
  418. for k in entry.keys() - set(db_fields):
  419. del entry[k]
  420. if (missing_keys := self._check_keys(entry.keys())):
  421. raise RuntimeError(f"Missing required data key(s) at entry {i + 1}: {', '.join(missing_keys)}")
  422. self.cursor.execute(f"INSERT INTO {self.table_name}({', '.join(entry.keys())}) VALUES({', '.join('?' * len(entry))});", tuple(entry.values()))
  423. self._builds_init()
  424. @staticmethod
  425. def valid_format(data_files: list[str]) -> bool:
  426. if not data_files:
  427. return False
  428. for data_file in data_files:
  429. try:
  430. with open(data_file, "r", encoding="utf-8") as fp:
  431. json.load(fp)
  432. except Exception as e:
  433. logger.debug(f'"{data_file}" is not a valid JSON file.', exc_info=e)
  434. return False
  435. return True
  436. class LlamaBenchDataCSV(LlamaBenchDataSQLite3):
  437. def __init__(self, data_files: list[str], tool: str = "llama-bench"):
  438. super().__init__(tool)
  439. # Get the appropriate field list based on tool
  440. db_fields = LLAMA_BENCH_DB_FIELDS if tool == "llama-bench" else TEST_BACKEND_OPS_DB_FIELDS
  441. for data_file in data_files:
  442. with open(data_file, "r", encoding="utf-8") as fp:
  443. for i, parsed in enumerate(csv.DictReader(fp)):
  444. keys = set(parsed.keys())
  445. for k in keys - set(db_fields):
  446. del parsed[k]
  447. if (missing_keys := self._check_keys(keys)):
  448. raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}")
  449. self.cursor.execute(f"INSERT INTO {self.table_name}({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", tuple(parsed.values()))
  450. self._builds_init()
  451. @staticmethod
  452. def valid_format(data_files: list[str]) -> bool:
  453. if not data_files:
  454. return False
  455. for data_file in data_files:
  456. try:
  457. with open(data_file, "r", encoding="utf-8") as fp:
  458. for parsed in csv.DictReader(fp):
  459. break
  460. except Exception as e:
  461. logger.debug(f'"{data_file}" is not a valid CSV file.', exc_info=e)
  462. return False
  463. return True
  464. def format_flops(flops_value: float) -> str:
  465. """Format FLOPS values with appropriate units for better readability."""
  466. if flops_value == 0:
  467. return "0.00"
  468. # Define unit thresholds and names
  469. units = [
  470. (1e12, "T"), # TeraFLOPS
  471. (1e9, "G"), # GigaFLOPS
  472. (1e6, "M"), # MegaFLOPS
  473. (1e3, "k"), # kiloFLOPS
  474. (1, "") # FLOPS
  475. ]
  476. for threshold, unit in units:
  477. if abs(flops_value) >= threshold:
  478. formatted_value = flops_value / threshold
  479. if formatted_value >= 100:
  480. return f"{formatted_value:.1f}{unit}"
  481. else:
  482. return f"{formatted_value:.2f}{unit}"
  483. # Fallback for very small values
  484. return f"{flops_value:.2f}"
  485. def format_flops_for_table(flops_value: float, target_unit: str) -> str:
  486. """Format FLOPS values for table display without unit suffix (since unit is in header)."""
  487. if flops_value == 0:
  488. return "0.00"
  489. # Define unit thresholds based on target unit
  490. unit_divisors = {
  491. "TFLOPS": 1e12,
  492. "GFLOPS": 1e9,
  493. "MFLOPS": 1e6,
  494. "kFLOPS": 1e3,
  495. "FLOPS": 1
  496. }
  497. divisor = unit_divisors.get(target_unit, 1)
  498. formatted_value = flops_value / divisor
  499. if formatted_value >= 100:
  500. return f"{formatted_value:.1f}"
  501. else:
  502. return f"{formatted_value:.2f}"
  503. def get_flops_unit_name(flops_values: list) -> str:
  504. """Determine the best FLOPS unit name based on the magnitude of values."""
  505. if not flops_values or all(v == 0 for v in flops_values):
  506. return "FLOPS"
  507. # Find the maximum absolute value to determine appropriate unit
  508. max_flops = max(abs(v) for v in flops_values if v != 0)
  509. if max_flops >= 1e12:
  510. return "TFLOPS"
  511. elif max_flops >= 1e9:
  512. return "GFLOPS"
  513. elif max_flops >= 1e6:
  514. return "MFLOPS"
  515. elif max_flops >= 1e3:
  516. return "kFLOPS"
  517. else:
  518. return "FLOPS"
  519. bench_data = None
  520. if len(input_file) == 1:
  521. if LlamaBenchDataSQLite3File.valid_format(input_file[0]):
  522. bench_data = LlamaBenchDataSQLite3File(input_file[0], tool)
  523. elif LlamaBenchDataJSON.valid_format(input_file):
  524. bench_data = LlamaBenchDataJSON(input_file, tool)
  525. elif LlamaBenchDataJSONL.valid_format(input_file[0]):
  526. bench_data = LlamaBenchDataJSONL(input_file[0], tool)
  527. elif LlamaBenchDataCSV.valid_format(input_file):
  528. bench_data = LlamaBenchDataCSV(input_file, tool)
  529. else:
  530. if LlamaBenchDataJSON.valid_format(input_file):
  531. bench_data = LlamaBenchDataJSON(input_file, tool)
  532. elif LlamaBenchDataCSV.valid_format(input_file):
  533. bench_data = LlamaBenchDataCSV(input_file, tool)
  534. if not bench_data:
  535. raise RuntimeError("No valid (or some invalid) input files found.")
  536. if not bench_data.builds:
  537. raise RuntimeError(f"{input_file} does not contain any builds.")
  538. tool = bench_data.tool # May have chosen a default if tool was None.
  539. hexsha8_baseline = name_baseline = None
  540. # If the user specified a baseline, try to find a commit for it:
  541. if known_args.baseline is not None:
  542. if known_args.baseline in bench_data.builds:
  543. hexsha8_baseline = known_args.baseline
  544. if hexsha8_baseline is None:
  545. hexsha8_baseline = bench_data.get_commit_hexsha8(known_args.baseline)
  546. name_baseline = known_args.baseline
  547. if hexsha8_baseline is None:
  548. logger.error(f"cannot find data for baseline={known_args.baseline}.")
  549. sys.exit(1)
  550. # Otherwise, search for the most recent parent of master for which there is data:
  551. elif bench_data.repo is not None:
  552. hexsha8_baseline = bench_data.find_parent_in_data(bench_data.repo.heads.master.commit)
  553. if hexsha8_baseline is None:
  554. logger.error("No baseline was provided and did not find data for any master branch commits.\n")
  555. parser.print_help()
  556. sys.exit(1)
  557. else:
  558. logger.error("No baseline was provided and the current working directory "
  559. "is not part of a git repository from which a baseline could be inferred.\n")
  560. parser.print_help()
  561. sys.exit(1)
  562. name_baseline = bench_data.get_commit_name(hexsha8_baseline)
  563. hexsha8_compare = name_compare = None
  564. # If the user has specified a compare value, try to find a corresponding commit:
  565. if known_args.compare is not None:
  566. if known_args.compare in bench_data.builds:
  567. hexsha8_compare = known_args.compare
  568. if hexsha8_compare is None:
  569. hexsha8_compare = bench_data.get_commit_hexsha8(known_args.compare)
  570. name_compare = known_args.compare
  571. if hexsha8_compare is None:
  572. logger.error(f"cannot find data for compare={known_args.compare}.")
  573. sys.exit(1)
  574. # Otherwise, search for the commit for llama-bench was most recently run
  575. # and that is not a parent of master:
  576. elif bench_data.repo is not None:
  577. hexsha8s_master = bench_data.get_all_parent_hexsha8s(bench_data.repo.heads.master.commit)
  578. for (hexsha8, _) in bench_data.builds_timestamp(reverse=True):
  579. if hexsha8 not in hexsha8s_master:
  580. hexsha8_compare = hexsha8
  581. break
  582. if hexsha8_compare is None:
  583. logger.error("No compare target was provided and did not find data for any non-master commits.\n")
  584. parser.print_help()
  585. sys.exit(1)
  586. else:
  587. logger.error("No compare target was provided and the current working directory "
  588. "is not part of a git repository from which a compare target could be inferred.\n")
  589. parser.print_help()
  590. sys.exit(1)
  591. name_compare = bench_data.get_commit_name(hexsha8_compare)
  592. # Get tool-specific configuration
  593. if tool == "llama-bench":
  594. key_properties = LLAMA_BENCH_KEY_PROPERTIES
  595. bool_properties = LLAMA_BENCH_BOOL_PROPERTIES
  596. pretty_names = LLAMA_BENCH_PRETTY_NAMES
  597. default_show = DEFAULT_SHOW_LLAMA_BENCH
  598. default_hide = DEFAULT_HIDE_LLAMA_BENCH
  599. elif tool == "test-backend-ops":
  600. key_properties = TEST_BACKEND_OPS_KEY_PROPERTIES
  601. bool_properties = TEST_BACKEND_OPS_BOOL_PROPERTIES
  602. pretty_names = TEST_BACKEND_OPS_PRETTY_NAMES
  603. default_show = DEFAULT_SHOW_TEST_BACKEND_OPS
  604. default_hide = DEFAULT_HIDE_TEST_BACKEND_OPS
  605. else:
  606. assert False
  607. # If the user provided columns to group the results by, use them:
  608. if known_args.show is not None:
  609. show = known_args.show.split(",")
  610. unknown_cols = []
  611. for prop in show:
  612. valid_props = key_properties if tool == "test-backend-ops" else key_properties[:-3] # Exclude n_prompt, n_gen, n_depth for llama-bench
  613. if prop not in valid_props:
  614. unknown_cols.append(prop)
  615. if unknown_cols:
  616. logger.error(f"Unknown values for --show: {', '.join(unknown_cols)}")
  617. parser.print_usage()
  618. sys.exit(1)
  619. rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare)
  620. # Otherwise, select those columns where the values are not all the same:
  621. else:
  622. rows_full = bench_data.get_rows(key_properties, hexsha8_baseline, hexsha8_compare)
  623. properties_different = []
  624. if tool == "llama-bench":
  625. # For llama-bench, skip n_prompt, n_gen, n_depth from differentiation logic
  626. check_properties = [kp for kp in key_properties if kp not in ["n_prompt", "n_gen", "n_depth"]]
  627. for i, kp_i in enumerate(key_properties):
  628. if kp_i in default_show or kp_i in ["n_prompt", "n_gen", "n_depth"]:
  629. continue
  630. for row_full in rows_full:
  631. if row_full[i] != rows_full[0][i]:
  632. properties_different.append(kp_i)
  633. break
  634. elif tool == "test-backend-ops":
  635. # For test-backend-ops, check all key properties
  636. for i, kp_i in enumerate(key_properties):
  637. if kp_i in default_show:
  638. continue
  639. for row_full in rows_full:
  640. if row_full[i] != rows_full[0][i]:
  641. properties_different.append(kp_i)
  642. break
  643. else:
  644. assert False
  645. show = []
  646. if tool == "llama-bench":
  647. # Show CPU and/or GPU by default even if the hardware for all results is the same:
  648. if rows_full and "n_gpu_layers" not in properties_different:
  649. ngl = int(rows_full[0][key_properties.index("n_gpu_layers")])
  650. if ngl != 99 and "cpu_info" not in properties_different:
  651. show.append("cpu_info")
  652. show += properties_different
  653. index_default = 0
  654. for prop in ["cpu_info", "gpu_info", "n_gpu_layers", "main_gpu"]:
  655. if prop in show:
  656. index_default += 1
  657. show = show[:index_default] + default_show + show[index_default:]
  658. elif tool == "test-backend-ops":
  659. show = default_show + properties_different
  660. else:
  661. assert False
  662. for prop in default_hide:
  663. try:
  664. show.remove(prop)
  665. except ValueError:
  666. pass
  667. # Add plot_x parameter to parameters to show if it's not already present:
  668. if known_args.plot:
  669. for k, v in pretty_names.items():
  670. if v == known_args.plot_x and k not in show:
  671. show.append(k)
  672. break
  673. rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare)
  674. if not rows_show:
  675. logger.error(f"No comparable data was found between {name_baseline} and {name_compare}.\n")
  676. sys.exit(1)
  677. table = []
  678. primary_metric = "FLOPS" # Default to FLOPS for test-backend-ops
  679. if tool == "llama-bench":
  680. # For llama-bench, create test names and compare avg_ts values
  681. for row in rows_show:
  682. n_prompt = int(row[-5])
  683. n_gen = int(row[-4])
  684. n_depth = int(row[-3])
  685. if n_prompt != 0 and n_gen == 0:
  686. test_name = f"pp{n_prompt}"
  687. elif n_prompt == 0 and n_gen != 0:
  688. test_name = f"tg{n_gen}"
  689. else:
  690. test_name = f"pp{n_prompt}+tg{n_gen}"
  691. if n_depth != 0:
  692. test_name = f"{test_name}@d{n_depth}"
  693. # Regular columns test name avg t/s values Speedup
  694. # VVVVVVVVVVVVV VVVVVVVVV VVVVVVVVVVVVVV VVVVVVV
  695. table.append(list(row[:-5]) + [test_name] + list(row[-2:]) + [float(row[-1]) / float(row[-2])])
  696. elif tool == "test-backend-ops":
  697. # Determine the primary metric by checking rows until we find one with valid data
  698. if rows_show:
  699. primary_metric = "FLOPS" # Default to FLOPS
  700. flops_values = []
  701. # Collect all FLOPS values to determine the best unit
  702. for sample_row in rows_show:
  703. baseline_flops = float(sample_row[-4])
  704. compare_flops = float(sample_row[-3])
  705. baseline_bandwidth = float(sample_row[-2])
  706. if baseline_flops > 0:
  707. flops_values.extend([baseline_flops, compare_flops])
  708. elif baseline_bandwidth > 0 and not flops_values:
  709. primary_metric = "Bandwidth (GB/s)"
  710. # If we have FLOPS data, determine the appropriate unit
  711. if flops_values:
  712. primary_metric = get_flops_unit_name(flops_values)
  713. # For test-backend-ops, prioritize FLOPS > bandwidth for comparison
  714. for row in rows_show:
  715. # Extract metrics: flops, bandwidth_gb_s (baseline and compare)
  716. baseline_flops = float(row[-4])
  717. compare_flops = float(row[-3])
  718. baseline_bandwidth = float(row[-2])
  719. compare_bandwidth = float(row[-1])
  720. # Determine which metric to use for comparison (prioritize FLOPS > bandwidth)
  721. if baseline_flops > 0 and compare_flops > 0:
  722. # Use FLOPS comparison (higher is better)
  723. speedup = compare_flops / baseline_flops
  724. baseline_str = format_flops_for_table(baseline_flops, primary_metric)
  725. compare_str = format_flops_for_table(compare_flops, primary_metric)
  726. elif baseline_bandwidth > 0 and compare_bandwidth > 0:
  727. # Use bandwidth comparison (higher is better)
  728. speedup = compare_bandwidth / baseline_bandwidth
  729. baseline_str = f"{baseline_bandwidth:.2f}"
  730. compare_str = f"{compare_bandwidth:.2f}"
  731. else:
  732. # Fallback if no valid data is available
  733. baseline_str = "N/A"
  734. compare_str = "N/A"
  735. from math import nan
  736. speedup = nan
  737. table.append(list(row[:-4]) + [baseline_str, compare_str, speedup])
  738. else:
  739. assert False
  740. # Some a-posteriori fixes to make the table contents prettier:
  741. for bool_property in bool_properties:
  742. if bool_property in show:
  743. ip = show.index(bool_property)
  744. for row_table in table:
  745. row_table[ip] = "Yes" if int(row_table[ip]) == 1 else "No"
  746. if tool == "llama-bench":
  747. if "model_type" in show:
  748. ip = show.index("model_type")
  749. for (old, new) in MODEL_SUFFIX_REPLACE.items():
  750. for row_table in table:
  751. row_table[ip] = row_table[ip].replace(old, new)
  752. if "model_size" in show:
  753. ip = show.index("model_size")
  754. for row_table in table:
  755. row_table[ip] = float(row_table[ip]) / 1024 ** 3
  756. if "gpu_info" in show:
  757. ip = show.index("gpu_info")
  758. for row_table in table:
  759. for gns in GPU_NAME_STRIP:
  760. row_table[ip] = row_table[ip].replace(gns, "")
  761. gpu_names = row_table[ip].split(", ")
  762. num_gpus = len(gpu_names)
  763. all_names_the_same = len(set(gpu_names)) == 1
  764. if len(gpu_names) >= 2 and all_names_the_same:
  765. row_table[ip] = f"{num_gpus}x {gpu_names[0]}"
  766. headers = [pretty_names.get(p, p) for p in show]
  767. if tool == "llama-bench":
  768. headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"]
  769. elif tool == "test-backend-ops":
  770. headers += [f"{primary_metric} {name_baseline}", f"{primary_metric} {name_compare}", "Speedup"]
  771. else:
  772. assert False
  773. if known_args.plot:
  774. def create_performance_plot(table_data: list[list[str]], headers: list[str], baseline_name: str, compare_name: str, output_file: str, plot_x_param: str, log_scale: bool = False, tool_type: str = "llama-bench", metric_name: str = "t/s"):
  775. try:
  776. import matplotlib
  777. import matplotlib.pyplot as plt
  778. matplotlib.use('Agg')
  779. except ImportError as e:
  780. logger.error("matplotlib is required for --plot.")
  781. raise e
  782. data_headers = headers[:-4] # Exclude the last 4 columns (Test, baseline t/s, compare t/s, Speedup)
  783. plot_x_index = None
  784. plot_x_label = plot_x_param
  785. if plot_x_param not in ["n_prompt", "n_gen", "n_depth"]:
  786. pretty_name = LLAMA_BENCH_PRETTY_NAMES.get(plot_x_param, plot_x_param)
  787. if pretty_name in data_headers:
  788. plot_x_index = data_headers.index(pretty_name)
  789. plot_x_label = pretty_name
  790. elif plot_x_param in data_headers:
  791. plot_x_index = data_headers.index(plot_x_param)
  792. plot_x_label = plot_x_param
  793. else:
  794. logger.error(f"Parameter '{plot_x_param}' not found in current table columns. Available columns: {', '.join(data_headers)}")
  795. return
  796. grouped_data = {}
  797. for i, row in enumerate(table_data):
  798. group_key_parts = []
  799. test_name = row[-4]
  800. base_test = ""
  801. x_value = None
  802. if plot_x_param in ["n_prompt", "n_gen", "n_depth"]:
  803. for j, val in enumerate(row[:-4]):
  804. header_name = data_headers[j]
  805. if val is not None and str(val).strip():
  806. group_key_parts.append(f"{header_name}={val}")
  807. if plot_x_param == "n_prompt" and "pp" in test_name:
  808. base_test = test_name.split("@")[0]
  809. x_value = base_test
  810. elif plot_x_param == "n_gen" and "tg" in test_name:
  811. x_value = test_name.split("@")[0]
  812. elif plot_x_param == "n_depth" and "@d" in test_name:
  813. base_test = test_name.split("@d")[0]
  814. x_value = int(test_name.split("@d")[1])
  815. else:
  816. base_test = test_name
  817. if base_test.strip():
  818. group_key_parts.append(f"Test={base_test}")
  819. else:
  820. for j, val in enumerate(row[:-4]):
  821. if j != plot_x_index:
  822. header_name = data_headers[j]
  823. if val is not None and str(val).strip():
  824. group_key_parts.append(f"{header_name}={val}")
  825. else:
  826. x_value = val
  827. group_key_parts.append(f"Test={test_name}")
  828. group_key = tuple(group_key_parts)
  829. if group_key not in grouped_data:
  830. grouped_data[group_key] = []
  831. grouped_data[group_key].append({
  832. 'x_value': x_value,
  833. 'baseline': float(row[-3]),
  834. 'compare': float(row[-2]),
  835. 'speedup': float(row[-1])
  836. })
  837. if not grouped_data:
  838. logger.error("No data available for plotting")
  839. return
  840. def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
  841. from math import ceil
  842. cols = 1 if num_groups == 1 else min(max_cols, num_groups)
  843. rows = ceil(num_groups / cols)
  844. # Scale figure size by grid dimensions
  845. w, h = base_size
  846. fig, ax_arr = plt.subplots(rows, cols,
  847. figsize=(w * cols, h * rows),
  848. squeeze=False)
  849. axes = ax_arr.flatten()[:num_groups]
  850. return fig, axes
  851. num_groups = len(grouped_data)
  852. fig, axes = make_axes(num_groups)
  853. plot_idx = 0
  854. for group_key, points in grouped_data.items():
  855. if plot_idx >= len(axes):
  856. break
  857. ax = axes[plot_idx]
  858. try:
  859. points_sorted = sorted(points, key=lambda p: float(p['x_value']) if p['x_value'] is not None else 0)
  860. x_values = [float(p['x_value']) if p['x_value'] is not None else 0 for p in points_sorted]
  861. except ValueError:
  862. points_sorted = sorted(points, key=lambda p: group_key)
  863. x_values = [p['x_value'] for p in points_sorted]
  864. baseline_vals = [p['baseline'] for p in points_sorted]
  865. compare_vals = [p['compare'] for p in points_sorted]
  866. ax.plot(x_values, baseline_vals, 'o-', color='skyblue',
  867. label=f'{baseline_name}', linewidth=2, markersize=6)
  868. ax.plot(x_values, compare_vals, 's--', color='lightcoral', alpha=0.8,
  869. label=f'{compare_name}', linewidth=2, markersize=6)
  870. if log_scale:
  871. ax.set_xscale('log', base=2)
  872. unique_x = sorted(set(x_values))
  873. ax.set_xticks(unique_x)
  874. ax.set_xticklabels([str(int(x)) for x in unique_x])
  875. title_parts = []
  876. for part in group_key:
  877. if '=' in part:
  878. key, value = part.split('=', 1)
  879. title_parts.append(f"{key}: {value}")
  880. title = ', '.join(title_parts) if title_parts else "Performance comparison"
  881. # Determine y-axis label based on tool type
  882. if tool_type == "llama-bench":
  883. y_label = "Tokens per second (t/s)"
  884. elif tool_type == "test-backend-ops":
  885. y_label = metric_name
  886. else:
  887. assert False
  888. ax.set_xlabel(plot_x_label, fontsize=12, fontweight='bold')
  889. ax.set_ylabel(y_label, fontsize=12, fontweight='bold')
  890. ax.set_title(title, fontsize=12, fontweight='bold')
  891. ax.legend(loc='best', fontsize=10)
  892. ax.grid(True, alpha=0.3)
  893. plot_idx += 1
  894. for i in range(plot_idx, len(axes)):
  895. axes[i].set_visible(False)
  896. fig.suptitle(f'Performance comparison: {compare_name} vs. {baseline_name}',
  897. fontsize=14, fontweight='bold')
  898. fig.subplots_adjust(top=1)
  899. plt.tight_layout()
  900. plt.savefig(output_file, dpi=300, bbox_inches='tight')
  901. plt.close()
  902. create_performance_plot(table, headers, name_baseline, name_compare, known_args.plot, known_args.plot_x, known_args.plot_log_scale, tool, primary_metric)
  903. print(tabulate( # noqa: NP100
  904. table,
  905. headers=headers,
  906. floatfmt=".2f",
  907. tablefmt=known_args.output
  908. ))