diff --git a/utils/testing.py b/utils/testing.py index b122554b..19b16a54 100644 --- a/utils/testing.py +++ b/utils/testing.py @@ -299,7 +299,7 @@ def parameterized_list(cases: list[T]) -> list[tuple[str, T]]: def unstable_test( max_retries: int = 2, - error_class: Type[BaseException] = AssertionError, + error_class: Type[BaseException] | tuple[Type[BaseException], ...] = AssertionError, wait_between_runs: float = 0, ) -> Callable[ [Callable[Concatenate[unittest.TestCase, P], T]], Callable[Concatenate[unittest.TestCase, P], T] @@ -322,14 +322,14 @@ def decorator( def wrapper(self: unittest.TestCase, *args: P.args, **kwargs: P.kwargs) -> T: try: return func(self, *args, **kwargs) # Initial attempt to run the test "normally" - except (error_class,): + except error_class: last_error = None for _ in range(max_retries): sleep(wait_between_runs) try: self.setUp() # Need to rerun setup return func(self, *args, **kwargs) - except (error_class,) as e: + except error_class as e: last_error = e finally: self.tearDown() # Rerun tearDown regardless of success or not