From 22b19d578e57f9b152eef4444738da68bbb33ce7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 25 Sep 2023 13:24:44 +0200 Subject: [PATCH] [Tests] Add is flaky decorator (#5139) * add is flaky decorator * fix more --- src/diffusers/utils/testing_utils.py | 40 +++++++++++++++++++ .../text_to_video/test_video_to_video.py | 6 +++ 2 files changed, 46 insertions(+) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index a2f4de439e11..1079299db84d 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -1,3 +1,4 @@ +import functools import importlib import inspect import io @@ -7,7 +8,9 @@ import random import re import struct +import sys import tempfile +import time import unittest import urllib.parse from contextlib import contextmanager @@ -612,6 +615,43 @@ def summary_failures_short(tr): config.option.tbstyle = orig_tbstyle +# Copied from https://github.com/huggingface/transformers/blob/000e52aec8850d3fe2f360adc6fd256e5b47fe4c/src/transformers/testing_utils.py#L1905 +def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None): + """ + To decorate flaky tests. They will be retried on failures. + + Args: + max_attempts (`int`, *optional*, defaults to 5): + The maximum number of attempts to retry the flaky test. + wait_before_retry (`float`, *optional*): + If provided, will wait that number of seconds before retrying the test. + description (`str`, *optional*): + A string to describe the situation (what / where / why is flaky, link to GH issue/PR comments, errors, + etc.) + """ + + def decorator(test_func_ref): + @functools.wraps(test_func_ref) + def wrapper(*args, **kwargs): + retry_count = 1 + + while retry_count < max_attempts: + try: + return test_func_ref(*args, **kwargs) + + except Exception as err: + print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr) + if wait_before_retry is not None: + time.sleep(wait_before_retry) + retry_count += 1 + + return test_func_ref(*args, **kwargs) + + return wrapper + + return decorator + + # Taken from: https://github.com/huggingface/transformers/blob/3658488ff77ff8d45101293e749263acf437f4d5/src/transformers/testing_utils.py#L1787 def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None): """ diff --git a/tests/pipelines/text_to_video/test_video_to_video.py b/tests/pipelines/text_to_video/test_video_to_video.py index f057eb34997e..2b4af2617d2c 100644 --- a/tests/pipelines/text_to_video/test_video_to_video.py +++ b/tests/pipelines/text_to_video/test_video_to_video.py @@ -30,6 +30,7 @@ from diffusers.utils.testing_utils import ( enable_full_determinism, floats_tensor, + is_flaky, skip_mps, slow, torch_device, @@ -156,9 +157,14 @@ def test_text_to_video_default_case(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + @is_flaky() def test_save_load_optional_components(self): super().test_save_load_optional_components(expected_max_difference=0.001) + @is_flaky() + def test_dict_tuple_outputs_equivalent(self): + super().test_dict_tuple_outputs_equivalent() + @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(), reason="XFormers attention is only available with CUDA and `xformers` installed",