Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tuner] Add device validation for user devices #60

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 64 additions & 12 deletions tuning/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Type, Optional, Callable, Iterable, Any
import pickle
from itertools import groupby
import iree.runtime as ireert
import random

"""
Expand All @@ -41,7 +42,7 @@

# Default values for num_candidates and devices, change it as needed
DEFAULT_NUM_CANDIDATES = 2048
DEFAULT_DEVICE_LIST = [0]
DEFAULT_DEVICE_LIST = ["hip://0"]

# Default values for max number of workers
DEFAULT_MAX_CPU_WORKERS = (
Expand Down Expand Up @@ -74,6 +75,7 @@ class CandidateTracker:
calibrated_benchmark_diff: Optional[float] = None



@dataclass(frozen=True)
class PathConfig:
# Preset constants
Expand Down Expand Up @@ -288,18 +290,64 @@ def generate_sample_result(
return f"Benchmarking: {candidate_vmfb_path_str} on device {device_id}\nBM_run_forward/process_time/real_time_median\t {t1:.3g} ms\t {(t1+1):.3g} ms\t 5 items_per_second={t1/200:5f}/s\n\n"


def parse_devices(devices_str: str) -> list[int]:
"""Parse a comma-separated list of device IDs (e.g., "1,3,5" -> [1, 3, 5])."""
devices = []
try:
devices = [int(device.strip()) for device in devices_str.split(",")]
except ValueError as e:
def extract_driver_names(user_devices: list[str]) -> set[str]:
"""Extract driver names from the user devices"""
return {device.split("://")[0] for device in user_devices}


def fetch_available_devices(drivers: list[str]) -> list[str]:
"""
Extract all available devices on the user's machine for the provided drivers
Only the user provided drivers will be queried
"""
all_device_ids = []

for driver_name in drivers:
try:
driver = ireert.get_driver(driver_name)
devices = driver.query_available_devices()
all_device_ids.extend(
f"{driver_name}://{device['path']}" for device in devices
)
except ValueError as e:
handle_error(
condition=True,
msg=f"Could not initialize driver {driver_name}: {e}",
error_type=ValueError,
exit_program=True,
)

return all_device_ids

def parse_devices(devices_str: str) -> list[str]:
"""
Parse a comma-separated list of device IDs e.g.:
--devices=hip://0,local-sync://default -> ["hip://0", "local-sync://default"]).
"""
devices = [device.strip() for device in devices_str.split(",")]
for device in devices:
if "://" not in device or not device:
handle_error(
condition=True,
msg=f"Invalid device list: {devices_str}. Error: {ValueError()}",
error_type=argparse.ArgumentTypeError,
)
return devices


def validate_devices(user_devices: list[str]) -> None:
"""Validates the user provided devices against the devices extracted by the IREE Runtime"""
user_drivers = extract_driver_names(user_devices)

mihaescuvlad marked this conversation as resolved.
Show resolved Hide resolved
available_devices = fetch_available_devices(user_drivers)

for device in user_devices:
handle_error(
condition=True,
msg=f"Invalid device list: {devices_str}. Error: {e}",
error_type=argparse.ArgumentTypeError,
condition=(device not in available_devices),
msg=f"Invalid device specified: {device}",
error_type=argparse.ArgumentError,
exit_program=True,
)
return devices


class ExecutionPhases(str, Enum):
Expand Down Expand Up @@ -330,7 +378,7 @@ def parse_arguments() -> argparse.Namespace:
"--devices",
type=parse_devices,
default=DEFAULT_DEVICE_LIST,
help="Comma-separated list of device IDs (e.g., --devices=0,1). Default: [0]",
help="Comma-separated list of device IDs (e.g., --devices=hip://,hip://GPU-UUID).",
)
parser.add_argument(
"--max-cpu-workers",
Expand Down Expand Up @@ -1240,6 +1288,10 @@ def autotune(args: argparse.Namespace) -> None:
setup_logging(args, path_config)
print(path_config.log_file_path, end="\n\n")

print("Validating devices")
validate_devices(args.devices)
print("Validation successful!")

print("Generating candidates...")
candidates = generate_candidates(args, path_config, candidate_trackers)
print(f"Generated [{len(candidates)}] candidates in {path_config.candidates_dir}\n")
Expand Down
2 changes: 1 addition & 1 deletion tuning/benchmark_dispatch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ readonly NAME="$(basename "$INPUT" .mlir)"

# printf "Benchmarking $(basename ${INPUT}) on ${DEVICE}\n"

timeout 16s ./tools/iree-benchmark-module --device="hip://${DEVICE}" --module="${INPUT}" \
timeout 16s ./tools/iree-benchmark-module --device="${DEVICE}" --module="${INPUT}" \
--hip_use_streams=true --hip_allow_inline_execution=true \
--batch_size=1000 --benchmark_repetitions=3 > "${DIR}/benchmark_log_${DEVICE}.out" 2>&1 || (mv "$INPUT" "${DIR}/benchmark_failed" && exit 0)

Expand Down
2 changes: 1 addition & 1 deletion tuning/benchmark_unet_candidate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ shift 2
echo "Benchmarking: ${INPUT} on device ${DEVICE}"

timeout 180s tools/iree-benchmark-module \
--device="hip://${DEVICE}" \
--device="${DEVICE}" \
--hip_use_streams=true \
--hip_allow_inline_execution=true \
--device_allocator=caching \
Expand Down
111 changes: 111 additions & 0 deletions tuning/test_autotune.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import argparse
import pytest
from unittest.mock import call, patch, MagicMock
import autotune

"""
Expand Down Expand Up @@ -363,3 +365,112 @@ def set_tracker(
assert (
dump_list == expect_dump_list
), "fail to parse incomplete baseline and candidates"
def test_extract_driver_names():
user_devices = ["hip://0", "local-sync://default", "cuda://default"]
expected_output = {"hip", "local-sync", "cuda"}

assert autotune.extract_driver_names(user_devices) == expected_output

def test_fetch_available_devices_success():
drivers = ["hip", "local-sync", "cuda"]
mock_devices = {
"hip": [{"path": "0"}],
"local-sync": [{"path": "default"}],
"cuda": [{"path": "default"}]
}

with patch("autotune.ireert.get_driver") as mock_get_driver:
mock_driver = MagicMock()

def get_mock_driver(name):
mock_driver.query_available_devices.side_effect = lambda: mock_devices[name]
return mock_driver

mock_get_driver.side_effect = get_mock_driver

actual_output = autotune.fetch_available_devices(drivers)
expected_output = ["hip://0", "local-sync://default", "cuda://default"]

assert actual_output == expected_output

def test_fetch_available_devices_failure():
drivers = ["hip", "local-sync", "cuda"]
mock_devices = {
"hip": [{"path": "0"}],
"local-sync": ValueError("Failed to initialize"),
"cuda": [{"path": "default"}]
}

with patch("autotune.ireert.get_driver") as mock_get_driver:
with patch("autotune.handle_error") as mock_handle_error:
mock_driver = MagicMock()

def get_mock_driver(name):
if isinstance(mock_devices[name], list):
mock_driver.query_available_devices.side_effect = lambda: mock_devices[name]
else:
mock_driver.query_available_devices.side_effect = lambda: (_ for _ in ()).throw(mock_devices[name])
return mock_driver

mock_get_driver.side_effect = get_mock_driver

actual_output = autotune.fetch_available_devices(drivers)
expected_output = ["hip://0", "cuda://default"]

assert actual_output == expected_output
mock_handle_error.assert_called_once_with(
condition=True,
msg="Could not initialize driver local-sync: Failed to initialize",
error_type=ValueError,
exit_program=True,
)

def test_parse_devices():
user_devices_str = "hip://0, local-sync://default, cuda://default"
expected_output = ["hip://0", "local-sync://default", "cuda://default"]

with patch("autotune.handle_error") as mock_handle_error:
actual_output = autotune.parse_devices(user_devices_str)
assert actual_output == expected_output

mock_handle_error.assert_not_called()

def test_parse_devices_with_invalid_input():
user_devices_str = "hip://0, local-sync://default, invalid_device, cuda://default"
expected_output = ["hip://0", "local-sync://default", "invalid_device", "cuda://default"]

with patch("autotune.handle_error") as mock_handle_error:
actual_output = autotune.parse_devices(user_devices_str)
assert actual_output == expected_output

mock_handle_error.assert_called_once_with(
condition=True,
msg=f"Invalid device list: {user_devices_str}. Error: {ValueError()}",
error_type=argparse.ArgumentTypeError,
)

def test_validate_devices():
user_devices = ["hip://0", "local-sync://default"]
user_drivers = {"hip", "local-sync"}

with patch('autotune.extract_driver_names', return_value=user_drivers):
with patch('autotune.fetch_available_devices', return_value=["hip://0", "local-sync://default"]):
with patch('autotune.handle_error') as mock_handle_error:
autotune.validate_devices(user_devices)
assert all(call[1]['condition'] is False for call in mock_handle_error.call_args_list)

def test_validate_devices_with_invalid_device():
user_devices = ["hip://0", "local-sync://default", "cuda://default"]
user_drivers = {"hip", "local-sync", "cuda"}

with patch("autotune.extract_driver_names", return_value=user_drivers):
with patch("autotune.fetch_available_devices", return_value=["hip://0", "local-sync://default"]):
with patch("autotune.handle_error") as mock_handle_error:
autotune.validate_devices(user_devices)
expected_call = call(
condition=True,
msg=f"Invalid device specified: cuda://default",
error_type=argparse.ArgumentError,
exit_program=True,
)
assert expected_call in mock_handle_error.call_args_list