From cd5dbea85093d63353c0bf0e5b5f9f44935f5401 Mon Sep 17 00:00:00 2001 From: RattataKing <46631728+RattataKing@users.noreply.github.com> Date: Tue, 6 Aug 2024 12:03:57 -0400 Subject: [PATCH] [tuner] Fix new lint err (#93) Fix new lint errors for #60 --- tuning/autotune.py | 6 +-- tuning/test_autotune.py | 82 +++++++++++++++++++++++++++-------------- 2 files changed, 58 insertions(+), 30 deletions(-) diff --git a/tuning/autotune.py b/tuning/autotune.py index cd956a2..a6ddda8 100755 --- a/tuning/autotune.py +++ b/tuning/autotune.py @@ -75,7 +75,6 @@ class CandidateTracker: calibrated_benchmark_diff: Optional[float] = None - @dataclass(frozen=True) class PathConfig: # Preset constants @@ -300,7 +299,7 @@ def fetch_available_devices(drivers: list[str]) -> list[str]: Extract all available devices on the user's machine for the provided drivers Only the user provided drivers will be queried """ - all_device_ids = [] + all_device_ids: list[str] = [] for driver_name in drivers: try: @@ -319,6 +318,7 @@ def fetch_available_devices(drivers: list[str]) -> list[str]: return all_device_ids + def parse_devices(devices_str: str) -> list[str]: """ Parse a comma-separated list of device IDs e.g.: @@ -339,7 +339,7 @@ def validate_devices(user_devices: list[str]) -> None: """Validates the user provided devices against the devices extracted by the IREE Runtime""" user_drivers = extract_driver_names(user_devices) - available_devices = fetch_available_devices(user_drivers) + available_devices = fetch_available_devices(list(user_drivers)) for device in user_devices: handle_error( diff --git a/tuning/test_autotune.py b/tuning/test_autotune.py index adabc3c..a2bac83 100644 --- a/tuning/test_autotune.py +++ b/tuning/test_autotune.py @@ -10,9 +10,11 @@ def test_group_benchmark_results_by_device_id(): def generate_res(res_arg: str, device_id: int) -> autotune.TaskResult: - result: autotune.subprocess.CompletedProcess = autotune.subprocess.CompletedProcess( - args=[res_arg], - returncode=0, + result: autotune.subprocess.CompletedProcess = ( + autotune.subprocess.CompletedProcess( + args=[res_arg], + returncode=0, + ) ) return autotune.TaskResult(result=result, device_id=device_id) @@ -365,58 +367,66 @@ def set_tracker( assert ( dump_list == expect_dump_list ), "fail to parse incomplete baseline and candidates" + + def test_extract_driver_names(): user_devices = ["hip://0", "local-sync://default", "cuda://default"] expected_output = {"hip", "local-sync", "cuda"} assert autotune.extract_driver_names(user_devices) == expected_output + def test_fetch_available_devices_success(): drivers = ["hip", "local-sync", "cuda"] mock_devices = { "hip": [{"path": "0"}], "local-sync": [{"path": "default"}], - "cuda": [{"path": "default"}] + "cuda": [{"path": "default"}], } - + with patch("autotune.ireert.get_driver") as mock_get_driver: mock_driver = MagicMock() - + def get_mock_driver(name): mock_driver.query_available_devices.side_effect = lambda: mock_devices[name] return mock_driver - + mock_get_driver.side_effect = get_mock_driver - + actual_output = autotune.fetch_available_devices(drivers) expected_output = ["hip://0", "local-sync://default", "cuda://default"] - + assert actual_output == expected_output + def test_fetch_available_devices_failure(): drivers = ["hip", "local-sync", "cuda"] mock_devices = { "hip": [{"path": "0"}], "local-sync": ValueError("Failed to initialize"), - "cuda": [{"path": "default"}] + "cuda": [{"path": "default"}], } - + with patch("autotune.ireert.get_driver") as mock_get_driver: with patch("autotune.handle_error") as mock_handle_error: mock_driver = MagicMock() - + def get_mock_driver(name): if isinstance(mock_devices[name], list): - mock_driver.query_available_devices.side_effect = lambda: mock_devices[name] + mock_driver.query_available_devices.side_effect = ( + lambda: mock_devices[name] + ) else: - mock_driver.query_available_devices.side_effect = lambda: (_ for _ in ()).throw(mock_devices[name]) + mock_driver.query_available_devices.side_effect = lambda: ( + _ for _ in () + ).throw(mock_devices[name]) return mock_driver - + mock_get_driver.side_effect = get_mock_driver - + actual_output = autotune.fetch_available_devices(drivers) expected_output = ["hip://0", "cuda://default"] - + assert actual_output == expected_output mock_handle_error.assert_called_once_with( condition=True, @@ -425,46 +435,64 @@ def get_mock_driver(name): exit_program=True, ) + def test_parse_devices(): user_devices_str = "hip://0, local-sync://default, cuda://default" expected_output = ["hip://0", "local-sync://default", "cuda://default"] - + with patch("autotune.handle_error") as mock_handle_error: actual_output = autotune.parse_devices(user_devices_str) assert actual_output == expected_output - + mock_handle_error.assert_not_called() + def test_parse_devices_with_invalid_input(): user_devices_str = "hip://0, local-sync://default, invalid_device, cuda://default" - expected_output = ["hip://0", "local-sync://default", "invalid_device", "cuda://default"] - + expected_output = [ + "hip://0", + "local-sync://default", + "invalid_device", + "cuda://default", + ] + with patch("autotune.handle_error") as mock_handle_error: actual_output = autotune.parse_devices(user_devices_str) assert actual_output == expected_output - + mock_handle_error.assert_called_once_with( condition=True, msg=f"Invalid device list: {user_devices_str}. Error: {ValueError()}", error_type=argparse.ArgumentTypeError, ) + def test_validate_devices(): user_devices = ["hip://0", "local-sync://default"] user_drivers = {"hip", "local-sync"} - with patch('autotune.extract_driver_names', return_value=user_drivers): - with patch('autotune.fetch_available_devices', return_value=["hip://0", "local-sync://default"]): - with patch('autotune.handle_error') as mock_handle_error: + with patch("autotune.extract_driver_names", return_value=user_drivers): + with patch( + "autotune.fetch_available_devices", + return_value=["hip://0", "local-sync://default"], + ): + with patch("autotune.handle_error") as mock_handle_error: autotune.validate_devices(user_devices) - assert all(call[1]['condition'] is False for call in mock_handle_error.call_args_list) + assert all( + call[1]["condition"] is False + for call in mock_handle_error.call_args_list + ) + def test_validate_devices_with_invalid_device(): user_devices = ["hip://0", "local-sync://default", "cuda://default"] user_drivers = {"hip", "local-sync", "cuda"} with patch("autotune.extract_driver_names", return_value=user_drivers): - with patch("autotune.fetch_available_devices", return_value=["hip://0", "local-sync://default"]): + with patch( + "autotune.fetch_available_devices", + return_value=["hip://0", "local-sync://default"], + ): with patch("autotune.handle_error") as mock_handle_error: autotune.validate_devices(user_devices) expected_call = call(