|
|
@@ -19,6 +19,7 @@ except ImportError as e:
|
|
|
print("the following Python libraries are required: GitPython, tabulate.") # noqa: NP100
|
|
|
raise e
|
|
|
|
|
|
+
|
|
|
logger = logging.getLogger("compare-llama-bench")
|
|
|
|
|
|
# All llama-bench SQL fields
|
|
|
@@ -122,11 +123,15 @@ help_s = (
|
|
|
parser.add_argument("--check", action="store_true", help="check if all required Python libraries are installed")
|
|
|
parser.add_argument("-s", "--show", help=help_s)
|
|
|
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
|
|
+parser.add_argument("--plot", help="generate a performance comparison plot and save to specified file (e.g., plot.png)")
|
|
|
+parser.add_argument("--plot_x", help="parameter to use as x axis for plotting (default: n_depth)", default="n_depth")
|
|
|
+parser.add_argument("--plot_log_scale", action="store_true", help="use log scale for x axis in plots (off by default)")
|
|
|
|
|
|
known_args, unknown_args = parser.parse_known_args()
|
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG if known_args.verbose else logging.INFO)
|
|
|
|
|
|
+
|
|
|
if known_args.check:
|
|
|
# Check if all required Python libraries are installed. Would have failed earlier if not.
|
|
|
sys.exit(0)
|
|
|
@@ -499,7 +504,6 @@ else:
|
|
|
|
|
|
name_compare = bench_data.get_commit_name(hexsha8_compare)
|
|
|
|
|
|
-
|
|
|
# If the user provided columns to group the results by, use them:
|
|
|
if known_args.show is not None:
|
|
|
show = known_args.show.split(",")
|
|
|
@@ -544,6 +548,14 @@ else:
|
|
|
show.remove(prop)
|
|
|
except ValueError:
|
|
|
pass
|
|
|
+
|
|
|
+ # Add plot_x parameter to parameters to show if it's not already present:
|
|
|
+ if known_args.plot:
|
|
|
+ for k, v in PRETTY_NAMES.items():
|
|
|
+ if v == known_args.plot_x and k not in show:
|
|
|
+ show.append(k)
|
|
|
+ break
|
|
|
+
|
|
|
rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare)
|
|
|
|
|
|
if not rows_show:
|
|
|
@@ -600,6 +612,161 @@ if "gpu_info" in show:
|
|
|
headers = [PRETTY_NAMES[p] for p in show]
|
|
|
headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"]
|
|
|
|
|
|
+if known_args.plot:
|
|
|
+ def create_performance_plot(table_data: list[list[str]], headers: list[str], baseline_name: str, compare_name: str, output_file: str, plot_x_param: str, log_scale: bool = False):
|
|
|
+ try:
|
|
|
+ import matplotlib.pyplot as plt
|
|
|
+ import matplotlib
|
|
|
+ matplotlib.use('Agg')
|
|
|
+ except ImportError as e:
|
|
|
+ logger.error("matplotlib is required for --plot.")
|
|
|
+ raise e
|
|
|
+
|
|
|
+ data_headers = headers[:-4] # Exclude the last 4 columns (Test, baseline t/s, compare t/s, Speedup)
|
|
|
+ plot_x_index = None
|
|
|
+ plot_x_label = plot_x_param
|
|
|
+
|
|
|
+ if plot_x_param not in ["n_prompt", "n_gen", "n_depth"]:
|
|
|
+ pretty_name = PRETTY_NAMES.get(plot_x_param, plot_x_param)
|
|
|
+ if pretty_name in data_headers:
|
|
|
+ plot_x_index = data_headers.index(pretty_name)
|
|
|
+ plot_x_label = pretty_name
|
|
|
+ elif plot_x_param in data_headers:
|
|
|
+ plot_x_index = data_headers.index(plot_x_param)
|
|
|
+ plot_x_label = plot_x_param
|
|
|
+ else:
|
|
|
+ logger.error(f"Parameter '{plot_x_param}' not found in current table columns. Available columns: {', '.join(data_headers)}")
|
|
|
+ return
|
|
|
+
|
|
|
+ grouped_data = {}
|
|
|
+
|
|
|
+ for i, row in enumerate(table_data):
|
|
|
+ group_key_parts = []
|
|
|
+ test_name = row[-4]
|
|
|
+
|
|
|
+ base_test = ""
|
|
|
+ x_value = None
|
|
|
+
|
|
|
+ if plot_x_param in ["n_prompt", "n_gen", "n_depth"]:
|
|
|
+ for j, val in enumerate(row[:-4]):
|
|
|
+ header_name = data_headers[j]
|
|
|
+ if val is not None and str(val).strip():
|
|
|
+ group_key_parts.append(f"{header_name}={val}")
|
|
|
+
|
|
|
+ if plot_x_param == "n_prompt" and "pp" in test_name:
|
|
|
+ base_test = test_name.split("@")[0]
|
|
|
+ x_value = base_test
|
|
|
+ elif plot_x_param == "n_gen" and "tg" in test_name:
|
|
|
+ x_value = test_name.split("@")[0]
|
|
|
+ elif plot_x_param == "n_depth" and "@d" in test_name:
|
|
|
+ base_test = test_name.split("@d")[0]
|
|
|
+ x_value = int(test_name.split("@d")[1])
|
|
|
+ else:
|
|
|
+ base_test = test_name
|
|
|
+
|
|
|
+ if base_test.strip():
|
|
|
+ group_key_parts.append(f"Test={base_test}")
|
|
|
+ else:
|
|
|
+ for j, val in enumerate(row[:-4]):
|
|
|
+ if j != plot_x_index:
|
|
|
+ header_name = data_headers[j]
|
|
|
+ if val is not None and str(val).strip():
|
|
|
+ group_key_parts.append(f"{header_name}={val}")
|
|
|
+ else:
|
|
|
+ x_value = val
|
|
|
+
|
|
|
+ group_key_parts.append(f"Test={test_name}")
|
|
|
+
|
|
|
+ group_key = tuple(group_key_parts)
|
|
|
+
|
|
|
+ if group_key not in grouped_data:
|
|
|
+ grouped_data[group_key] = []
|
|
|
+
|
|
|
+ grouped_data[group_key].append({
|
|
|
+ 'x_value': x_value,
|
|
|
+ 'baseline': float(row[-3]),
|
|
|
+ 'compare': float(row[-2]),
|
|
|
+ 'speedup': float(row[-1])
|
|
|
+ })
|
|
|
+
|
|
|
+ if not grouped_data:
|
|
|
+ logger.error("No data available for plotting")
|
|
|
+ return
|
|
|
+
|
|
|
+ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
|
|
|
+ from math import ceil
|
|
|
+ cols = 1 if num_groups == 1 else min(max_cols, num_groups)
|
|
|
+ rows = ceil(num_groups / cols)
|
|
|
+
|
|
|
+ # Scale figure size by grid dimensions
|
|
|
+ w, h = base_size
|
|
|
+ fig, ax_arr = plt.subplots(rows, cols,
|
|
|
+ figsize=(w * cols, h * rows),
|
|
|
+ squeeze=False)
|
|
|
+
|
|
|
+ axes = ax_arr.flatten()[:num_groups]
|
|
|
+ return fig, axes
|
|
|
+
|
|
|
+ num_groups = len(grouped_data)
|
|
|
+ fig, axes = make_axes(num_groups)
|
|
|
+
|
|
|
+ plot_idx = 0
|
|
|
+
|
|
|
+ for group_key, points in grouped_data.items():
|
|
|
+ if plot_idx >= len(axes):
|
|
|
+ break
|
|
|
+ ax = axes[plot_idx]
|
|
|
+
|
|
|
+ try:
|
|
|
+ points_sorted = sorted(points, key=lambda p: float(p['x_value']) if p['x_value'] is not None else 0)
|
|
|
+ x_values = [float(p['x_value']) if p['x_value'] is not None else 0 for p in points_sorted]
|
|
|
+ except ValueError:
|
|
|
+ points_sorted = sorted(points, key=lambda p: group_key)
|
|
|
+ x_values = [p['x_value'] for p in points_sorted]
|
|
|
+
|
|
|
+ baseline_vals = [p['baseline'] for p in points_sorted]
|
|
|
+ compare_vals = [p['compare'] for p in points_sorted]
|
|
|
+
|
|
|
+ ax.plot(x_values, baseline_vals, 'o-', color='skyblue',
|
|
|
+ label=f'{baseline_name}', linewidth=2, markersize=6)
|
|
|
+ ax.plot(x_values, compare_vals, 's--', color='lightcoral', alpha=0.8,
|
|
|
+ label=f'{compare_name}', linewidth=2, markersize=6)
|
|
|
+
|
|
|
+ if log_scale:
|
|
|
+ ax.set_xscale('log', base=2)
|
|
|
+ unique_x = sorted(set(x_values))
|
|
|
+ ax.set_xticks(unique_x)
|
|
|
+ ax.set_xticklabels([str(int(x)) for x in unique_x])
|
|
|
+
|
|
|
+ title_parts = []
|
|
|
+ for part in group_key:
|
|
|
+ if '=' in part:
|
|
|
+ key, value = part.split('=', 1)
|
|
|
+ title_parts.append(f"{key}: {value}")
|
|
|
+
|
|
|
+ title = ', '.join(title_parts) if title_parts else "Performance comparison"
|
|
|
+
|
|
|
+ ax.set_xlabel(plot_x_label, fontsize=12, fontweight='bold')
|
|
|
+ ax.set_ylabel('Tokens per second (t/s)', fontsize=12, fontweight='bold')
|
|
|
+ ax.set_title(title, fontsize=12, fontweight='bold')
|
|
|
+ ax.legend(loc='best', fontsize=10)
|
|
|
+ ax.grid(True, alpha=0.3)
|
|
|
+
|
|
|
+ plot_idx += 1
|
|
|
+
|
|
|
+ for i in range(plot_idx, len(axes)):
|
|
|
+ axes[i].set_visible(False)
|
|
|
+
|
|
|
+ fig.suptitle(f'Performance comparison: {compare_name} vs. {baseline_name}',
|
|
|
+ fontsize=14, fontweight='bold')
|
|
|
+ fig.subplots_adjust(top=1)
|
|
|
+
|
|
|
+ plt.tight_layout()
|
|
|
+ plt.savefig(output_file, dpi=300, bbox_inches='tight')
|
|
|
+ plt.close()
|
|
|
+
|
|
|
+ create_performance_plot(table, headers, name_baseline, name_compare, known_args.plot, known_args.plot_x, known_args.plot_log_scale)
|
|
|
+
|
|
|
print(tabulate( # noqa: NP100
|
|
|
table,
|
|
|
headers=headers,
|