diff --git a/tuner/tuner/__init__.py b/tuner/tuner/__init__.py index 86b8fa1bf..1570edaa1 100644 --- a/tuner/tuner/__init__.py +++ b/tuner/tuner/__init__.py @@ -4,7 +4,8 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from iree.compiler import ir +from iree.compiler import ir # type: ignore + # substitute replace=True so that colliding registration don't error def register_attribute_builder(kind, replace=True): @@ -14,4 +15,5 @@ def decorator_builder(func): return decorator_builder + ir.register_attribute_builder = register_attribute_builder diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index feeb66e04..6d76c216f 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -217,10 +217,3 @@ def test_get_lowering_config(tuner_ctx: common.TunerContext) -> None: assert config.lowering_config.mma_kind is None assert config.lowering_config.subgroup_count_mn == (1, 1) - - -def test_enum_collision(): - from iree.compiler.dialects import linalg, vector - - linalg_iter_type_e = linalg._iteratortype(0, None) - vector_iter_type_e = vector._vector_iteratortype(0, None) diff --git a/tuner/tuner/libtuner_test.py b/tuner/tuner/libtuner_test.py index 11af59af4..274fd8cae 100644 --- a/tuner/tuner/libtuner_test.py +++ b/tuner/tuner/libtuner_test.py @@ -499,3 +499,7 @@ def test_validate_devices_with_invalid_device() -> None: exit_program=True, ) assert expected_call in mock_handle_error.call_args_list + + +def test_enum_collision(): + from iree.compiler.dialects import linalg, vector