Skip to content

Commit

Permalink
[tuner] Fix new lint err (#93)
Browse files Browse the repository at this point in the history
Fix new lint errors for #60
  • Loading branch information
RattataKing authored Aug 6, 2024
1 parent fa34a1b commit cd5dbea
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 30 deletions.
6 changes: 3 additions & 3 deletions tuning/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class CandidateTracker:
calibrated_benchmark_diff: Optional[float] = None



@dataclass(frozen=True)
class PathConfig:
# Preset constants
Expand Down Expand Up @@ -300,7 +299,7 @@ def fetch_available_devices(drivers: list[str]) -> list[str]:
Extract all available devices on the user's machine for the provided drivers
Only the user provided drivers will be queried
"""
all_device_ids = []
all_device_ids: list[str] = []

for driver_name in drivers:
try:
Expand All @@ -319,6 +318,7 @@ def fetch_available_devices(drivers: list[str]) -> list[str]:

return all_device_ids


def parse_devices(devices_str: str) -> list[str]:
"""
Parse a comma-separated list of device IDs e.g.:
Expand All @@ -339,7 +339,7 @@ def validate_devices(user_devices: list[str]) -> None:
"""Validates the user provided devices against the devices extracted by the IREE Runtime"""
user_drivers = extract_driver_names(user_devices)

available_devices = fetch_available_devices(user_drivers)
available_devices = fetch_available_devices(list(user_drivers))

for device in user_devices:
handle_error(
Expand Down
82 changes: 55 additions & 27 deletions tuning/test_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@

def test_group_benchmark_results_by_device_id():
def generate_res(res_arg: str, device_id: int) -> autotune.TaskResult:
result: autotune.subprocess.CompletedProcess = autotune.subprocess.CompletedProcess(
args=[res_arg],
returncode=0,
result: autotune.subprocess.CompletedProcess = (
autotune.subprocess.CompletedProcess(
args=[res_arg],
returncode=0,
)
)
return autotune.TaskResult(result=result, device_id=device_id)

Expand Down Expand Up @@ -365,58 +367,66 @@ def set_tracker(
assert (
dump_list == expect_dump_list
), "fail to parse incomplete baseline and candidates"


def test_extract_driver_names():
user_devices = ["hip://0", "local-sync://default", "cuda://default"]
expected_output = {"hip", "local-sync", "cuda"}

assert autotune.extract_driver_names(user_devices) == expected_output


def test_fetch_available_devices_success():
drivers = ["hip", "local-sync", "cuda"]
mock_devices = {
"hip": [{"path": "0"}],
"local-sync": [{"path": "default"}],
"cuda": [{"path": "default"}]
"cuda": [{"path": "default"}],
}

with patch("autotune.ireert.get_driver") as mock_get_driver:
mock_driver = MagicMock()

def get_mock_driver(name):
mock_driver.query_available_devices.side_effect = lambda: mock_devices[name]
return mock_driver

mock_get_driver.side_effect = get_mock_driver

actual_output = autotune.fetch_available_devices(drivers)
expected_output = ["hip://0", "local-sync://default", "cuda://default"]

assert actual_output == expected_output


def test_fetch_available_devices_failure():
drivers = ["hip", "local-sync", "cuda"]
mock_devices = {
"hip": [{"path": "0"}],
"local-sync": ValueError("Failed to initialize"),
"cuda": [{"path": "default"}]
"cuda": [{"path": "default"}],
}

with patch("autotune.ireert.get_driver") as mock_get_driver:
with patch("autotune.handle_error") as mock_handle_error:
mock_driver = MagicMock()

def get_mock_driver(name):
if isinstance(mock_devices[name], list):
mock_driver.query_available_devices.side_effect = lambda: mock_devices[name]
mock_driver.query_available_devices.side_effect = (
lambda: mock_devices[name]
)
else:
mock_driver.query_available_devices.side_effect = lambda: (_ for _ in ()).throw(mock_devices[name])
mock_driver.query_available_devices.side_effect = lambda: (
_ for _ in ()
).throw(mock_devices[name])
return mock_driver

mock_get_driver.side_effect = get_mock_driver

actual_output = autotune.fetch_available_devices(drivers)
expected_output = ["hip://0", "cuda://default"]

assert actual_output == expected_output
mock_handle_error.assert_called_once_with(
condition=True,
Expand All @@ -425,46 +435,64 @@ def get_mock_driver(name):
exit_program=True,
)


def test_parse_devices():
user_devices_str = "hip://0, local-sync://default, cuda://default"
expected_output = ["hip://0", "local-sync://default", "cuda://default"]

with patch("autotune.handle_error") as mock_handle_error:
actual_output = autotune.parse_devices(user_devices_str)
assert actual_output == expected_output

mock_handle_error.assert_not_called()


def test_parse_devices_with_invalid_input():
user_devices_str = "hip://0, local-sync://default, invalid_device, cuda://default"
expected_output = ["hip://0", "local-sync://default", "invalid_device", "cuda://default"]

expected_output = [
"hip://0",
"local-sync://default",
"invalid_device",
"cuda://default",
]

with patch("autotune.handle_error") as mock_handle_error:
actual_output = autotune.parse_devices(user_devices_str)
assert actual_output == expected_output

mock_handle_error.assert_called_once_with(
condition=True,
msg=f"Invalid device list: {user_devices_str}. Error: {ValueError()}",
error_type=argparse.ArgumentTypeError,
)


def test_validate_devices():
user_devices = ["hip://0", "local-sync://default"]
user_drivers = {"hip", "local-sync"}

with patch('autotune.extract_driver_names', return_value=user_drivers):
with patch('autotune.fetch_available_devices', return_value=["hip://0", "local-sync://default"]):
with patch('autotune.handle_error') as mock_handle_error:
with patch("autotune.extract_driver_names", return_value=user_drivers):
with patch(
"autotune.fetch_available_devices",
return_value=["hip://0", "local-sync://default"],
):
with patch("autotune.handle_error") as mock_handle_error:
autotune.validate_devices(user_devices)
assert all(call[1]['condition'] is False for call in mock_handle_error.call_args_list)
assert all(
call[1]["condition"] is False
for call in mock_handle_error.call_args_list
)


def test_validate_devices_with_invalid_device():
user_devices = ["hip://0", "local-sync://default", "cuda://default"]
user_drivers = {"hip", "local-sync", "cuda"}

with patch("autotune.extract_driver_names", return_value=user_drivers):
with patch("autotune.fetch_available_devices", return_value=["hip://0", "local-sync://default"]):
with patch(
"autotune.fetch_available_devices",
return_value=["hip://0", "local-sync://default"],
):
with patch("autotune.handle_error") as mock_handle_error:
autotune.validate_devices(user_devices)
expected_call = call(
Expand Down

0 comments on commit cd5dbea

Please sign in to comment.