"""Benchmark helpers for TauFactor convergence studies."""
from __future__ import annotations
import gc
import inspect
import io
import itertools
import os
import time
from collections.abc import Callable
from contextlib import redirect_stdout
import torch
import taufactor as tau
from taufactor.utils import create_fcc_cube, create_2d_diagonals, create_3d_diagonals, create_stacked_blocks, create_2d_zigzag
DEFAULT_OUTFILE = "taufactor_benchmark_results.txt"
# Benchmarkable solvers exposed by the top-level package.
SOLVER_REGISTRY = {
name: cls
for name, cls in vars(tau).items()
if name.endswith("Solver") and isinstance(cls, type) and not inspect.isabstract(cls)
}
STRUCTURE_REGISTRY = {
"fcc": lambda N, features=None: (create_fcc_cube(N, overlap=0.05) == 0).astype(int),
"blocks": create_stacked_blocks,
"diagonal2d": create_2d_diagonals,
"zigzag": create_2d_zigzag,
"diagonal3d": create_3d_diagonals,
}
[docs]
def resolve_solver(solver: str | type | None) -> type:
"""Resolve a solver provided either as class object or string name."""
if solver is None:
return tau.PeriodicSolver
if isinstance(solver, str):
if solver not in SOLVER_REGISTRY:
available = ", ".join(sorted(SOLVER_REGISTRY))
raise ValueError(
f"Unknown solver '{solver}'. Available solvers: {available}"
)
return SOLVER_REGISTRY[solver]
if isinstance(solver, type):
return solver
raise TypeError("solver must be None, a solver class, or a solver name string")
def _call_structure_hook(
structure_fn: Callable,
N: int,
features: int | None,
):
"""Call custom structure hooks with flexible, user-friendly signatures."""
attempts = [
lambda: structure_fn(N=N, features=features),
lambda: structure_fn(Nx=N, features=features),
lambda: structure_fn(N, features=features),
lambda: structure_fn(N),
]
last_exc = None
for attempt in attempts:
try:
return attempt()
except TypeError as exc:
last_exc = exc
continue
raise TypeError(
"Unable to call custom structure hook. Expected a callable that accepts "
"N or Nx (optionally features)."
) from last_exc
[docs]
def resolve_structure(
structure: str | Callable,
N: int,
features: int | None = None,
) -> tuple:
"""Resolve a benchmark structure from a predefined name or custom callable."""
if isinstance(structure, str):
if structure not in STRUCTURE_REGISTRY:
available = ", ".join(sorted(STRUCTURE_REGISTRY))
raise ValueError(
f"Unknown structure '{structure}'. Supported: {available}"
)
cube = STRUCTURE_REGISTRY[structure](N, features=features)
return cube, structure
if callable(structure):
cube = _call_structure_hook(structure, N=N, features=features)
structure_name = getattr(structure, "__name__", structure.__class__.__name__)
return cube, structure_name
raise TypeError("structure must be a predefined structure name or a callable hook")
[docs]
def append_row_to_file(row: dict, outfile: str = DEFAULT_OUTFILE) -> None:
"""Format a benchmark result row and append it to the output file."""
line = (
f"{row['N']:4d} {row['structure'][:10]:>10} {row['solver'][:16]:>16} {row['device'][:4]:>4} {row['conv_crit']:.4f} "
f"{row['total_time']:9.3f} {row['solve_time']:9.3f} {row['iterations']:6d} {row['taufactor']:8.3f} "
f"{row['torch_cur']:10.2f} {row['torch_max']:10.2f} {row['torch_res']:10.2f}\n"
)
with open(outfile, "a", encoding="utf-8") as f:
f.write(line)
[docs]
def run_benchmark_case(
N: int,
device: str,
conv_crit: float,
structure: str | Callable = "fcc",
features: int | None = None,
iter_limit: int = 10000,
solver: str | type | None = None,
solver_kwargs: dict | None = None,
solve_kwargs: dict | None = None,
) -> dict:
"""Run a single structure benchmark case.
Args:
structure: Either one of the predefined keys in ``STRUCTURE_REGISTRY``
or a custom callable hook returning a structure array.
"""
cube, structure_name = resolve_structure(structure, N=N, features=features)
solver_cls = resolve_solver(solver)
solver_kwargs = dict(solver_kwargs or {})
solve_kwargs = dict(solve_kwargs or {})
if device == "cuda":
torch.cuda.empty_cache()
torch._dynamo.reset()
gc.collect()
if device == "cuda":
torch.cuda.synchronize()
start_time = time.perf_counter()
buf = io.StringIO()
with redirect_stdout(buf):
solver = solver_cls(cube, device=device, **solver_kwargs)
solver.solve(iter_limit=iter_limit, conv_crit=conv_crit, **solve_kwargs)
if device == "cuda":
torch.cuda.synchronize()
end_time = time.perf_counter()
out = buf.getvalue().splitlines()
iterations = int(solver.iter)
wall_time = float(solver.walltime)
taufactor = float(solver.tau[0])
torch_line = next((line for line in out if "GPU-RAM" in line), "")
if torch_line:
parts = torch_line.replace("(", "").replace(")", "").replace(",", "").split()
torch_cur = float(parts[2])
torch_max = float(parts[6])
torch_res = float(parts[8])
else:
torch_cur = 0.0
torch_max = 0.0
torch_res = 0.0
return {
"N": N,
"structure": structure_name,
"solver": solver_cls.__name__,
"device": device,
"conv_crit": conv_crit,
"total_time": end_time - start_time,
"solve_time": wall_time,
"iterations": iterations,
"taufactor": taufactor,
"torch_cur": torch_cur,
"torch_max": torch_max,
"torch_res": torch_res,
}
[docs]
def run_benchmark_study(
Ns: list[int] | tuple[int, ...] = (100, 128, 200, 256, 300, 384, 400),
devices: list[str] | tuple[str, ...] = ("cuda",),
conv_crit_values: list[float] | tuple[float, ...] = (1e-3,),
structure: str | Callable = "fcc",
features: int = 1,
outfile: str = DEFAULT_OUTFILE,
write_file: bool = True,
iter_limit: int = 10000,
solver: str | type | None = None,
solver_kwargs: dict | None = None,
solve_kwargs: dict | None = None,
) -> list[dict]:
"""Run a convergence benchmark on synthetic structures.
Args:
structure: Either a predefined structure name (``fcc``, ``blocks``,
``diagonal2d``, ``zigzag``, ``diagonal3d``) or a custom callable
hook. Custom hooks should accept ``N`` or ``Nx`` and can
optionally accept ``features``.
"""
rows: list[dict] = []
if write_file:
write_header_if_missing(outfile=outfile)
for N, device, conv_crit in itertools.product(Ns, devices, conv_crit_values):
if device == "cuda" and not torch.cuda.is_available():
print(f"Skipping N={N} on CUDA (not available)")
continue
row = run_benchmark_case(
N=N,
device=device,
conv_crit=conv_crit,
structure=structure,
features=features,
iter_limit=iter_limit,
solver=solver,
solver_kwargs=solver_kwargs,
solve_kwargs=solve_kwargs,
)
rows.append(row)
if write_file:
append_row_to_file(row, outfile=outfile)
return rows