From 72c58a628aee0599b525d36ac51c1319455fb06c Mon Sep 17 00:00:00 2001 From: Mihaescu Vlad <52869843+mihaescuvlad@users.noreply.github.com> Date: Sat, 13 Jul 2024 17:59:09 +0300 Subject: [PATCH] Add device validation for user devices Initial changes did not fully commit Optimized validation & implemented CR suggestions Fix merge issues Add device validation for user devices Initial changes did not fully commit Optimized validation & implemented CR suggestions Add device validation for user devices Optimized validation & implemented CR suggestions Fix merge issues Add device validation for user devices Optimized validation & implemented CR suggestions Add unit tests for autotuning Update benchmark scripts to take in any device --- tuning/autotune.py | 76 ++++++++++++++++---- tuning/benchmark_dispatch.sh | 2 +- tuning/benchmark_unet_candidate.sh | 2 +- tuning/test_autotune.py | 111 +++++++++++++++++++++++++++++ 4 files changed, 177 insertions(+), 14 deletions(-) diff --git a/tuning/autotune.py b/tuning/autotune.py index e90728c..cd956a2 100755 --- a/tuning/autotune.py +++ b/tuning/autotune.py @@ -19,6 +19,7 @@ from typing import Type, Optional, Callable, Iterable, Any import pickle from itertools import groupby +import iree.runtime as ireert import random """ @@ -41,7 +42,7 @@ # Default values for num_candidates and devices, change it as needed DEFAULT_NUM_CANDIDATES = 2048 -DEFAULT_DEVICE_LIST = [0] +DEFAULT_DEVICE_LIST = ["hip://0"] # Default values for max number of workers DEFAULT_MAX_CPU_WORKERS = ( @@ -74,6 +75,7 @@ class CandidateTracker: calibrated_benchmark_diff: Optional[float] = None + @dataclass(frozen=True) class PathConfig: # Preset constants @@ -288,18 +290,64 @@ def generate_sample_result( return f"Benchmarking: {candidate_vmfb_path_str} on device {device_id}\nBM_run_forward/process_time/real_time_median\t {t1:.3g} ms\t {(t1+1):.3g} ms\t 5 items_per_second={t1/200:5f}/s\n\n" -def parse_devices(devices_str: str) -> list[int]: - """Parse a comma-separated list of device IDs (e.g., "1,3,5" -> [1, 3, 5]).""" - devices = [] - try: - devices = [int(device.strip()) for device in devices_str.split(",")] - except ValueError as e: +def extract_driver_names(user_devices: list[str]) -> set[str]: + """Extract driver names from the user devices""" + return {device.split("://")[0] for device in user_devices} + + +def fetch_available_devices(drivers: list[str]) -> list[str]: + """ + Extract all available devices on the user's machine for the provided drivers + Only the user provided drivers will be queried + """ + all_device_ids = [] + + for driver_name in drivers: + try: + driver = ireert.get_driver(driver_name) + devices = driver.query_available_devices() + all_device_ids.extend( + f"{driver_name}://{device['path']}" for device in devices + ) + except ValueError as e: + handle_error( + condition=True, + msg=f"Could not initialize driver {driver_name}: {e}", + error_type=ValueError, + exit_program=True, + ) + + return all_device_ids + +def parse_devices(devices_str: str) -> list[str]: + """ + Parse a comma-separated list of device IDs e.g.: + --devices=hip://0,local-sync://default -> ["hip://0", "local-sync://default"]). + """ + devices = [device.strip() for device in devices_str.split(",")] + for device in devices: + if "://" not in device or not device: + handle_error( + condition=True, + msg=f"Invalid device list: {devices_str}. Error: {ValueError()}", + error_type=argparse.ArgumentTypeError, + ) + return devices + + +def validate_devices(user_devices: list[str]) -> None: + """Validates the user provided devices against the devices extracted by the IREE Runtime""" + user_drivers = extract_driver_names(user_devices) + + available_devices = fetch_available_devices(user_drivers) + + for device in user_devices: handle_error( - condition=True, - msg=f"Invalid device list: {devices_str}. Error: {e}", - error_type=argparse.ArgumentTypeError, + condition=(device not in available_devices), + msg=f"Invalid device specified: {device}", + error_type=argparse.ArgumentError, + exit_program=True, ) - return devices class ExecutionPhases(str, Enum): @@ -330,7 +378,7 @@ def parse_arguments() -> argparse.Namespace: "--devices", type=parse_devices, default=DEFAULT_DEVICE_LIST, - help="Comma-separated list of device IDs (e.g., --devices=0,1). Default: [0]", + help="Comma-separated list of device IDs (e.g., --devices=hip://,hip://GPU-UUID).", ) parser.add_argument( "--max-cpu-workers", @@ -1240,6 +1288,10 @@ def autotune(args: argparse.Namespace) -> None: setup_logging(args, path_config) print(path_config.log_file_path, end="\n\n") + print("Validating devices") + validate_devices(args.devices) + print("Validation successful!") + print("Generating candidates...") candidates = generate_candidates(args, path_config, candidate_trackers) print(f"Generated [{len(candidates)}] candidates in {path_config.candidates_dir}\n") diff --git a/tuning/benchmark_dispatch.sh b/tuning/benchmark_dispatch.sh index 2e4decb..dd8f29c 100755 --- a/tuning/benchmark_dispatch.sh +++ b/tuning/benchmark_dispatch.sh @@ -12,7 +12,7 @@ readonly NAME="$(basename "$INPUT" .mlir)" # printf "Benchmarking $(basename ${INPUT}) on ${DEVICE}\n" -timeout 16s ./tools/iree-benchmark-module --device="hip://${DEVICE}" --module="${INPUT}" \ +timeout 16s ./tools/iree-benchmark-module --device="${DEVICE}" --module="${INPUT}" \ --hip_use_streams=true --hip_allow_inline_execution=true \ --batch_size=1000 --benchmark_repetitions=3 > "${DIR}/benchmark_log_${DEVICE}.out" 2>&1 || (mv "$INPUT" "${DIR}/benchmark_failed" && exit 0) diff --git a/tuning/benchmark_unet_candidate.sh b/tuning/benchmark_unet_candidate.sh index 1fa3bf4..abc908c 100755 --- a/tuning/benchmark_unet_candidate.sh +++ b/tuning/benchmark_unet_candidate.sh @@ -9,7 +9,7 @@ shift 2 echo "Benchmarking: ${INPUT} on device ${DEVICE}" timeout 180s tools/iree-benchmark-module \ - --device="hip://${DEVICE}" \ + --device="${DEVICE}" \ --hip_use_streams=true \ --hip_allow_inline_execution=true \ --device_allocator=caching \ diff --git a/tuning/test_autotune.py b/tuning/test_autotune.py index 8b0c7ea..adabc3c 100644 --- a/tuning/test_autotune.py +++ b/tuning/test_autotune.py @@ -1,4 +1,6 @@ +import argparse import pytest +from unittest.mock import call, patch, MagicMock import autotune """ @@ -363,3 +365,112 @@ def set_tracker( assert ( dump_list == expect_dump_list ), "fail to parse incomplete baseline and candidates" +def test_extract_driver_names(): + user_devices = ["hip://0", "local-sync://default", "cuda://default"] + expected_output = {"hip", "local-sync", "cuda"} + + assert autotune.extract_driver_names(user_devices) == expected_output + +def test_fetch_available_devices_success(): + drivers = ["hip", "local-sync", "cuda"] + mock_devices = { + "hip": [{"path": "0"}], + "local-sync": [{"path": "default"}], + "cuda": [{"path": "default"}] + } + + with patch("autotune.ireert.get_driver") as mock_get_driver: + mock_driver = MagicMock() + + def get_mock_driver(name): + mock_driver.query_available_devices.side_effect = lambda: mock_devices[name] + return mock_driver + + mock_get_driver.side_effect = get_mock_driver + + actual_output = autotune.fetch_available_devices(drivers) + expected_output = ["hip://0", "local-sync://default", "cuda://default"] + + assert actual_output == expected_output + +def test_fetch_available_devices_failure(): + drivers = ["hip", "local-sync", "cuda"] + mock_devices = { + "hip": [{"path": "0"}], + "local-sync": ValueError("Failed to initialize"), + "cuda": [{"path": "default"}] + } + + with patch("autotune.ireert.get_driver") as mock_get_driver: + with patch("autotune.handle_error") as mock_handle_error: + mock_driver = MagicMock() + + def get_mock_driver(name): + if isinstance(mock_devices[name], list): + mock_driver.query_available_devices.side_effect = lambda: mock_devices[name] + else: + mock_driver.query_available_devices.side_effect = lambda: (_ for _ in ()).throw(mock_devices[name]) + return mock_driver + + mock_get_driver.side_effect = get_mock_driver + + actual_output = autotune.fetch_available_devices(drivers) + expected_output = ["hip://0", "cuda://default"] + + assert actual_output == expected_output + mock_handle_error.assert_called_once_with( + condition=True, + msg="Could not initialize driver local-sync: Failed to initialize", + error_type=ValueError, + exit_program=True, + ) + +def test_parse_devices(): + user_devices_str = "hip://0, local-sync://default, cuda://default" + expected_output = ["hip://0", "local-sync://default", "cuda://default"] + + with patch("autotune.handle_error") as mock_handle_error: + actual_output = autotune.parse_devices(user_devices_str) + assert actual_output == expected_output + + mock_handle_error.assert_not_called() + +def test_parse_devices_with_invalid_input(): + user_devices_str = "hip://0, local-sync://default, invalid_device, cuda://default" + expected_output = ["hip://0", "local-sync://default", "invalid_device", "cuda://default"] + + with patch("autotune.handle_error") as mock_handle_error: + actual_output = autotune.parse_devices(user_devices_str) + assert actual_output == expected_output + + mock_handle_error.assert_called_once_with( + condition=True, + msg=f"Invalid device list: {user_devices_str}. Error: {ValueError()}", + error_type=argparse.ArgumentTypeError, + ) + +def test_validate_devices(): + user_devices = ["hip://0", "local-sync://default"] + user_drivers = {"hip", "local-sync"} + + with patch('autotune.extract_driver_names', return_value=user_drivers): + with patch('autotune.fetch_available_devices', return_value=["hip://0", "local-sync://default"]): + with patch('autotune.handle_error') as mock_handle_error: + autotune.validate_devices(user_devices) + assert all(call[1]['condition'] is False for call in mock_handle_error.call_args_list) + +def test_validate_devices_with_invalid_device(): + user_devices = ["hip://0", "local-sync://default", "cuda://default"] + user_drivers = {"hip", "local-sync", "cuda"} + + with patch("autotune.extract_driver_names", return_value=user_drivers): + with patch("autotune.fetch_available_devices", return_value=["hip://0", "local-sync://default"]): + with patch("autotune.handle_error") as mock_handle_error: + autotune.validate_devices(user_devices) + expected_call = call( + condition=True, + msg=f"Invalid device specified: cuda://default", + error_type=argparse.ArgumentError, + exit_program=True, + ) + assert expected_call in mock_handle_error.call_args_list