diff --git a/tuner/__init__.py b/tuner/__init__.py new file mode 100644 index 000000000..a85ba359d --- /dev/null +++ b/tuner/__init__.py @@ -0,0 +1,5 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/tuner/candidate_gen_test.py b/tuner/candidate_gen_test.py index ee0a32c66..6c5af4fc4 100644 --- a/tuner/candidate_gen_test.py +++ b/tuner/candidate_gen_test.py @@ -9,7 +9,7 @@ """ import pytest -import candidate_gen +from . import candidate_gen def test_get_shaped_type_element_bitwidth(): diff --git a/tuner/examples/__init__.py b/tuner/examples/__init__.py new file mode 100644 index 000000000..a85ba359d --- /dev/null +++ b/tuner/examples/__init__.py @@ -0,0 +1,5 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/tuner/examples/punet/.gitignore b/tuner/examples/punet/.gitignore new file mode 100644 index 000000000..fae904ffb --- /dev/null +++ b/tuner/examples/punet/.gitignore @@ -0,0 +1,3 @@ +# Test files/dirs recommended by README.md. +dump-mmt +test-benchmark.mlir diff --git a/tuner/examples/punet/README.md b/tuner/examples/punet/README.md new file mode 100644 index 000000000..f478c8695 --- /dev/null +++ b/tuner/examples/punet/README.md @@ -0,0 +1,46 @@ +# Punet Tuner + +## Environments +Follow instructions in [`/tuner/README.md`](../README.md) + +## Shell Scripts + +The required shell scripts can be downloaded from: +[sdxl-scripts](https://github.com/nod-ai/sdxl-scripts). + +These scripts include: +1. `compile-punet-base.sh` - Used for compiling model candidates. +2. `compile_candidate.sh` - Used for compiling dispatch candidates. +3. `punet.sh` - Invoked by `compile_candidate.sh`. + +Add the parent directories of these scripts to your `PATH` environment variable, +so that they can be picked up by `punet_autotune.py`. + +## Running the Tuner + +### [Optional] Generate a tunable mlir +Use +[`punet.sh`](https://github.com/nod-ai/sdxl-scripts/blob/main/tuning/punet.sh) +to compile the sample matmul `mmt.mlir` (can also find here: +[`mmt_unet.mlir`](https://github.com/nod-ai/sdxl-scripts/blob/main/tuning/mmt_unet.mlir)): +```shell +punet.sh mmt.mlir -o mmt.vmfb --iree-hal-dump-executable-files-to=dump-mmt +cp ./dump-mmt/module_main_0_dispatch_0_rocm_hsaco_fb_benchmark.mlir test-benchmark.mlir +``` + +### Recommended Trial Run +For an initial trial to test the tuning loop, use: +```shell +python -m tuner.examples.punet.punet_autotune test-benchmark.mlir --num-candidates=10 +``` + +### Dry Run Test +To perform a dry run (no GPU required), use: +```shell +python -m tuner.examples.punet.punet_autotune test-benchmark.mlir --num-candidates=64 --num-model-candidates=10 --dry-run +``` + +### Basic Usage +```shell +python -m tuner.examples.punet.punet_autotune test-benchmark.mlir +``` diff --git a/tuner/examples/punet/__init__.py b/tuner/examples/punet/__init__.py new file mode 100644 index 000000000..a85ba359d --- /dev/null +++ b/tuner/examples/punet/__init__.py @@ -0,0 +1,5 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/tuner/examples/punet/mmt.mlir b/tuner/examples/punet/mmt.mlir new file mode 100644 index 000000000..b9d6c5f4c --- /dev/null +++ b/tuner/examples/punet/mmt.mlir @@ -0,0 +1,11 @@ +!matA_0 = tensor<2048x1280xf16> +!matB_0 = tensor<10240x1280xf16> +!matC_0 = tensor<2048x10240xf32> + +func.func @main_0(%arg0: !matA_0, %arg1: !matB_0) -> !matC_0 { + %cst = arith.constant 0.000000e+00 : f16 + %5 = tensor.empty() : !matC_0 + %6 = linalg.fill ins(%cst : f16) outs(%5 : !matC_0) -> !matC_0 + %8 = linalg.matmul_transpose_b ins(%arg0, %arg1 : !matA_0, !matB_0) outs(%6 : !matC_0) -> !matC_0 + return %8 : !matC_0 +} diff --git a/tuner/examples/punet/punet_autotune.py b/tuner/examples/punet/punet_autotune.py new file mode 100644 index 000000000..9fa20cb52 --- /dev/null +++ b/tuner/examples/punet/punet_autotune.py @@ -0,0 +1,191 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Sample Usage: + +python -m tuner.examples.punet.punet_autotune benchmark.mlir --lhs-dims=bmk --rhs-dims=bkn --tile-dims=*mnk --devices=hip://0,hip://1 --num-candidates=64 + + +Recommended Trial Run: + +python -m tuner.examples.punet.punet_autotune benchmark.mlir --num-candidates=1 + + +Dry Run Test (no gpu requried): + +python -m tuner.examples.punet.punet_autotune benchmark.mlir --num-candidates=64 --num-model-candidates=10 --dry-run + +""" + +from ... import libtuner +from pathlib import Path + + +class PunetClient(libtuner.TuningClient): + def get_dispatch_compile_timeout_s(self) -> int: + return 4 + + def get_dispatch_compile_command( + self, candidate_tracker: libtuner.CandidateTracker + ) -> list[str]: + mlir_path = candidate_tracker.dispatch_mlir_path + assert mlir_path is not None + command = [ + "compile_candidate.sh", + mlir_path.as_posix(), + ] + return command + + def get_dispatch_benchmark_timeout_s(self) -> int: + return 15 + + def get_dispatch_benchmark_command( + self, + candidate_tracker: libtuner.CandidateTracker, + ) -> list[str]: + compiled_vmfb_path = candidate_tracker.compiled_dispatch_path + assert compiled_vmfb_path is not None + + command = [ + "iree-benchmark-module", + f"--device={libtuner.DEVICE_ID_PLACEHOLDER}", + f"--module={compiled_vmfb_path.resolve()}", + "--hip_use_streams=true", + "--hip_allow_inline_execution=true", + "--batch_size=1000", + "--benchmark_repetitions=3", + f"--benchmark_out=dispatch_{candidate_tracker.candidate_id}_bm.json", + "--benchmark_out_format=json", + ] + + return command + + def get_model_compile_timeout_s(self) -> int: + return 300 + + def get_model_compile_command( + self, candidate_tracker: libtuner.CandidateTracker + ) -> list[str]: + mlir_spec_path = candidate_tracker.spec_path + assert mlir_spec_path is not None + target_dir = mlir_spec_path.resolve().parent.parent.parent + output_name = f"unet_candidate_{candidate_tracker.candidate_id}.vmfb" + command = [ + "compile-punet-base.sh", + "iree-compile", + "gfx942", + f"{mlir_spec_path.resolve()}", + "./punet.mlir", + "-o", + (target_dir / output_name).as_posix(), + ] + return command + + def get_model_benchmark_timeout_s(self) -> int: + return 180 + + def get_model_benchmark_command( + self, candidate_tracker: libtuner.CandidateTracker + ) -> list[str]: + unet_candidate_path = candidate_tracker.compiled_model_path + assert unet_candidate_path is not None + + command = [ + "iree-benchmark-module", + f"--device={libtuner.DEVICE_ID_PLACEHOLDER}", + "--hip_use_streams=true", + "--hip_allow_inline_execution=true", + "--device_allocator=caching", + f"--module={unet_candidate_path.resolve()}", + "--parameters=model=punet.irpa", + "--function=main", + "--input=1x4x128x128xf16", + "--input=1xsi32", + "--input=2x64x2048xf16", + "--input=2x1280xf16", + "--input=2x6xf16", + "--input=1xf16", + "--benchmark_repetitions=5", + f"--benchmark_out=model_{candidate_tracker.candidate_id}_bm.json", + "--benchmark_out_format=json", + ] + return command + + +def main(): + args = libtuner.parse_arguments() + path_config = libtuner.PathConfig() + path_config.base_dir.mkdir(parents=True, exist_ok=True) + path_config.output_unilog.touch() + candidate_trackers: list[libtuner.CandidateTracker] = [] + punet_client = PunetClient() + stop_after_phase: str = args.stop_after + + print("Setup logging") + libtuner.setup_logging(args, path_config) + print(path_config.run_log, end="\n\n") + + if not args.dry_run: + print("Validating devices") + libtuner.validate_devices(args.devices) + print("Validation successful!\n") + + print("Generating candidates...") + candidates = libtuner.generate_candidates(args, path_config, candidate_trackers) + print(f"Stored candidates in {path_config.candidates_dir}\n") + if stop_after_phase == libtuner.ExecutionPhases.generate_candidates: + return + + print("Compiling candidates...") + compiled_candidates = libtuner.compile_dispatches( + args, path_config, candidates, candidate_trackers, punet_client + ) + print(f"Compiled files are stored in {path_config.compiled_dir}\n") + if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches: + return + + print("Benchmarking compiled candidates...") + top_candidates = libtuner.benchmark_dispatches( + args, path_config, compiled_candidates, candidate_trackers, punet_client + ) + print(f"Stored results in {path_config.output_unilog}\n") + if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches: + return + + print(f"Compiling top model candidates...") + punet_candidates = libtuner.compile_models( + args, path_config, top_candidates, candidate_trackers, punet_client + ) + print(f"Model candidates compiled in {path_config.base_dir}\n") + if stop_after_phase == libtuner.ExecutionPhases.compile_models: + return + + print("Benchmarking model candidates...") + libtuner.benchmark_models( + args, path_config, punet_candidates, candidate_trackers, punet_client + ) + print(f"Stored results in {path_config.output_unilog}") + if stop_after_phase == libtuner.ExecutionPhases.benchmark_models: + return + + libtuner.summerize_top_candidates(path_config, candidate_trackers) + print(f"Stored top candidates info in {path_config.result_summary_log}\n") + + libtuner.save_pickle(path_config.candidate_trackers_pkl, candidate_trackers) + print(f"Candidate trackers are saved in {path_config.candidate_trackers_pkl}\n") + + print("Check the detailed execution logs in:") + print(path_config.run_log) + + for candidate in candidate_trackers: + libtuner.logging.debug(candidate) + if args.verbose: + print(candidate) + + +if __name__ == "__main__": + main() diff --git a/tuner/libtuner.py b/tuner/libtuner.py new file mode 100644 index 000000000..396b535f1 --- /dev/null +++ b/tuner/libtuner.py @@ -0,0 +1,1362 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Provides fundamental functions for tuning: + - generate_candidates() + - compile_dispatches() + - benchmark_dispatches() + - compile_models() + - benchmark_models() + +Requires a wrapper Python script to import `libtuner`, +use the `TuningClient` API, customize compilation and benchmarking commands, +and implement a complete tuning loop for a specific model. +""" + + +import sys +import shutil +import subprocess +import logging +import argparse +from datetime import datetime +from enum import Enum +from pathlib import Path +import time +import multiprocessing +import queue +from tqdm import tqdm +import re +import hashlib +from dataclasses import dataclass, field +from typing import Type, Optional, Callable, Iterable, Any +import pickle +import random +from abc import ABC, abstractmethod +import iree.runtime as ireert +from . import candidate_gen + + +# Default values for num_candidates and devices, change it as needed +DEFAULT_NUM_CANDIDATES = 2048 +DEFAULT_DEVICE_LIST = ["hip://0"] + +# Default values for max number of workers +DEFAULT_MAX_CPU_WORKERS = ( + multiprocessing.cpu_count() // 2 +) # the actual amount of worker that will be generated = min(max_cpu_workers, len(task_list)) + +# Declare global variables at the module level for multiprocessing +worker_id = None +device_id = None + +# Declare special symbols for libtuner to search and locate +DEVICE_ID_PLACEHOLDER = "!DEVICE_ID!" + + +@dataclass +class CandidateTracker: + candidate_id: int + dispatch_mlir_path: Optional[Path] = None + dispatch_config_path: Optional[Path] = None + configuration: Optional[candidate_gen.Configuration] = None + compilation_successful: Optional[bool] = None + compiled_dispatch_path: Optional[Path] = None + compiled_dispatch_hash: Optional[str] = None + first_benchmark_time: Optional[float] = None + first_benchmark_device_id: Optional[str] = None + spec_path: Optional[Path] = None + compiled_model_path: Optional[Path] = None + compiled_model_hash: Optional[str] = None + model_benchmark_time: Optional[float] = None + model_benchmark_device_id: Optional[str] = None + baseline_benchmark_time: Optional[float] = None + calibrated_benchmark_diff: Optional[float] = None + + +@dataclass(frozen=True) +class PathConfig: + # Preset constants + global_config_prolog_mlir: Path = Path("./config_prolog.mlir") + global_config_epilog_mlir: Path = Path("./config_epilog.mlir") + model_baseline_vmfb: Path = Path("./baseline.vmfb") + + # Dynamic paths + base_dir: Path = field(init=False) + local_config_prolog_mlir: Path = field(init=False) + local_config_epilog_mlir: Path = field(init=False) + template_mlir: Path = field(init=False) + candidates_dir: Path = field(init=False) + candidate_configs_pkl: Path = field(init=False) + compiled_dir: Path = field(init=False) + compile_failed_dir: Path = field(init=False) + specs_dir: Path = field(init=False) + + output_unilog: Path = field(init=False) + result_summary_log: Path = field(init=False) + candidate_trackers_pkl: Path = field(init=False) + + # To be set outside of class + run_log: Optional[Path] = field(init=False, default=None) + + def __post_init__(self): + object.__setattr__(self, "base_dir", self._name_base_dir()) + object.__setattr__( + self, "local_config_prolog_mlir", self.base_dir / "config_prolog.mlir" + ) + object.__setattr__( + self, "local_config_epilog_mlir", self.base_dir / "config_epilog.mlir" + ) + object.__setattr__(self, "template_mlir", self.base_dir / "template.mlir") + object.__setattr__(self, "candidates_dir", self.base_dir / "candidates") + object.__setattr__( + self, "candidate_configs_pkl", self.candidates_dir / "configs.pkl" + ) + object.__setattr__(self, "compiled_dir", self.candidates_dir / "compiled") + object.__setattr__(self, "compile_failed_dir", self.candidates_dir / "failed") + object.__setattr__(self, "specs_dir", self.candidates_dir / "specs") + object.__setattr__(self, "output_unilog", self.base_dir / "output.log") + object.__setattr__( + self, "result_summary_log", self.base_dir / "result_summary.log" + ) + object.__setattr__( + self, "candidate_trackers_pkl", self.base_dir / "candidate_trackers.pkl" + ) + + def _name_base_dir(self) -> Path: + timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M") + base_dir = Path(f"./tuning_{timestamp}") + return base_dir + + def _set_run_log(self, run_log: Path): + object.__setattr__(self, "run_log", run_log) + + def get_candidate_mlir_path(self, candidate_id: int) -> Path: + return self.candidates_dir / f"{candidate_id}.mlir" + + def get_candidate_spec_mlir_path(self, candidate_id: int) -> Path: + return self.candidates_dir / "specs" / f"{candidate_id}_spec.mlir" + + def get_exe_format(self, path: Path) -> str: + return f"./{path.as_posix()}" + + def get_compiled_dispatch_index(self, file_path: Path) -> int: + return int(file_path.stem) + + def get_candidate_spec_filename(self, candidate_id: int) -> str: + return f"{candidate_id}_spec.mlir" + + def get_compiled_model_index(self, file_path: Path) -> int: + return int(file_path.stem.split("_")[-1]) + + +class TuningClient(ABC): + @abstractmethod + def get_dispatch_compile_command( + self, candidate_tracker: CandidateTracker + ) -> list[str]: + pass + + @abstractmethod + def get_dispatch_benchmark_command( + self, candidate_tracker: CandidateTracker + ) -> list[str]: + pass + + @abstractmethod + def get_model_compile_command( + self, candidate_tracker: CandidateTracker + ) -> list[str]: + pass + + @abstractmethod + def get_model_benchmark_command( + self, candidate_tracker: CandidateTracker + ) -> list[str]: + pass + + @abstractmethod + def get_dispatch_compile_timeout_s(self) -> int: + pass + + @abstractmethod + def get_dispatch_benchmark_timeout_s(self) -> int: + pass + + @abstractmethod + def get_model_compile_timeout_s(self) -> int: + pass + + @abstractmethod + def get_model_benchmark_timeout_s(self) -> int: + pass + + +@dataclass +class RunPack: + command: list[str] + check: bool = True + timeout_seconds: Optional[int] = None + + +@dataclass +class RunResult: + process_res: Optional[subprocess.CompletedProcess] + is_timeout: bool + + +@dataclass +class TaskPack: + run_pack: RunPack + candidate_id: int + command_need_device_id: bool = False + cooling_time: int = 0 + + +@dataclass +class TaskResult: + run_result: RunResult + candidate_id: int + device_id: str + + +@dataclass +class ParsedDisptachBenchmarkResult: + candidate_id: int + benchmark_time_in_seconds: float + candidate_mlir: Path + candidate_spec_mlir: Path + + +@dataclass +class IREEBenchmarkResult: + # Default format follows output of iree-benchmark-module + candidate_id: int + result_str: str + + def get_mean_time(self) -> Optional[float]: + if not self.result_str: + return None + pattern = r"process_time/real_time_mean\s+([\d.]+)\s\w{2}" + match = re.search(pattern, self.result_str) + if not match: + return None + try: + return float(match.group(1)) + except ValueError: + return None + + +def generate_display_DBR(candidate_id: int, mean_time: float) -> str: + """Generate dispatch_benchmark_result string for displaying""" + return f"{candidate_id}\tMean Time: {mean_time:.1f}" + + +def generate_display_MBR( + candidate_vmfb_path_str: str, + device_id: str, + t1: float, + calibrated_diff: Optional[float] = None, +) -> str: + """Generate model_benchmark_result string for displaying""" + if calibrated_diff: + percentage_change = calibrated_diff * 100 + change_str = f"({percentage_change:+.3f}%)" + res_str = f"Benchmarking: {candidate_vmfb_path_str} on device {device_id}: {t1:.3g} {change_str}" + else: + res_str = ( + f"Benchmarking: {candidate_vmfb_path_str} on device {device_id}: {t1:.3g}" + ) + return res_str + + +def extract_driver_names(user_devices: list[str]) -> set[str]: + """Extract driver names from the user devices""" + return {device.split("://")[0] for device in user_devices} + + +def fetch_available_devices(drivers: list[str]) -> list[str]: + """ + Extract all available devices on the user's machine for the provided drivers + Only the user provided drivers will be queried + """ + all_device_ids: list[str] = [] + + for driver_name in drivers: + try: + driver = ireert.get_driver(driver_name) + devices = driver.query_available_devices() + all_device_ids.extend( + f"{driver_name}://{device['path']}" for device in devices + ) + all_device_ids.extend( + f"{driver_name}://{device['device_id'] - 1}" for device in devices + ) + except ValueError as e: + handle_error( + condition=True, + msg=f"Could not initialize driver {driver_name}: {e}", + error_type=ValueError, + exit_program=True, + ) + + return all_device_ids + + +def parse_devices(devices_str: str) -> list[str]: + """ + Parse a comma-separated list of device IDs e.g.: + --devices=hip://0,local-sync://default -> ["hip://0", "local-sync://default"]). + """ + devices = [device.strip() for device in devices_str.split(",")] + for device in devices: + if "://" not in device or not device: + handle_error( + condition=True, + msg=f"Invalid device list: {devices_str}. Error: {ValueError()}", + error_type=argparse.ArgumentTypeError, + ) + return devices + + +def validate_devices(user_devices: list[str]) -> None: + """Validates the user provided devices against the devices extracted by the IREE Runtime""" + user_drivers = extract_driver_names(user_devices) + + available_devices = fetch_available_devices(list(user_drivers)) + + for device in user_devices: + handle_error( + condition=(device not in available_devices), + msg=f"Invalid device specified: {device}\nFetched available devices: {available_devices}", + error_type=argparse.ArgumentError, + exit_program=True, + ) + + +class ExecutionPhases(str, Enum): + dont_stop = "" + generate_candidates = "generate-candidates" + compile_dispatches = "compile-dispatches" + benchmark_dispatches = "benchmark-dispatches" + compile_models = "compile-models" + benchmark_models = "benchmark-models" + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Autotune script") + + # Required arguments + required_args = parser.add_argument_group("Required Options") + required_args.add_argument( + "input_file", type=Path, help="Path to the input benchmark file (.mlir)" + ) + + # General options + general_args = parser.add_argument_group("General Options") + general_args.add_argument( + "--verbose", "-v", action="store_true", help="Enable verbose output to stdout" + ) + general_args.add_argument( + "--devices", + type=parse_devices, + default=DEFAULT_DEVICE_LIST, + help="Comma-separated list of device IDs (e.g., --devices=hip://,hip://GPU-UUID).", + ) + general_args.add_argument( + "--max-cpu-workers", + type=int, + default=DEFAULT_MAX_CPU_WORKERS, + help=f"Max number of workers for CPU-bounding tasks (default: {DEFAULT_MAX_CPU_WORKERS}, the number of CPUs in current system)", + ) + general_args.add_argument( + "--stop-after", + choices=[x.value for x in ExecutionPhases], + default=ExecutionPhases.dont_stop, + help="Stop execution after specified phase", + ) + general_args.add_argument( + "--num-model-candidates", + help="Maximum number of stage 2 candidates", + type=int, + default=50, + ) + general_args.add_argument( + "--dry-run", + action="store_true", + help="Do not attempt to run any modules or initialize the IREE runtime", + ) + + # candidate_gen.tune() options + candidate_gen_args = parser.add_argument_group("Candidate Generation Options") + candidate_gen_args.add_argument( + "--num-candidates", + type=int, + default=DEFAULT_NUM_CANDIDATES, + help=f"Number of candidates to be generated by candidate_gen.py (default: {DEFAULT_NUM_CANDIDATES})", + ) + candidate_gen_args.add_argument( + "--num-subgroups", + help="Number of subgroups per workgroup to use. (-1 == unconstrained)", + type=int, + default=-1, + ) + candidate_gen_args.add_argument( + "--lhs-dims", help="Map of LHS matmul dims", type=str, default="mk" + ) + candidate_gen_args.add_argument( + "--rhs-dims", help="Map of RHS matmul dims", type=str, default="nk" + ) + candidate_gen_args.add_argument( + "--tile-dims", help="Map of tile size matmul dims", type=str, default="mnk" + ) + + return parser.parse_args() + + +def setup_logging(args: argparse.Namespace, path_config: PathConfig): + log_file_name = f"autotune_{args.input_file.stem}.log" + run_log_path = path_config.base_dir / log_file_name + path_config._set_run_log(run_log_path) + + # Create file handler for logging to a file + if path_config.run_log is None: + raise + file_handler = logging.FileHandler(path_config.run_log) + file_handler.setLevel(logging.DEBUG) + + # Create stream handler for logging to the console (only warnings and higher) + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.WARNING) + + # Create a formatter that dynamically adds [levelname] for ERROR and WARNING + class CustomFormatter(logging.Formatter): + def format(self, record): + if record.levelno == logging.INFO: + return f"{record.message}" + else: + return f"[{record.levelname}] {record.message}" + + file_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") + console_formatter = CustomFormatter() + + # Set formatters to handlers + file_handler.setFormatter(file_formatter) + console_handler.setFormatter(console_formatter) + + # Configure the root logger + logging.basicConfig( + level=logging.DEBUG, # Set the root logger to the lowest level + handlers=[file_handler, console_handler], + ) + + # If verbose flag is set, add a console handler for INFO level and higher + if args.verbose: + verbose_console_handler = logging.StreamHandler() + verbose_console_handler.setLevel(logging.DEBUG) + verbose_console_handler.setFormatter(file_formatter) + logging.getLogger().addHandler(verbose_console_handler) + + # config logger in candidate_gen.py + tune_logger = logging.getLogger("tune") + tune_logger.setLevel(logging.DEBUG) + + # Log all arguments + logging.debug(f"Input Arguments:") + for arg, value in vars(args).items(): + tune_logger.info(f"{arg}: {value}") + + +def handle_error( + condition: bool, + msg: str, + level: int = logging.ERROR, + error_type: Type[BaseException] = Exception, + exit_program: bool = False, +) -> None: + """If meets the condition, handles errors with logging and optional program exit""" + if not condition: + return + + # Log the message with the specified level + if level == logging.CRITICAL: + logging.critical(msg) + raise error_type(msg) + if level == logging.ERROR: + logging.error(msg) + raise error_type(msg) + elif level == logging.WARNING: + logging.warning(msg) + elif level == logging.INFO: + logging.info(msg) + elif level == logging.DEBUG: + logging.debug(msg) + else: + raise ValueError( + "Invalid logging level specified: choose from logging.CRITICAL, logging.ERROR, logging.WARNING, logging.INFO, logging.DEBUG" + ) + + if exit_program: + sys.exit(1) + + +def init_worker_context(queue: multiprocessing.Queue) -> None: + """Assign a static index to current process as the worker ordinal, and specify the device indice to be used""" + global worker_id, device_id + + worker_id, device_id = queue.get() + + +def create_worker_context_queue(device_ids: list[int]) -> queue.Queue[tuple[int, int]]: + """Create queue contains Worker ID and Device ID for worker initialization""" + worker_contexts_queue = multiprocessing.Manager().Queue() + for worker_id, device_id in enumerate(device_ids): + worker_contexts_queue.put((worker_id, device_id)) + + return worker_contexts_queue + + +def run_command(run_pack: RunPack) -> TaskResult: + command = run_pack.command + check = run_pack.check + timeout_seconds = run_pack.timeout + + result = None + is_timeout = False + try: + # Convert the command list to a command string for logging + command_str = " ".join(command) + logging.debug(f"Run: {command_str}") + + # Add timeout to subprocess.run call + result = subprocess.run( + command, + check=check, + capture_output=True, + text=True, + timeout=timeout_seconds, + ) + + if result.stdout: + logging.debug(f"stdout: {result.stdout}") + if result.stderr: + logging.debug(f"stderr: {result.stderr}") + except subprocess.TimeoutExpired as e: + logging.warning( + f"Command '{command_str}' timed out after {timeout_seconds} seconds." + ) + is_timeout = True + except subprocess.CalledProcessError as e: + print(e.output) + logging.error( + f"Command '{command_str}' returned non-zero exit status {e.returncode}." + ) + logging.error(f"Command '{command_str}' failed with error: {e.stderr}") + if check: + raise + except KeyboardInterrupt: + print("Ctrl+C detected, terminating child processes...") + + return RunResult(result, is_timeout) + + +def run_command_wrapper(task_pack: TaskPack) -> TaskResult: + """Help handle extra requirements and record more data for run_command()""" + if task_pack.command_need_device_id: + # Worker searches for the special symbol and substitutes it with the actual device_id + pattern = re.compile(re.escape(DEVICE_ID_PLACEHOLDER)) + task_pack.run_pack.command = [ + pattern.sub(str(device_id), s) for s in task_pack.run_pack.command + ] + + run_result = run_command(task_pack.run_pack) + + task_result = TaskResult( + run_result, task_pack.candidate_id, device_id=str(-1) + ) # Main process + if device_id: + task_result = TaskResult( + run_result, task_pack.candidate_id, device_id + ) # Subprocess + + time.sleep(task_pack.cooling_time) + + return task_result + + +def multiprocess_progress_wrapper( + num_worker: int, + task_list: list, + function: Callable, + initializer: Optional[Callable] = None, + initializer_inputs: Optional[Iterable[Any]] = None, +) -> list[Any]: + """Wrapper of multiprocessing pool and progress bar""" + results = [] + initializer_inputs = initializer_inputs or () + + # Create a multiprocessing pool + with multiprocessing.Pool( + num_worker, initializer, initializer_inputs + ) as worker_pool: + # Use tqdm to create a progress bar + with tqdm(total=len(task_list)) as pbar: + try: + # Use imap_unordered to asynchronously execute the worker function on each task + for result in worker_pool.imap_unordered(function, task_list): + pbar.update(1) # Update progress bar + results.append(result) + except KeyboardInterrupt: + # If Ctrl+C is pressed, terminate all child processes + worker_pool.terminate() + worker_pool.join() + sys.exit(1) # Exit the script + + return results + + +def numerical_sort_key(path: Path) -> tuple[int | float, str]: + """ + Define a sort key function that splits the filename into a numeric and a string part. + Order: 0 | 0_a | 0_b | 1 | 1_a | 2 + """ + numeric_part: int | float + # Extract the numeric part at the start of the filename + match = re.match(r"(\d+)", path.stem) + if match: + numeric_part = int(match.group(1)) + # The rest of the filename after the numeric part + remaining_part = path.stem[len(match.group(0)) :] + else: + numeric_part = float("inf") + remaining_part = path.stem + return (numeric_part, remaining_part) + + +def calculate_md5(file_path: Path) -> str: + md5 = hashlib.md5() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + md5.update(chunk) + return md5.hexdigest() + + +def find_collisions( + hash_list: list[tuple[int, str]] +) -> tuple[bool, list[tuple[str, list[int]]]]: + """ + Detect hash value collisions + Take input list of candidate index numbers and hash value strings: ex. [(1, 'abc'), (2, 'def'), (3, 'abc')] + Return collision boolean value and list of unique hash values along with their corresponding indices: ex. [('abc', [1,3]), ('def', [2])] + """ + hash_count: dict[str, list[int]] = {} + + # Count occurrences of each hash_val + for index, hash_val in hash_list: + if hash_val in hash_count: + hash_count[hash_val].append(index) + else: + hash_count[hash_val] = [index] + + # Prepare output for all hash values + hash_values = [(hash_val, indices) for hash_val, indices in hash_count.items()] + + # Determine if there are collisions + collisions_exist = any(len(indices) > 1 for hash_val, indices in hash_count.items()) + + return collisions_exist, hash_values + + +def load_pickle(file_path: Path) -> list[Any]: + handle_error( + condition=(not file_path.exists()), + msg=f"Configuration file not found: {file_path}", + error_type=FileNotFoundError, + ) + with open(file_path, "rb") as file: + loaded_array = pickle.load(file) + return loaded_array + + +def save_pickle(file_path: Path, input_list: list[Any]) -> None: + with open(file_path, "wb") as file: + pickle.dump(input_list, file) + + +def append_to_file(lines: list[str], filepath: Path, title: str = "") -> None: + """Appends new content to the end of the output.log.""" + title_str = "=" * 5 + f" {title} " + "=" * 5 + "\n" if title != "" else "" + with open(filepath, "a") as file: + file.write(title_str) + file.writelines(lines) + file.write("\n") + + +def generate_candidates( + args: argparse.Namespace, + path_config: PathConfig, + candidate_trackers: list[CandidateTracker], +) -> list[int]: + """Generate candidate files for tuning. Returns the list of candidate indexes""" + logging.debug("generate_candidates()") + + try: + shutil.copy( + path_config.global_config_epilog_mlir, path_config.local_config_epilog_mlir + ) + shutil.copy( + path_config.global_config_prolog_mlir, path_config.local_config_prolog_mlir + ) + except FileNotFoundError as e: + handle_error( + condition=True, + msg=f"Configuration file not found: {e}", + error_type=FileNotFoundError, + ) + + shutil.copy(args.input_file, path_config.template_mlir) + + mlirs = [] + try: + logging.debug("Captured messages from candidate_gen.py:") + candidate_gen.tune( + input=str(path_config.template_mlir), + output=str(path_config.candidates_dir), + limit=args.num_candidates, + num_subgroups=args.num_subgroups, + lhs_dims=args.lhs_dims, + rhs_dims=args.rhs_dims, + tile_dims=args.tile_dims, + ) + mlirs = sorted( + path_config.candidates_dir.glob("*.mlir"), key=numerical_sort_key + ) + except Exception as e: + logging.error("An error occurred during candidates generation: %s", str(e)) + # Capture and log debug messages from candidate_gen.py + tune_logger = logging.getLogger("tune") + for handler in logging.getLogger().handlers: + if isinstance(handler, logging.FileHandler): + tune_logger.handlers.append(handler) + tune_logger.exception("Error in candidate_gen.py:") + raise + logging.debug("candidate_gen.py ends") + + candidate_configs = load_pickle(path_config.candidate_configs_pkl) + candidate_configs.insert(0, None) # No Configuration class for 0.mlir + + # Create candidate trackers + assert len(mlirs) // 2 + 1 == len(candidate_configs) + candidates = [] + for mlir in mlirs: + if "_config.mlir" not in mlir.name: + candidates.append(int(mlir.stem)) + new_candidate = CandidateTracker( + candidate_id=int(mlir.stem), + dispatch_mlir_path=mlir, + configuration=candidate_configs[int(mlir.stem)], + ) + candidate_trackers.append(new_candidate) + else: + candidate_trackers[ + int(mlir.stem.split("_config")[0]) + ].dispatch_config_path = mlir + + handle_error( + condition=(len(candidates) == 0), msg="Failed to generate any candidates" + ) + + logging.info(f"Generated [{len(candidates)}] candidates") + + return candidates + + +def collision_handler(index_hash_list: list[tuple[int, str]]) -> tuple[bool, list[int]]: + """If a collision is found, generate a list of new indexes. If no collision, `unique_indexes = []`""" + # Check if candidate produces tbe same .vmfb + collision_detected, hash_list = find_collisions(index_hash_list) + unique_indexes: list[int] = [] + if not collision_detected: + return collision_detected, unique_indexes + + # If a collision is detected, select the first one from the collided list + logging.warning("Collisions detected") + for hash_val, indices in hash_list: + if len(indices) != 1: + logging.warning(f"Hash value '{hash_val}' collided at candidate {indices}.") + unique_indexes.append(indices[0]) + + return collision_detected, unique_indexes + + +def compile_dispatches( + args: argparse.Namespace, + path_config: PathConfig, + candidates: list[int], + candidate_trackers: list[CandidateTracker], + tuning_client: TuningClient, +) -> list[int]: + logging.debug("compile_dispatches()") + + if not candidates: + logging.warning("No candidates to compile.") + return [] + + path_config.compiled_dir.mkdir(parents=True, exist_ok=True) + path_config.compile_failed_dir.mkdir(parents=True, exist_ok=True) + path_config.specs_dir.mkdir(parents=True, exist_ok=True) + + task_list = [ + TaskPack( + RunPack( + command=tuning_client.get_dispatch_compile_command( + candidate_trackers[i] + ), + check=False, + timeout_seconds=tuning_client.get_dispatch_compile_timeout_s(), + ), + candidate_id=i, + ) + for i in candidates + ] + num_worker = min(args.max_cpu_workers, len(task_list)) + multiprocess_progress_wrapper( + num_worker=num_worker, task_list=task_list, function=run_command_wrapper + ) + + # Note: failed/incompleted candidates can also be detected by checking if subprocess.res is None + compiled_files = sorted( + path_config.compiled_dir.glob("*.vmfb"), key=numerical_sort_key + ) + failed_files = sorted( + path_config.compile_failed_dir.glob("*.mlir"), key=numerical_sort_key + ) + + total, good, bad = len(task_list), len(compiled_files), len(failed_files) + compiling_rate = good / total * 100 + logging.info( + f"Total: {total} | Compiled: {good} | Failed: {bad} | Compiling Rate: {compiling_rate:.1f}%" + ) + + # Update candidate tracker + for failed_file in failed_files: + index = path_config.get_compiled_dispatch_index(failed_file) + candidate_trackers[index].compilation_successful = False + compiled_candidates = [] + compiled_candidates_hash_list = [] + for compiled_file in compiled_files: + index = path_config.get_compiled_dispatch_index(compiled_file) + compiled_candidates.append(index) + candidate_trackers[index].compilation_successful = True + candidate_trackers[index].compiled_dispatch_path = compiled_file + compiled_vmfb_path = candidate_trackers[index].compiled_dispatch_path + assert compiled_vmfb_path is not None + hash_val = calculate_md5(compiled_vmfb_path) + candidate_trackers[index].compiled_dispatch_hash = hash_val + compiled_candidates_hash_list.append((index, hash_val)) + + handle_error( + condition=(good == 0), msg="Failed to compile all candidate .mlir files" + ) + handle_error( + condition=(compiling_rate < 10), + msg=f"Compiling rate [{compiling_rate:.1f}%] < 10%", + level=logging.WARNING, + ) + + collision_detected, unique_indexes = collision_handler( + compiled_candidates_hash_list + ) + if collision_detected: + logging.info(f"Remains [{len(unique_indexes)}] unique candidate indexes") + + return compiled_candidates if not collision_detected else unique_indexes + + +def parse_dispatch_benchmark_results( + path_config: PathConfig, + benchmark_results: list[TaskResult], + candidate_trackers: list[CandidateTracker], +) -> tuple[list[ParsedDisptachBenchmarkResult], list[str]]: + benchmark_result_configs = [] + dump_list = [] + incomplete_list = [] + + for benchmark_result in benchmark_results: + candidate_id = benchmark_result.candidate_id + process_res = benchmark_result.run_result.process_res + + if not process_res: + if benchmark_result.run_result.is_timeout: + incomplete_list.append(candidate_id) + continue + + res_str = process_res.stdout + res = IREEBenchmarkResult(candidate_id, res_str) + benchmark_time = res.get_mean_time() + assert benchmark_time is not None + candidate_trackers[candidate_id].first_benchmark_time = benchmark_time + candidate_trackers[ + candidate_id + ].spec_path = path_config.specs_dir / path_config.get_candidate_spec_filename( + candidate_id + ) + mlir_path = candidate_trackers[candidate_id].dispatch_mlir_path + spec_path = candidate_trackers[candidate_id].spec_path + assert mlir_path is not None and spec_path is not None + dump_list.append(generate_display_DBR(candidate_id, benchmark_time) + "\n") + + benchmark_result_configs.append( + ( + ParsedDisptachBenchmarkResult( + candidate_id, + benchmark_time, + mlir_path, + spec_path, + ) + ) + ) + + if incomplete_list: + dump_list += [f"Candidate {i} not completed" for i in incomplete_list] + + return benchmark_result_configs, dump_list + + +def generate_sample_task_result( + stdout: str, candidate_id: int, device_id: str +) -> TaskResult: + res = subprocess.CompletedProcess( + args=[""], + stdout=stdout, + returncode=0, + ) + return TaskResult(result=res, candidate_id=candidate_id, device_id=device_id) + + +def generate_dryrun_dispatch_benchmark_results( + compiled_candidates: list[int], +) -> list[TaskResult]: + logging.debug("generate_dryrun_dispatch_benchmark_results()") + + task_results = [ + generate_sample_task_result( + f"process_time/real_time_mean {random.uniform(100.0, 500.0):.3g} ms", + i, + str(0), + ) + for i in compiled_candidates + ] + + return task_results + + +def generate_dryrun_model_benchmark_results( + model_candidates: list[int], +) -> tuple[list[TaskResult], list[TaskResult]]: + candidate_results = [] + for i, j in enumerate(model_candidates): + stdout = f"process_time/real_time_mean {random.uniform(100.0, 500.0):.3g} ms" + candidate_results.append(generate_sample_task_result(stdout, j, str(i % 3))) + + baseline_results = [ + generate_sample_task_result( + f"process_time/real_time_mean {random.uniform(100.0, 500.0):.3g} ms", + 0, + str(i), + ) + for i in range(3) + ] + + return candidate_results, baseline_results + + +def benchmark_dispatches( + args: argparse.Namespace, + path_config: PathConfig, + compiled_candidates: list[int], + candidate_trackers: list[CandidateTracker], + tuning_client: TuningClient, +): + logging.debug("benchmark_dispatches()") + + if args.dry_run: + benchmark_results = generate_dryrun_dispatch_benchmark_results( + compiled_candidates + ) + else: + # Benchmarking dispatch candidates + task_list = [ + TaskPack( + RunPack( + command=tuning_client.get_dispatch_benchmark_command( + candidate_trackers[i] + ), + check=False, + timeout_seconds=tuning_client.get_dispatch_benchmark_timeout_s(), + ), + candidate_id=i, + command_need_device_id=True, + ) + for i in compiled_candidates + ] + worker_context_queue = create_worker_context_queue(args.devices) + benchmark_results = multiprocess_progress_wrapper( + num_worker=len(args.devices), + task_list=task_list, + function=run_command_wrapper, + initializer=init_worker_context, + initializer_inputs=(worker_context_queue,), + ) + + ( + parsed_benchmark_results, + dispatch_benchmark_dump_list, + ) = parse_dispatch_benchmark_results( + path_config, benchmark_results, candidate_trackers + ) + append_to_file( + dispatch_benchmark_dump_list, + filepath=path_config.output_unilog, + title="All Dispatch Benchmark Results", + ) + + benchmarking_rate = (len(parsed_benchmark_results) / len(benchmark_results)) * 100 + logging.info( + f"Total: {len(benchmark_results)} | Benchmarked: {len(parsed_benchmark_results)} | Failed: {len(benchmark_results) - len(parsed_benchmark_results)} | Benchmarking Rate: {benchmarking_rate:.1f}%" + ) + handle_error( + condition=(len(benchmark_results) == 0), + msg="Failed to benchmark all candidate .vmfb files", + ) + + # Select top candidates + best_results = sorted( + parsed_benchmark_results, key=lambda x: float(x.benchmark_time_in_seconds) + )[: args.num_model_candidates] + logging.info(f"Selected top[{len(best_results)}]") + + dump_list = [ + f"{result.benchmark_time_in_seconds}\t{result.candidate_mlir.as_posix()}\t{result.candidate_spec_mlir.as_posix()}\n" + for result in best_results + ] + append_to_file( + dump_list, filepath=path_config.output_unilog, title="Top Candidates Results" + ) + + top_candidates = [result.candidate_id for result in best_results] + return top_candidates + + +def compile_models( + args: argparse.Namespace, + path_config: PathConfig, + candidates: list[int], + candidate_trackers: list[CandidateTracker], + tuning_client: TuningClient, +) -> list[int]: + logging.debug("compile_models()") + + candidate_trackers[0].compiled_model_path = path_config.model_baseline_vmfb + + if args.dry_run: + for i in candidates: + candidate_trackers[i].compiled_model_path = Path(f"model_{i}.vmfb") + return candidates + + if not candidates: + logging.warning("No model candidates to compile.") + return [] + + task_list = [ + TaskPack( + RunPack( + command=tuning_client.get_model_compile_command(candidate_trackers[i]), + check=False, + timeout_seconds=tuning_client.get_model_compile_timeout_s(), + ), + candidate_id=i, + ) + for i in candidates + if i != 0 + ] + num_worker = min(args.max_cpu_workers, len(task_list)) + multiprocess_progress_wrapper( + num_worker=num_worker, task_list=task_list, function=run_command_wrapper + ) + + model_candidates_files = list(path_config.base_dir.glob("*.vmfb")) + + model_candidates_indexes = [] + model_candidates_hash_list = [] + + # Update candidate tracker + for model_candidate in model_candidates_files: + assert model_candidate is not None + index = path_config.get_compiled_model_index(model_candidate) + candidate_trackers[index].compiled_model_path = model_candidate + hash_val = calculate_md5(model_candidate) + candidate_trackers[index].compiled_model_hash = hash_val + model_candidates_hash_list.append((index, hash_val)) + model_candidates_indexes.append(index) + + # Check if model candidate produces tbe same .vmfb + collision_detected, unique_model_candidates_indexes = collision_handler( + model_candidates_hash_list + ) + + if collision_detected: + logging.info( + f"Remains [{len(unique_model_candidates_indexes)}] unique candidate indexes" + ) + + return ( + unique_model_candidates_indexes + if collision_detected + else model_candidates_indexes + ) + + +def group_benchmark_results_by_device_id( + benchmark_results: list[TaskResult], +) -> list[list[TaskResult]]: + """ + Groups benchmark results by device ID. + + e.g. + [TaskResult(res1, device_1), TaskResult(res2, device_2), TaskResult(res3, device_1)] + -----> + [ [TaskResult(res1, device_1), TaskResult(res3, device_1)], [TaskResult(res2, device_2)] ] + """ + grouped_results: dict[str, list[TaskResult]] = {} + for result in benchmark_results: + assert result.device_id is not None + if result.device_id not in grouped_results: + grouped_results[result.device_id] = [] + grouped_results[result.device_id].append(result) + + grouped_benchmark_results = [ + grouped_results[device_id] for device_id in sorted(grouped_results) + ] + + return grouped_benchmark_results + + +def parse_model_benchmark_results( + candidate_trackers: list[CandidateTracker], + candidate_results: list[TaskResult], + baseline_results: list[TaskResult], +): + """Update candidate_tracker and format a list of result strings to be saved later.""" + candidate_results = sorted(candidate_results, key=lambda br: br.device_id) + baseline_results = sorted(baseline_results, key=lambda tr: tr.device_id) + + # Assign candidates to the same groups by device_id + grouped_candidate_results = group_benchmark_results_by_device_id(candidate_results) + + # Insert baseline results to the head of each list + grouped_benchmark_results = [ + [x] + y for x, y in zip(baseline_results, grouped_candidate_results) + ] + + dump_list = [] + incomplete_list: list[ + tuple[int, Optional[str]] + ] = [] # format: [(candidate_id, device_id)] + + baseline_time = None + for same_device_results in grouped_benchmark_results: + dump_unsort_list: list[tuple[float, str]] = [] + for task_result in same_device_results: + candidate_id = task_result.candidate_id + device_id = task_result.device_id + process_res = task_result.run_result.process_res + + # Check if benchmarking has completed + if not process_res: + if task_result.run_result.is_timeout: + incomplete_list.append((candidate_id, device_id)) + if candidate_id == 0: + baseline_time = None + continue + + result_str = process_res.stdout + res = IREEBenchmarkResult(candidate_id, result_str) + benchmark_time = res.get_mean_time() + assert benchmark_time is not None + + # Record baseline benchmarking result and skip rest processes + if candidate_id == 0: + baseline_time = benchmark_time + baseline_vmfb_path = candidate_trackers[ + candidate_id + ].compiled_model_path + assert baseline_vmfb_path is not None + dump_str = ( + generate_display_MBR( + candidate_vmfb_path_str=baseline_vmfb_path.as_posix(), + device_id=device_id, + t1=benchmark_time, + ) + + "\n\n" + ) + dump_list.append(dump_str) + continue + + # Update candidate_tracker + candidate_trackers[candidate_id].model_benchmark_time = benchmark_time + candidate_trackers[candidate_id].model_benchmark_device_id = device_id + + # Calculate candidate improvement based on baseline. + if baseline_time: + candidate_trackers[candidate_id].baseline_benchmark_time = baseline_time + calibrated_benchmark_diff = ( + benchmark_time - baseline_time + ) / baseline_time + candidate_trackers[ + candidate_id + ].calibrated_benchmark_diff = calibrated_benchmark_diff + else: + calibrated_benchmark_diff = None + + # Collect candidate dump str + candidate_vmfb_path = candidate_trackers[candidate_id].compiled_model_path + assert candidate_vmfb_path is not None + dump_str = ( + generate_display_MBR( + candidate_vmfb_path_str=candidate_vmfb_path.as_posix(), + device_id=device_id, + t1=benchmark_time, + calibrated_diff=calibrated_benchmark_diff, + ) + + "\n\n" + ) + + dump_unsort_list.append((benchmark_time, dump_str)) + + # Sort model candidate benchmarking result str in ascending time order. + dump_list = dump_list + [ + dump_str for _, dump_str in sorted(dump_unsort_list, key=lambda x: x[0]) + ] + + # Store incomplete .vmfb file at the end of dump_list. + for index, device in incomplete_list: + file_path = candidate_trackers[index].compiled_model_path + assert file_path is not None + error_msg = f"Benchmarking result of {file_path.as_posix()} on device {device} is incomplete" + handle_error(condition=True, msg=error_msg, level=logging.WARNING) + dump_list.append(error_msg + "\n") + + return dump_list + + +def benchmark_models( + args: argparse.Namespace, + path_config: PathConfig, + model_candidates: list[int], + candidate_trackers: list[CandidateTracker], + tuning_client: TuningClient, +): + """Benchmark U-Net candidate files and log the results.""" + logging.debug("benchmark_models()") + + if args.dry_run: + candidate_results, baseline_results = generate_dryrun_model_benchmark_results( + model_candidates + ) + else: + # Benchmarking model candidates + worker_context_queue = create_worker_context_queue(args.devices) + benchmark_task_list = [ + TaskPack( + RunPack( + command=tuning_client.get_model_benchmark_command( + candidate_trackers[i] + ), + check=False, + timeout_seconds=tuning_client.get_dispatch_benchmark_timeout_s(), + ), + candidate_id=i, + command_need_device_id=True, + cooling_time=10, + ) + for i in model_candidates + ] + candidate_results = multiprocess_progress_wrapper( + num_worker=len(args.devices), + task_list=benchmark_task_list, + function=run_command_wrapper, + initializer=init_worker_context, + initializer_inputs=(worker_context_queue,), + ) + + # Benchmarking baselines on each involved device + candidate_trackers[0].compiled_model_path = path_config.model_baseline_vmfb + worker_context_queue = create_worker_context_queue(args.devices) + baseline_task_list = [ + TaskPack( + RunPack( + command=tuning_client.get_model_benchmark_command( + candidate_trackers[0] + ), + check=False, + timeout_seconds=tuning_client.get_model_benchmark_timeout_s(), + ), + candidate_id=0, + command_need_device_id=True, + ) + ] * len(group_benchmark_results_by_device_id(candidate_results)) + baseline_results = multiprocess_progress_wrapper( + num_worker=len(args.devices), + task_list=baseline_task_list, + function=run_command_wrapper, + initializer=init_worker_context, + initializer_inputs=(worker_context_queue,), + ) + + dump_list = parse_model_benchmark_results( + candidate_trackers, candidate_results, baseline_results + ) + + append_to_file( + dump_list, filepath=path_config.output_unilog, title="Model Benchmark Results" + ) + + +def summerize_top_candidates( + path_config: PathConfig, candidate_trackers: list[CandidateTracker] +): + dump_list = [] + top_candidates = [] + for candidate in candidate_trackers: + if candidate.candidate_id == 0 or candidate.model_benchmark_time is None: + continue + top_candidates.append( + (candidate.candidate_id, candidate.model_benchmark_time) + ) # collect (id, time) + + top_candidates = sorted( + top_candidates, key=lambda x: x[1] + ) # sort the list in ascending benchmark time order + top_candidate_ids = [item[0] for item in top_candidates] # get list of candidate id + + for candidate_id in top_candidate_ids: + candidate = candidate_trackers[candidate_id] + assert candidate.dispatch_config_path is not None + with open(candidate.dispatch_config_path, "r") as file: + config_file_contents = file.read() + final_str = f"Candidate {candidate.candidate_id}:\nModel benchmark time: {candidate.model_benchmark_time} on device {candidate.model_benchmark_device_id}\nDispatch benchmark time: {candidate.first_benchmark_time} on device {candidate.model_benchmark_device_id}\nSpec file path: {candidate.spec_path}\nSpec contents:{config_file_contents}\n\n" + dump_list.append(final_str) + + with open(path_config.result_summary_log, "w") as file: + file.writelines(dump_list) + + +def sanitize_filename(filename: str) -> str: + # Replace invalid characters by an underscore + sanitized = re.sub(r"[^\w\.-]", "_", filename) + return sanitized diff --git a/tuner/libtuner_test.py b/tuner/libtuner_test.py new file mode 100644 index 000000000..3cbaa5ed0 --- /dev/null +++ b/tuner/libtuner_test.py @@ -0,0 +1,406 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import argparse +import pytest +from unittest.mock import call, patch, MagicMock +from . import libtuner + +""" +Usage: python -m pytest test_libtuner.py +""" + + +def test_group_benchmark_results_by_device_id(): + # Create mock TaskResult objects with device_id attributes + task_result_1 = MagicMock() + task_result_1.device_id = "device_1" + + task_result_2 = MagicMock() + task_result_2.device_id = "device_2" + + task_result_3 = MagicMock() + task_result_3.device_id = "device_1" + + benchmark_results = [task_result_1, task_result_2, task_result_3] + + expected_grouped_results = [ + [task_result_1, task_result_3], # Grouped by device_1 + [task_result_2], # Grouped by device_2 + ] + + grouped_results = libtuner.group_benchmark_results_by_device_id(benchmark_results) + + assert grouped_results == expected_grouped_results + assert grouped_results[0][0].device_id == "device_1" + assert grouped_results[1][0].device_id == "device_2" + + +def test_find_collisions(): + input = [(1, "abc"), (2, "def"), (3, "abc")] + assert libtuner.find_collisions(input) == (True, [("abc", [1, 3]), ("def", [2])]) + input = [(1, "abc"), (2, "def"), (3, "hig")] + assert libtuner.find_collisions(input) == ( + False, + [("abc", [1]), ("def", [2]), ("hig", [3])], + ) + + +def test_collision_handler(): + input = [(1, "abc"), (2, "def"), (3, "abc"), (4, "def"), (5, "hig")] + assert libtuner.collision_handler(input) == (True, [1, 2, 5]) + input = [(1, "abc"), (2, "def"), (3, "hig")] + assert libtuner.collision_handler(input) == (False, []) + + +def test_IREEBenchmarkResult_get(): + # Time is int + normal_str = r""" + ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ + Benchmark Time CPU Iterations UserCounters... + ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ + BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time 271 us 275 us 3000 items_per_second=3.65611k/s + BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time 274 us 275 us 3000 items_per_second=3.65481k/s + BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time 273 us 275 us 3000 items_per_second=3.65671k/s + BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time_mean 274 us 275 us 3 items_per_second=3.65587k/s + BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time_mean 275 us 275 us 3 items_per_second=3.65611k/s + BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time_stddev 0.073 us 0.179 us 3 items_per_second=0.971769/s + BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time_cv 0.03 % 0.07 % 3 items_per_second=0.03% + """ + res = libtuner.IREEBenchmarkResult(candidate_id=1, result_str=normal_str) + assert res.get_mean_time() == float(274) + + # Time is float + res = libtuner.IREEBenchmarkResult( + candidate_id=2, + result_str="process_time/real_time_mean 123.45 us, process_time/real_time_mean 246.78 us", + ) + assert res.get_mean_time() == 123.45 + + # Invalid str + res = libtuner.IREEBenchmarkResult(candidate_id=3, result_str="hello world") + assert res.get_mean_time() == None + res = libtuner.IREEBenchmarkResult(candidate_id=4, result_str="") + assert res.get_mean_time() == None + + +def test_generate_display_BR(): + output = libtuner.generate_display_DBR(1, 3.14) + expected = f"1\tMean Time: 3.1" + assert output == expected, "DispatchBenchmarkResult generates invalid sample string" + + output = libtuner.generate_display_MBR("baseline.vmfb", str(1), 567.89) + expected = "Benchmarking: baseline.vmfb on device 1: 568" + assert output == expected, "ModelBenchmarkResult generates invalid sample string" + output = libtuner.generate_display_MBR("baseline.vmfb", str(1), 567.89, 0.0314) + expected = "Benchmarking: baseline.vmfb on device 1: 568 (+3.140%)" + assert output == expected, "ModelBenchmarkResult generates invalid sample string" + output = libtuner.generate_display_MBR("baseline.vmfb", str(1), 567.89, -3.14) + expected = "Benchmarking: baseline.vmfb on device 1: 568 (-314.000%)" + assert output == expected, "ModelBenchmarkResult generates invalid sample string" + + +def test_parse_dispatch_benchmark_results(): + base_path = libtuner.Path("/mock/base/dir") + spec_dir = base_path / "specs" + path_config = libtuner.PathConfig() + object.__setattr__(path_config, "specs_dir", spec_dir) + + mock_result_1 = MagicMock() + mock_result_1.run_result.process_res.stdout = "process_time/real_time_mean 100.0 us" + mock_result_1.candidate_id = 1 + mock_result_2 = MagicMock() + mock_result_2.run_result.process_res.stdout = "process_time/real_time_mean 200.0 us" + mock_result_2.candidate_id = 2 + mock_result_3 = MagicMock() + mock_result_3.run_result.process_res = None # Incomplete result + mock_result_3.candidate_id = 3 + benchmark_results = [mock_result_1, mock_result_2, mock_result_3] + + candidate_trackers = [] + for i in range(4): + tracker = libtuner.CandidateTracker(candidate_id=i) + tracker.dispatch_mlir_path = libtuner.Path(f"/mock/mlir/path/{i}.mlir") + candidate_trackers.append(tracker) + + expected_parsed_results = [ + libtuner.ParsedDisptachBenchmarkResult( + candidate_id=1, + benchmark_time_in_seconds=100.0, + candidate_mlir=libtuner.Path("/mock/mlir/path/1.mlir"), + candidate_spec_mlir=libtuner.Path("/mock/base/dir/specs/1_spec.mlir"), + ), + libtuner.ParsedDisptachBenchmarkResult( + candidate_id=2, + benchmark_time_in_seconds=200.0, + candidate_mlir=libtuner.Path("/mock/mlir/path/2.mlir"), + candidate_spec_mlir=libtuner.Path("/mock/base/dir/specs/2_spec.mlir"), + ), + ] + expected_dump_list = [ + "1\tMean Time: 100.0\n", + "2\tMean Time: 200.0\n", + "Candidate 3 not completed", + ] + + parsed_results, dump_list = libtuner.parse_dispatch_benchmark_results( + path_config, benchmark_results, candidate_trackers + ) + + assert parsed_results == expected_parsed_results + assert dump_list == expected_dump_list + assert candidate_trackers[1].first_benchmark_time == 100.0 + assert candidate_trackers[1].spec_path == libtuner.Path( + "/mock/base/dir/specs/1_spec.mlir" + ) + assert candidate_trackers[2].first_benchmark_time == 200.0 + assert candidate_trackers[2].spec_path == libtuner.Path( + "/mock/base/dir/specs/2_spec.mlir" + ) + + +def test_parse_model_benchmark_results(): + # Setup mock data for candidate_trackers + tracker0 = libtuner.CandidateTracker(0) + tracker0.compiled_model_path = libtuner.Path("/path/to/baseline.vmfb") + + tracker1 = libtuner.CandidateTracker(1) + tracker1.compiled_model_path = libtuner.Path("/path/to/model_1.vmfb") + + tracker2 = libtuner.CandidateTracker(2) + tracker2.compiled_model_path = libtuner.Path("/path/to/model_2.vmfb") + + tracker3 = libtuner.CandidateTracker(3) + tracker3.compiled_model_path = libtuner.Path("/path/to/model_3.vmfb") + + candidate_trackers = [tracker0, tracker1, tracker2, tracker3] + + # Setup mock data for task results + result1 = MagicMock() + result1.run_result.process_res.stdout = "1.23" + result1.candidate_id = 1 + result1.device_id = "device1" + + result2 = MagicMock() + result2.run_result.process_res.stdout = "4.56" + result2.candidate_id = 2 + result2.device_id = "device2" + + result3 = MagicMock() + result3.run_result.process_res.stdout = "0.98" + result3.candidate_id = 0 + result3.device_id = "device1" + + result4 = MagicMock() + result4.run_result.process_res.stdout = "4.13" + result4.candidate_id = 0 + result4.device_id = "device2" + + # Incomplete baseline on device3 + result5 = MagicMock() + result5.run_result.process_res = None + result5.candidate_id = 0 + result5.device_id = "device3" + + result6 = MagicMock() + result6.run_result.process_res.stdout = "3.38" + result6.candidate_id = 3 + result6.device_id = "device3" + + candidate_results = [result1, result2, result6] + baseline_results = [result3, result4, result5] + + # Skip real benchmark extraction, directly use given values from above + def mock_get_mean_time(self): + return float(self.result_str) if self.result_str else None + + # Mock IREEBenchmarkResult to return wanted benchmark times + with patch( + f"{libtuner.__name__}.IREEBenchmarkResult.get_mean_time", new=mock_get_mean_time + ): + # Mock handle_error to avoid actual logging during tests + with patch(f"{libtuner.__name__}.handle_error") as mock_handle_error: + dump_list = libtuner.parse_model_benchmark_results( + candidate_trackers, candidate_results, baseline_results + ) + + # Verify interactions with candidate_trackers + assert tracker1.model_benchmark_time == 1.23 + assert tracker1.model_benchmark_device_id == "device1" + assert tracker1.baseline_benchmark_time == 0.98 + assert tracker1.calibrated_benchmark_diff == pytest.approx( + (1.23 - 0.98) / 0.98, rel=1e-6 + ) + + assert tracker2.model_benchmark_time == 4.56 + assert tracker2.model_benchmark_device_id == "device2" + assert tracker2.baseline_benchmark_time == 4.13 + assert tracker2.calibrated_benchmark_diff == pytest.approx( + (4.56 - 4.13) / 4.13, rel=1e-6 + ) + + assert tracker3.model_benchmark_time == 3.38 + assert tracker3.model_benchmark_device_id == "device3" + + assert dump_list == [ + "Benchmarking: /path/to/baseline.vmfb on device device1: 0.98\n" "\n", + "Benchmarking: /path/to/model_1.vmfb on device device1: 1.23 (+25.510%)\n" + "\n", + "Benchmarking: /path/to/baseline.vmfb on device device2: 4.13\n" "\n", + "Benchmarking: /path/to/model_2.vmfb on device device2: 4.56 (+10.412%)\n" + "\n", + "Benchmarking: /path/to/model_3.vmfb on device device3: 3.38\n" "\n", + "Benchmarking result of /path/to/baseline.vmfb on device device3 is incomplete\n", + ] + + # Verify handle_error was called correctly + mock_handle_error.assert_called_once_with( + condition=True, + msg="Benchmarking result of /path/to/baseline.vmfb on device device3 is incomplete", + level=libtuner.logging.WARNING, + ) + + +def test_extract_driver_names(): + user_devices = ["hip://0", "local-sync://default", "cuda://default"] + expected_output = {"hip", "local-sync", "cuda"} + + assert libtuner.extract_driver_names(user_devices) == expected_output + + +def test_fetch_available_devices_success(): + drivers = ["hip", "local-sync", "cuda"] + mock_devices = { + "hip": [{"path": "ABCD", "device_id": 1}], + "local-sync": [{"path": "default", "device_id": 2}], + "cuda": [{"path": "default", "device_id": 3}], + } + + with patch(f"{libtuner.__name__}.ireert.get_driver") as mock_get_driver: + mock_driver = MagicMock() + + def get_mock_driver(name): + mock_driver.query_available_devices.side_effect = lambda: mock_devices[name] + return mock_driver + + mock_get_driver.side_effect = get_mock_driver + + actual_output = libtuner.fetch_available_devices(drivers) + expected_output = [ + "hip://ABCD", + "hip://0", + "local-sync://default", + "local-sync://1", + "cuda://default", + "cuda://2", + ] + + assert actual_output == expected_output + + +def test_fetch_available_devices_failure(): + drivers = ["hip", "local-sync", "cuda"] + mock_devices = { + "hip": [{"path": "ABCD", "device_id": 1}], + "local-sync": ValueError("Failed to initialize"), + "cuda": [{"path": "default", "device_id": 1}], + } + + with patch(f"{libtuner.__name__}.ireert.get_driver") as mock_get_driver: + with patch(f"{libtuner.__name__}.handle_error") as mock_handle_error: + mock_driver = MagicMock() + + def get_mock_driver(name): + if isinstance(mock_devices[name], list): + mock_driver.query_available_devices.side_effect = ( + lambda: mock_devices[name] + ) + else: + mock_driver.query_available_devices.side_effect = lambda: ( + _ for _ in () + ).throw(mock_devices[name]) + return mock_driver + + mock_get_driver.side_effect = get_mock_driver + + actual_output = libtuner.fetch_available_devices(drivers) + expected_output = ["hip://ABCD", "hip://0", "cuda://default", "cuda://0"] + + assert actual_output == expected_output + mock_handle_error.assert_called_once_with( + condition=True, + msg="Could not initialize driver local-sync: Failed to initialize", + error_type=ValueError, + exit_program=True, + ) + + +def test_parse_devices(): + user_devices_str = "hip://0, local-sync://default, cuda://default" + expected_output = ["hip://0", "local-sync://default", "cuda://default"] + + with patch(f"{libtuner.__name__}.handle_error") as mock_handle_error: + actual_output = libtuner.parse_devices(user_devices_str) + assert actual_output == expected_output + + mock_handle_error.assert_not_called() + + +def test_parse_devices_with_invalid_input(): + user_devices_str = "hip://0, local-sync://default, invalid_device, cuda://default" + expected_output = [ + "hip://0", + "local-sync://default", + "invalid_device", + "cuda://default", + ] + + with patch(f"{libtuner.__name__}.handle_error") as mock_handle_error: + actual_output = libtuner.parse_devices(user_devices_str) + assert actual_output == expected_output + + mock_handle_error.assert_called_once_with( + condition=True, + msg=f"Invalid device list: {user_devices_str}. Error: {ValueError()}", + error_type=argparse.ArgumentTypeError, + ) + + +def test_validate_devices(): + user_devices = ["hip://0", "local-sync://default"] + user_drivers = {"hip", "local-sync"} + + with patch(f"{libtuner.__name__}.extract_driver_names", return_value=user_drivers): + with patch( + f"{libtuner.__name__}.fetch_available_devices", + return_value=["hip://0", "local-sync://default"], + ): + with patch(f"{libtuner.__name__}.handle_error") as mock_handle_error: + libtuner.validate_devices(user_devices) + assert all( + call[1]["condition"] is False + for call in mock_handle_error.call_args_list + ) + + +def test_validate_devices_with_invalid_device(): + user_devices = ["hip://0", "local-sync://default", "cuda://default"] + user_drivers = {"hip", "local-sync", "cuda"} + + with patch(f"{libtuner.__name__}.extract_driver_names", return_value=user_drivers): + with patch( + f"{libtuner.__name__}.fetch_available_devices", + return_value=["hip://0", "local-sync://default"], + ): + with patch(f"{libtuner.__name__}.handle_error") as mock_handle_error: + libtuner.validate_devices(user_devices) + expected_call = call( + condition=True, + msg=f"Invalid device specified: cuda://default\nFetched available devices: ['hip://0', 'local-sync://default']", + error_type=argparse.ArgumentError, + exit_program=True, + ) + assert expected_call in mock_handle_error.call_args_list