|
@@ -7,6 +7,10 @@ import sys
|
|
|
import os
|
|
import os
|
|
|
from glob import glob
|
|
from glob import glob
|
|
|
import sqlite3
|
|
import sqlite3
|
|
|
|
|
+import json
|
|
|
|
|
+import csv
|
|
|
|
|
+from typing import Optional, Union
|
|
|
|
|
+from collections.abc import Iterator, Sequence
|
|
|
|
|
|
|
|
try:
|
|
try:
|
|
|
import git
|
|
import git
|
|
@@ -17,6 +21,28 @@ except ImportError as e:
|
|
|
|
|
|
|
|
logger = logging.getLogger("compare-llama-bench")
|
|
logger = logging.getLogger("compare-llama-bench")
|
|
|
|
|
|
|
|
|
|
+# All llama-bench SQL fields
|
|
|
|
|
+DB_FIELDS = [
|
|
|
|
|
+ "build_commit", "build_number", "cpu_info", "gpu_info", "backends", "model_filename",
|
|
|
|
|
+ "model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", "n_threads",
|
|
|
|
|
+ "cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers",
|
|
|
|
|
+ "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "tensor_buft_overrides",
|
|
|
|
|
+ "defrag_thold",
|
|
|
|
|
+ "use_mmap", "embeddings", "no_op_offload", "n_prompt", "n_gen", "n_depth",
|
|
|
|
|
+ "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts",
|
|
|
|
|
+]
|
|
|
|
|
+
|
|
|
|
|
+DB_TYPES = [
|
|
|
|
|
+ "TEXT", "INTEGER", "TEXT", "TEXT", "TEXT", "TEXT",
|
|
|
|
|
+ "TEXT", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER",
|
|
|
|
|
+ "TEXT", "INTEGER", "INTEGER", "TEXT", "TEXT", "INTEGER",
|
|
|
|
|
+ "TEXT", "INTEGER", "INTEGER", "INTEGER", "TEXT", "TEXT",
|
|
|
|
|
+ "REAL",
|
|
|
|
|
+ "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER",
|
|
|
|
|
+ "TEXT", "INTEGER", "INTEGER", "REAL", "REAL",
|
|
|
|
|
+]
|
|
|
|
|
+assert len(DB_FIELDS) == len(DB_TYPES)
|
|
|
|
|
+
|
|
|
# Properties by which to differentiate results per commit:
|
|
# Properties by which to differentiate results per commit:
|
|
|
KEY_PROPERTIES = [
|
|
KEY_PROPERTIES = [
|
|
|
"cpu_info", "gpu_info", "backends", "n_gpu_layers", "tensor_buft_overrides", "model_filename", "model_type",
|
|
"cpu_info", "gpu_info", "backends", "n_gpu_layers", "tensor_buft_overrides", "model_filename", "model_type",
|
|
@@ -42,7 +68,7 @@ DEFAULT_HIDE = ["model_filename"] # Always hide these properties by default.
|
|
|
GPU_NAME_STRIP = ["NVIDIA GeForce ", "Tesla ", "AMD Radeon "] # Strip prefixes for smaller tables.
|
|
GPU_NAME_STRIP = ["NVIDIA GeForce ", "Tesla ", "AMD Radeon "] # Strip prefixes for smaller tables.
|
|
|
MODEL_SUFFIX_REPLACE = {" - Small": "_S", " - Medium": "_M", " - Large": "_L"}
|
|
MODEL_SUFFIX_REPLACE = {" - Small": "_S", " - Medium": "_M", " - Large": "_L"}
|
|
|
|
|
|
|
|
-DESCRIPTION = """Creates tables from llama-bench data written to an SQLite database. Example usage (Linux):
|
|
|
|
|
|
|
+DESCRIPTION = """Creates tables from llama-bench data written to multiple JSON/CSV files, a single JSONL file or SQLite database. Example usage (Linux):
|
|
|
|
|
|
|
|
$ git checkout master
|
|
$ git checkout master
|
|
|
$ make clean && make llama-bench
|
|
$ make clean && make llama-bench
|
|
@@ -70,12 +96,13 @@ help_c = (
|
|
|
)
|
|
)
|
|
|
parser.add_argument("-c", "--compare", help=help_c)
|
|
parser.add_argument("-c", "--compare", help=help_c)
|
|
|
help_i = (
|
|
help_i = (
|
|
|
- "Input SQLite file for comparing commits. "
|
|
|
|
|
|
|
+ "JSON/JSONL/SQLite/CSV files for comparing commits. "
|
|
|
|
|
+ "Specify multiple times to use multiple input files (JSON/CSV only). "
|
|
|
"Defaults to 'llama-bench.sqlite' in the current working directory. "
|
|
"Defaults to 'llama-bench.sqlite' in the current working directory. "
|
|
|
"If no such file is found and there is exactly one .sqlite file in the current directory, "
|
|
"If no such file is found and there is exactly one .sqlite file in the current directory, "
|
|
|
"that file is instead used as input."
|
|
"that file is instead used as input."
|
|
|
)
|
|
)
|
|
|
-parser.add_argument("-i", "--input", help=help_i)
|
|
|
|
|
|
|
+parser.add_argument("-i", "--input", action="append", help=help_i)
|
|
|
help_o = (
|
|
help_o = (
|
|
|
"Output format for the table. "
|
|
"Output format for the table. "
|
|
|
"Defaults to 'pipe' (GitHub compatible). "
|
|
"Defaults to 'pipe' (GitHub compatible). "
|
|
@@ -110,119 +137,321 @@ if unknown_args:
|
|
|
sys.exit(1)
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
input_file = known_args.input
|
|
input_file = known_args.input
|
|
|
-if input_file is None and os.path.exists("./llama-bench.sqlite"):
|
|
|
|
|
- input_file = "llama-bench.sqlite"
|
|
|
|
|
-if input_file is None:
|
|
|
|
|
|
|
+if not input_file and os.path.exists("./llama-bench.sqlite"):
|
|
|
|
|
+ input_file = ["llama-bench.sqlite"]
|
|
|
|
|
+if not input_file:
|
|
|
sqlite_files = glob("*.sqlite")
|
|
sqlite_files = glob("*.sqlite")
|
|
|
if len(sqlite_files) == 1:
|
|
if len(sqlite_files) == 1:
|
|
|
- input_file = sqlite_files[0]
|
|
|
|
|
|
|
+ input_file = sqlite_files
|
|
|
|
|
|
|
|
-if input_file is None:
|
|
|
|
|
|
|
+if not input_file:
|
|
|
logger.error("Cannot find a suitable input file, please provide one.\n")
|
|
logger.error("Cannot find a suitable input file, please provide one.\n")
|
|
|
parser.print_help()
|
|
parser.print_help()
|
|
|
sys.exit(1)
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
-connection = sqlite3.connect(input_file)
|
|
|
|
|
-cursor = connection.cursor()
|
|
|
|
|
|
|
|
|
|
-build_len_min: int = cursor.execute("SELECT MIN(LENGTH(build_commit)) from test;").fetchone()[0]
|
|
|
|
|
-build_len_max: int = cursor.execute("SELECT MAX(LENGTH(build_commit)) from test;").fetchone()[0]
|
|
|
|
|
|
|
+class LlamaBenchData:
|
|
|
|
|
+ repo: Optional[git.Repo]
|
|
|
|
|
+ build_len_min: int
|
|
|
|
|
+ build_len_max: int
|
|
|
|
|
+ build_len: int = 8
|
|
|
|
|
+ builds: list[str] = []
|
|
|
|
|
+ check_keys = set(KEY_PROPERTIES + ["build_commit", "test_time", "avg_ts"])
|
|
|
|
|
|
|
|
-if build_len_min != build_len_max:
|
|
|
|
|
- logger.warning(f"{input_file} contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. "
|
|
|
|
|
- "Try purging the the database of old commits.")
|
|
|
|
|
- cursor.execute(f"UPDATE test SET build_commit = SUBSTRING(build_commit, 1, {build_len_min});")
|
|
|
|
|
|
|
+ def __init__(self):
|
|
|
|
|
+ try:
|
|
|
|
|
+ self.repo = git.Repo(".", search_parent_directories=True)
|
|
|
|
|
+ except git.InvalidGitRepositoryError:
|
|
|
|
|
+ self.repo = None
|
|
|
|
|
|
|
|
-build_len: int = build_len_min
|
|
|
|
|
|
|
+ def _builds_init(self):
|
|
|
|
|
+ self.build_len = self.build_len_min
|
|
|
|
|
|
|
|
-builds = cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall()
|
|
|
|
|
-builds = list(map(lambda b: b[0], builds)) # list[tuple[str]] -> list[str]
|
|
|
|
|
|
|
+ def _check_keys(self, keys: set) -> Optional[set]:
|
|
|
|
|
+ """Private helper method that checks against required data keys and returns missing ones."""
|
|
|
|
|
+ if not keys >= self.check_keys:
|
|
|
|
|
+ return self.check_keys - keys
|
|
|
|
|
+ return None
|
|
|
|
|
|
|
|
-if not builds:
|
|
|
|
|
- raise RuntimeError(f"{input_file} does not contain any builds.")
|
|
|
|
|
|
|
+ def find_parent_in_data(self, commit: git.Commit) -> Optional[str]:
|
|
|
|
|
+ """Helper method to find the most recent parent measured in number of commits for which there is data."""
|
|
|
|
|
+ heap: list[tuple[int, git.Commit]] = [(0, commit)]
|
|
|
|
|
+ seen_hexsha8 = set()
|
|
|
|
|
+ while heap:
|
|
|
|
|
+ depth, current_commit = heapq.heappop(heap)
|
|
|
|
|
+ current_hexsha8 = commit.hexsha[:self.build_len]
|
|
|
|
|
+ if current_hexsha8 in self.builds:
|
|
|
|
|
+ return current_hexsha8
|
|
|
|
|
+ for parent in commit.parents:
|
|
|
|
|
+ parent_hexsha8 = parent.hexsha[:self.build_len]
|
|
|
|
|
+ if parent_hexsha8 not in seen_hexsha8:
|
|
|
|
|
+ seen_hexsha8.add(parent_hexsha8)
|
|
|
|
|
+ heapq.heappush(heap, (depth + 1, parent))
|
|
|
|
|
+ return None
|
|
|
|
|
|
|
|
-try:
|
|
|
|
|
- repo = git.Repo(".", search_parent_directories=True)
|
|
|
|
|
-except git.InvalidGitRepositoryError:
|
|
|
|
|
- repo = None
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def find_parent_in_data(commit: git.Commit):
|
|
|
|
|
- """Helper function to find the most recent parent measured in number of commits for which there is data."""
|
|
|
|
|
- heap: list[tuple[int, git.Commit]] = [(0, commit)]
|
|
|
|
|
- seen_hexsha8 = set()
|
|
|
|
|
- while heap:
|
|
|
|
|
- depth, current_commit = heapq.heappop(heap)
|
|
|
|
|
- current_hexsha8 = commit.hexsha[:build_len]
|
|
|
|
|
- if current_hexsha8 in builds:
|
|
|
|
|
- return current_hexsha8
|
|
|
|
|
- for parent in commit.parents:
|
|
|
|
|
- parent_hexsha8 = parent.hexsha[:build_len]
|
|
|
|
|
- if parent_hexsha8 not in seen_hexsha8:
|
|
|
|
|
- seen_hexsha8.add(parent_hexsha8)
|
|
|
|
|
- heapq.heappush(heap, (depth + 1, parent))
|
|
|
|
|
- return None
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def get_all_parent_hexsha8s(commit: git.Commit):
|
|
|
|
|
- """Helper function to recursively get hexsha8 values for all parents of a commit."""
|
|
|
|
|
- unvisited = [commit]
|
|
|
|
|
- visited = []
|
|
|
|
|
-
|
|
|
|
|
- while unvisited:
|
|
|
|
|
- current_commit = unvisited.pop(0)
|
|
|
|
|
- visited.append(current_commit.hexsha[:build_len])
|
|
|
|
|
- for parent in current_commit.parents:
|
|
|
|
|
- if parent.hexsha[:build_len] not in visited:
|
|
|
|
|
- unvisited.append(parent)
|
|
|
|
|
-
|
|
|
|
|
- return visited
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def get_commit_name(hexsha8: str):
|
|
|
|
|
- """Helper function to find a human-readable name for a commit if possible."""
|
|
|
|
|
- if repo is None:
|
|
|
|
|
|
|
+ def get_all_parent_hexsha8s(self, commit: git.Commit) -> Sequence[str]:
|
|
|
|
|
+ """Helper method to recursively get hexsha8 values for all parents of a commit."""
|
|
|
|
|
+ unvisited = [commit]
|
|
|
|
|
+ visited = []
|
|
|
|
|
+
|
|
|
|
|
+ while unvisited:
|
|
|
|
|
+ current_commit = unvisited.pop(0)
|
|
|
|
|
+ visited.append(current_commit.hexsha[:self.build_len])
|
|
|
|
|
+ for parent in current_commit.parents:
|
|
|
|
|
+ if parent.hexsha[:self.build_len] not in visited:
|
|
|
|
|
+ unvisited.append(parent)
|
|
|
|
|
+
|
|
|
|
|
+ return visited
|
|
|
|
|
+
|
|
|
|
|
+ def get_commit_name(self, hexsha8: str) -> str:
|
|
|
|
|
+ """Helper method to find a human-readable name for a commit if possible."""
|
|
|
|
|
+ if self.repo is None:
|
|
|
|
|
+ return hexsha8
|
|
|
|
|
+ for h in self.repo.heads:
|
|
|
|
|
+ if h.commit.hexsha[:self.build_len] == hexsha8:
|
|
|
|
|
+ return h.name
|
|
|
|
|
+ for t in self.repo.tags:
|
|
|
|
|
+ if t.commit.hexsha[:self.build_len] == hexsha8:
|
|
|
|
|
+ return t.name
|
|
|
return hexsha8
|
|
return hexsha8
|
|
|
- for h in repo.heads:
|
|
|
|
|
- if h.commit.hexsha[:build_len] == hexsha8:
|
|
|
|
|
- return h.name
|
|
|
|
|
- for t in repo.tags:
|
|
|
|
|
- if t.commit.hexsha[:build_len] == hexsha8:
|
|
|
|
|
- return t.name
|
|
|
|
|
- return hexsha8
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def get_commit_hexsha8(name: str):
|
|
|
|
|
- """Helper function to search for a commit given a human-readable name."""
|
|
|
|
|
- if repo is None:
|
|
|
|
|
|
|
+
|
|
|
|
|
+ def get_commit_hexsha8(self, name: str) -> Optional[str]:
|
|
|
|
|
+ """Helper method to search for a commit given a human-readable name."""
|
|
|
|
|
+ if self.repo is None:
|
|
|
|
|
+ return None
|
|
|
|
|
+ for h in self.repo.heads:
|
|
|
|
|
+ if h.name == name:
|
|
|
|
|
+ return h.commit.hexsha[:self.build_len]
|
|
|
|
|
+ for t in self.repo.tags:
|
|
|
|
|
+ if t.name == name:
|
|
|
|
|
+ return t.commit.hexsha[:self.build_len]
|
|
|
|
|
+ for c in self.repo.iter_commits("--all"):
|
|
|
|
|
+ if c.hexsha[:self.build_len] == name[:self.build_len]:
|
|
|
|
|
+ return c.hexsha[:self.build_len]
|
|
|
return None
|
|
return None
|
|
|
- for h in repo.heads:
|
|
|
|
|
- if h.name == name:
|
|
|
|
|
- return h.commit.hexsha[:build_len]
|
|
|
|
|
- for t in repo.tags:
|
|
|
|
|
- if t.name == name:
|
|
|
|
|
- return t.commit.hexsha[:build_len]
|
|
|
|
|
- for c in repo.iter_commits("--all"):
|
|
|
|
|
- if c.hexsha[:build_len] == name[:build_len]:
|
|
|
|
|
- return c.hexsha[:build_len]
|
|
|
|
|
- return None
|
|
|
|
|
|
|
+
|
|
|
|
|
+ def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequence[tuple]]:
|
|
|
|
|
+ """Helper method that gets rows of (build_commit, test_time) sorted by the latter."""
|
|
|
|
|
+ return []
|
|
|
|
|
+
|
|
|
|
|
+ def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]:
|
|
|
|
|
+ """
|
|
|
|
|
+ Helper method that gets table rows for some list of properties.
|
|
|
|
|
+ Rows are created by combining those where all provided properties are equal.
|
|
|
|
|
+ The resulting rows are then grouped by the provided properties and the t/s values are averaged.
|
|
|
|
|
+ The returned rows are unique in terms of property combinations.
|
|
|
|
|
+ """
|
|
|
|
|
+ return []
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class LlamaBenchDataSQLite3(LlamaBenchData):
|
|
|
|
|
+ connection: sqlite3.Connection
|
|
|
|
|
+ cursor: sqlite3.Cursor
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self):
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.connection = sqlite3.connect(":memory:")
|
|
|
|
|
+ self.cursor = self.connection.cursor()
|
|
|
|
|
+ self.cursor.execute(f"CREATE TABLE test({', '.join(' '.join(x) for x in zip(DB_FIELDS, DB_TYPES))});")
|
|
|
|
|
+
|
|
|
|
|
+ def _builds_init(self):
|
|
|
|
|
+ if self.connection:
|
|
|
|
|
+ self.build_len_min = self.cursor.execute("SELECT MIN(LENGTH(build_commit)) from test;").fetchone()[0]
|
|
|
|
|
+ self.build_len_max = self.cursor.execute("SELECT MAX(LENGTH(build_commit)) from test;").fetchone()[0]
|
|
|
|
|
+
|
|
|
|
|
+ if self.build_len_min != self.build_len_max:
|
|
|
|
|
+ logger.warning("Data contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. "
|
|
|
|
|
+ "Try purging the the database of old commits.")
|
|
|
|
|
+ self.cursor.execute(f"UPDATE test SET build_commit = SUBSTRING(build_commit, 1, {self.build_len_min});")
|
|
|
|
|
+
|
|
|
|
|
+ builds = self.cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall()
|
|
|
|
|
+ self.builds = list(map(lambda b: b[0], builds)) # list[tuple[str]] -> list[str]
|
|
|
|
|
+ super()._builds_init()
|
|
|
|
|
+
|
|
|
|
|
+ def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequence[tuple]]:
|
|
|
|
|
+ data = self.cursor.execute(
|
|
|
|
|
+ "SELECT build_commit, test_time FROM test ORDER BY test_time;").fetchall()
|
|
|
|
|
+ return reversed(data) if reverse else data
|
|
|
|
|
+
|
|
|
|
|
+ def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]:
|
|
|
|
|
+ select_string = ", ".join(
|
|
|
|
|
+ [f"tb.{p}" for p in properties] + ["tb.n_prompt", "tb.n_gen", "tb.n_depth", "AVG(tb.avg_ts)", "AVG(tc.avg_ts)"])
|
|
|
|
|
+ equal_string = " AND ".join(
|
|
|
|
|
+ [f"tb.{p} = tc.{p}" for p in KEY_PROPERTIES] + [
|
|
|
|
|
+ f"tb.build_commit = '{hexsha8_baseline}'", f"tc.build_commit = '{hexsha8_compare}'"]
|
|
|
|
|
+ )
|
|
|
|
|
+ group_order_string = ", ".join([f"tb.{p}" for p in properties] + ["tb.n_gen", "tb.n_prompt", "tb.n_depth"])
|
|
|
|
|
+ query = (f"SELECT {select_string} FROM test tb JOIN test tc ON {equal_string} "
|
|
|
|
|
+ f"GROUP BY {group_order_string} ORDER BY {group_order_string};")
|
|
|
|
|
+ return self.cursor.execute(query).fetchall()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class LlamaBenchDataSQLite3File(LlamaBenchDataSQLite3):
|
|
|
|
|
+ def __init__(self, data_file: str):
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+
|
|
|
|
|
+ self.connection.close()
|
|
|
|
|
+ self.connection = sqlite3.connect(data_file)
|
|
|
|
|
+ self.cursor = self.connection.cursor()
|
|
|
|
|
+ self._builds_init()
|
|
|
|
|
+
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def valid_format(data_file: str) -> bool:
|
|
|
|
|
+ connection = sqlite3.connect(data_file)
|
|
|
|
|
+ cursor = connection.cursor()
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ if cursor.execute("PRAGMA schema_version;").fetchone()[0] == 0:
|
|
|
|
|
+ raise sqlite3.DatabaseError("The provided input file does not exist or is empty.")
|
|
|
|
|
+ except sqlite3.DatabaseError as e:
|
|
|
|
|
+ logger.debug(f'"{data_file}" is not a valid SQLite3 file.', exc_info=e)
|
|
|
|
|
+ cursor = None
|
|
|
|
|
+
|
|
|
|
|
+ connection.close()
|
|
|
|
|
+ return True if cursor else False
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class LlamaBenchDataJSONL(LlamaBenchDataSQLite3):
|
|
|
|
|
+ def __init__(self, data_file: str):
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+
|
|
|
|
|
+ with open(data_file, "r", encoding="utf-8") as fp:
|
|
|
|
|
+ for i, line in enumerate(fp):
|
|
|
|
|
+ parsed = json.loads(line)
|
|
|
|
|
+
|
|
|
|
|
+ for k in parsed.keys() - set(DB_FIELDS):
|
|
|
|
|
+ del parsed[k]
|
|
|
|
|
+
|
|
|
|
|
+ if (missing_keys := self._check_keys(parsed.keys())):
|
|
|
|
|
+ raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}")
|
|
|
|
|
+
|
|
|
|
|
+ self.cursor.execute(f"INSERT INTO test({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", tuple(parsed.values()))
|
|
|
|
|
+
|
|
|
|
|
+ self._builds_init()
|
|
|
|
|
+
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def valid_format(data_file: str) -> bool:
|
|
|
|
|
+ try:
|
|
|
|
|
+ with open(data_file, "r", encoding="utf-8") as fp:
|
|
|
|
|
+ for line in fp:
|
|
|
|
|
+ json.loads(line)
|
|
|
|
|
+ break
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.debug(f'"{data_file}" is not a valid JSONL file.', exc_info=e)
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ return True
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class LlamaBenchDataJSON(LlamaBenchDataSQLite3):
|
|
|
|
|
+ def __init__(self, data_files: list[str]):
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+
|
|
|
|
|
+ for data_file in data_files:
|
|
|
|
|
+ with open(data_file, "r", encoding="utf-8") as fp:
|
|
|
|
|
+ parsed = json.load(fp)
|
|
|
|
|
+
|
|
|
|
|
+ for i, entry in enumerate(parsed):
|
|
|
|
|
+ for k in entry.keys() - set(DB_FIELDS):
|
|
|
|
|
+ del entry[k]
|
|
|
|
|
+
|
|
|
|
|
+ if (missing_keys := self._check_keys(entry.keys())):
|
|
|
|
|
+ raise RuntimeError(f"Missing required data key(s) at entry {i + 1}: {', '.join(missing_keys)}")
|
|
|
|
|
+
|
|
|
|
|
+ self.cursor.execute(f"INSERT INTO test({', '.join(entry.keys())}) VALUES({', '.join('?' * len(entry))});", tuple(entry.values()))
|
|
|
|
|
+
|
|
|
|
|
+ self._builds_init()
|
|
|
|
|
+
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def valid_format(data_files: list[str]) -> bool:
|
|
|
|
|
+ if not data_files:
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ for data_file in data_files:
|
|
|
|
|
+ try:
|
|
|
|
|
+ with open(data_file, "r", encoding="utf-8") as fp:
|
|
|
|
|
+ json.load(fp)
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.debug(f'"{data_file}" is not a valid JSON file.', exc_info=e)
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ return True
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class LlamaBenchDataCSV(LlamaBenchDataSQLite3):
|
|
|
|
|
+ def __init__(self, data_files: list[str]):
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+
|
|
|
|
|
+ for data_file in data_files:
|
|
|
|
|
+ with open(data_file, "r", encoding="utf-8") as fp:
|
|
|
|
|
+ for i, parsed in enumerate(csv.DictReader(fp)):
|
|
|
|
|
+ keys = set(parsed.keys())
|
|
|
|
|
+
|
|
|
|
|
+ for k in keys - set(DB_FIELDS):
|
|
|
|
|
+ del parsed[k]
|
|
|
|
|
+
|
|
|
|
|
+ if (missing_keys := self._check_keys(keys)):
|
|
|
|
|
+ raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}")
|
|
|
|
|
+
|
|
|
|
|
+ self.cursor.execute(f"INSERT INTO test({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", tuple(parsed.values()))
|
|
|
|
|
+
|
|
|
|
|
+ self._builds_init()
|
|
|
|
|
+
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def valid_format(data_files: list[str]) -> bool:
|
|
|
|
|
+ if not data_files:
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ for data_file in data_files:
|
|
|
|
|
+ try:
|
|
|
|
|
+ with open(data_file, "r", encoding="utf-8") as fp:
|
|
|
|
|
+ for parsed in csv.DictReader(fp):
|
|
|
|
|
+ break
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.debug(f'"{data_file}" is not a valid CSV file.', exc_info=e)
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ return True
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+bench_data = None
|
|
|
|
|
+if len(input_file) == 1:
|
|
|
|
|
+ if LlamaBenchDataSQLite3File.valid_format(input_file[0]):
|
|
|
|
|
+ bench_data = LlamaBenchDataSQLite3File(input_file[0])
|
|
|
|
|
+ elif LlamaBenchDataJSON.valid_format(input_file):
|
|
|
|
|
+ bench_data = LlamaBenchDataJSON(input_file)
|
|
|
|
|
+ elif LlamaBenchDataJSONL.valid_format(input_file[0]):
|
|
|
|
|
+ bench_data = LlamaBenchDataJSONL(input_file[0])
|
|
|
|
|
+ elif LlamaBenchDataCSV.valid_format(input_file):
|
|
|
|
|
+ bench_data = LlamaBenchDataCSV(input_file)
|
|
|
|
|
+else:
|
|
|
|
|
+ if LlamaBenchDataJSON.valid_format(input_file):
|
|
|
|
|
+ bench_data = LlamaBenchDataJSON(input_file)
|
|
|
|
|
+ elif LlamaBenchDataCSV.valid_format(input_file):
|
|
|
|
|
+ bench_data = LlamaBenchDataCSV(input_file)
|
|
|
|
|
+
|
|
|
|
|
+if not bench_data:
|
|
|
|
|
+ raise RuntimeError("No valid (or some invalid) input files found.")
|
|
|
|
|
+
|
|
|
|
|
+if not bench_data.builds:
|
|
|
|
|
+ raise RuntimeError(f"{input_file} does not contain any builds.")
|
|
|
|
|
|
|
|
|
|
|
|
|
hexsha8_baseline = name_baseline = None
|
|
hexsha8_baseline = name_baseline = None
|
|
|
|
|
|
|
|
# If the user specified a baseline, try to find a commit for it:
|
|
# If the user specified a baseline, try to find a commit for it:
|
|
|
if known_args.baseline is not None:
|
|
if known_args.baseline is not None:
|
|
|
- if known_args.baseline in builds:
|
|
|
|
|
|
|
+ if known_args.baseline in bench_data.builds:
|
|
|
hexsha8_baseline = known_args.baseline
|
|
hexsha8_baseline = known_args.baseline
|
|
|
if hexsha8_baseline is None:
|
|
if hexsha8_baseline is None:
|
|
|
- hexsha8_baseline = get_commit_hexsha8(known_args.baseline)
|
|
|
|
|
|
|
+ hexsha8_baseline = bench_data.get_commit_hexsha8(known_args.baseline)
|
|
|
name_baseline = known_args.baseline
|
|
name_baseline = known_args.baseline
|
|
|
if hexsha8_baseline is None:
|
|
if hexsha8_baseline is None:
|
|
|
logger.error(f"cannot find data for baseline={known_args.baseline}.")
|
|
logger.error(f"cannot find data for baseline={known_args.baseline}.")
|
|
|
sys.exit(1)
|
|
sys.exit(1)
|
|
|
# Otherwise, search for the most recent parent of master for which there is data:
|
|
# Otherwise, search for the most recent parent of master for which there is data:
|
|
|
-elif repo is not None:
|
|
|
|
|
- hexsha8_baseline = find_parent_in_data(repo.heads.master.commit)
|
|
|
|
|
|
|
+elif bench_data.repo is not None:
|
|
|
|
|
+ hexsha8_baseline = bench_data.find_parent_in_data(bench_data.repo.heads.master.commit)
|
|
|
|
|
|
|
|
if hexsha8_baseline is None:
|
|
if hexsha8_baseline is None:
|
|
|
logger.error("No baseline was provided and did not find data for any master branch commits.\n")
|
|
logger.error("No baseline was provided and did not find data for any master branch commits.\n")
|
|
@@ -235,27 +464,25 @@ else:
|
|
|
sys.exit(1)
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
-name_baseline = get_commit_name(hexsha8_baseline)
|
|
|
|
|
|
|
+name_baseline = bench_data.get_commit_name(hexsha8_baseline)
|
|
|
|
|
|
|
|
hexsha8_compare = name_compare = None
|
|
hexsha8_compare = name_compare = None
|
|
|
|
|
|
|
|
# If the user has specified a compare value, try to find a corresponding commit:
|
|
# If the user has specified a compare value, try to find a corresponding commit:
|
|
|
if known_args.compare is not None:
|
|
if known_args.compare is not None:
|
|
|
- if known_args.compare in builds:
|
|
|
|
|
|
|
+ if known_args.compare in bench_data.builds:
|
|
|
hexsha8_compare = known_args.compare
|
|
hexsha8_compare = known_args.compare
|
|
|
if hexsha8_compare is None:
|
|
if hexsha8_compare is None:
|
|
|
- hexsha8_compare = get_commit_hexsha8(known_args.compare)
|
|
|
|
|
|
|
+ hexsha8_compare = bench_data.get_commit_hexsha8(known_args.compare)
|
|
|
name_compare = known_args.compare
|
|
name_compare = known_args.compare
|
|
|
if hexsha8_compare is None:
|
|
if hexsha8_compare is None:
|
|
|
logger.error(f"cannot find data for compare={known_args.compare}.")
|
|
logger.error(f"cannot find data for compare={known_args.compare}.")
|
|
|
sys.exit(1)
|
|
sys.exit(1)
|
|
|
# Otherwise, search for the commit for llama-bench was most recently run
|
|
# Otherwise, search for the commit for llama-bench was most recently run
|
|
|
# and that is not a parent of master:
|
|
# and that is not a parent of master:
|
|
|
-elif repo is not None:
|
|
|
|
|
- hexsha8s_master = get_all_parent_hexsha8s(repo.heads.master.commit)
|
|
|
|
|
- builds_timestamp = cursor.execute(
|
|
|
|
|
- "SELECT build_commit, test_time FROM test ORDER BY test_time;").fetchall()
|
|
|
|
|
- for (hexsha8, _) in reversed(builds_timestamp):
|
|
|
|
|
|
|
+elif bench_data.repo is not None:
|
|
|
|
|
+ hexsha8s_master = bench_data.get_all_parent_hexsha8s(bench_data.repo.heads.master.commit)
|
|
|
|
|
+ for (hexsha8, _) in bench_data.builds_timestamp(reverse=True):
|
|
|
if hexsha8 not in hexsha8s_master:
|
|
if hexsha8 not in hexsha8s_master:
|
|
|
hexsha8_compare = hexsha8
|
|
hexsha8_compare = hexsha8
|
|
|
break
|
|
break
|
|
@@ -270,26 +497,7 @@ else:
|
|
|
parser.print_help()
|
|
parser.print_help()
|
|
|
sys.exit(1)
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
-name_compare = get_commit_name(hexsha8_compare)
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def get_rows(properties):
|
|
|
|
|
- """
|
|
|
|
|
- Helper function that gets table rows for some list of properties.
|
|
|
|
|
- Rows are created by combining those where all provided properties are equal.
|
|
|
|
|
- The resulting rows are then grouped by the provided properties and the t/s values are averaged.
|
|
|
|
|
- The returned rows are unique in terms of property combinations.
|
|
|
|
|
- """
|
|
|
|
|
- select_string = ", ".join(
|
|
|
|
|
- [f"tb.{p}" for p in properties] + ["tb.n_prompt", "tb.n_gen", "tb.n_depth", "AVG(tb.avg_ts)", "AVG(tc.avg_ts)"])
|
|
|
|
|
- equal_string = " AND ".join(
|
|
|
|
|
- [f"tb.{p} = tc.{p}" for p in KEY_PROPERTIES] + [
|
|
|
|
|
- f"tb.build_commit = '{hexsha8_baseline}'", f"tc.build_commit = '{hexsha8_compare}'"]
|
|
|
|
|
- )
|
|
|
|
|
- group_order_string = ", ".join([f"tb.{p}" for p in properties] + ["tb.n_gen", "tb.n_prompt", "tb.n_depth"])
|
|
|
|
|
- query = (f"SELECT {select_string} FROM test tb JOIN test tc ON {equal_string} "
|
|
|
|
|
- f"GROUP BY {group_order_string} ORDER BY {group_order_string};")
|
|
|
|
|
- return cursor.execute(query).fetchall()
|
|
|
|
|
|
|
+name_compare = bench_data.get_commit_name(hexsha8_compare)
|
|
|
|
|
|
|
|
|
|
|
|
|
# If the user provided columns to group the results by, use them:
|
|
# If the user provided columns to group the results by, use them:
|
|
@@ -303,10 +511,10 @@ if known_args.show is not None:
|
|
|
logger.error(f"Unknown values for --show: {', '.join(unknown_cols)}")
|
|
logger.error(f"Unknown values for --show: {', '.join(unknown_cols)}")
|
|
|
parser.print_usage()
|
|
parser.print_usage()
|
|
|
sys.exit(1)
|
|
sys.exit(1)
|
|
|
- rows_show = get_rows(show)
|
|
|
|
|
|
|
+ rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare)
|
|
|
# Otherwise, select those columns where the values are not all the same:
|
|
# Otherwise, select those columns where the values are not all the same:
|
|
|
else:
|
|
else:
|
|
|
- rows_full = get_rows(KEY_PROPERTIES)
|
|
|
|
|
|
|
+ rows_full = bench_data.get_rows(KEY_PROPERTIES, hexsha8_baseline, hexsha8_compare)
|
|
|
properties_different = []
|
|
properties_different = []
|
|
|
for i, kp_i in enumerate(KEY_PROPERTIES):
|
|
for i, kp_i in enumerate(KEY_PROPERTIES):
|
|
|
if kp_i in DEFAULT_SHOW or kp_i in ["n_prompt", "n_gen", "n_depth"]:
|
|
if kp_i in DEFAULT_SHOW or kp_i in ["n_prompt", "n_gen", "n_depth"]:
|
|
@@ -336,7 +544,7 @@ else:
|
|
|
show.remove(prop)
|
|
show.remove(prop)
|
|
|
except ValueError:
|
|
except ValueError:
|
|
|
pass
|
|
pass
|
|
|
- rows_show = get_rows(show)
|
|
|
|
|
|
|
+ rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare)
|
|
|
|
|
|
|
|
if not rows_show:
|
|
if not rows_show:
|
|
|
logger.error(f"No comparable data was found between {name_baseline} and {name_compare}.\n")
|
|
logger.error(f"No comparable data was found between {name_baseline} and {name_compare}.\n")
|