Skip to content

Commit

Permalink
[Tests] Add is flaky decorator (#5139)
Browse files Browse the repository at this point in the history
* add is flaky decorator

* fix more
  • Loading branch information
patrickvonplaten authored Sep 25, 2023
1 parent 787195f commit 22b19d5
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
40 changes: 40 additions & 0 deletions src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import importlib
import inspect
import io
Expand All @@ -7,7 +8,9 @@
import random
import re
import struct
import sys
import tempfile
import time
import unittest
import urllib.parse
from contextlib import contextmanager
Expand Down Expand Up @@ -612,6 +615,43 @@ def summary_failures_short(tr):
config.option.tbstyle = orig_tbstyle


# Copied from https://github.com/huggingface/transformers/blob/000e52aec8850d3fe2f360adc6fd256e5b47fe4c/src/transformers/testing_utils.py#L1905
def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None):
"""
To decorate flaky tests. They will be retried on failures.
Args:
max_attempts (`int`, *optional*, defaults to 5):
The maximum number of attempts to retry the flaky test.
wait_before_retry (`float`, *optional*):
If provided, will wait that number of seconds before retrying the test.
description (`str`, *optional*):
A string to describe the situation (what / where / why is flaky, link to GH issue/PR comments, errors,
etc.)
"""

def decorator(test_func_ref):
@functools.wraps(test_func_ref)
def wrapper(*args, **kwargs):
retry_count = 1

while retry_count < max_attempts:
try:
return test_func_ref(*args, **kwargs)

except Exception as err:
print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr)
if wait_before_retry is not None:
time.sleep(wait_before_retry)
retry_count += 1

return test_func_ref(*args, **kwargs)

return wrapper

return decorator


# Taken from: https://github.com/huggingface/transformers/blob/3658488ff77ff8d45101293e749263acf437f4d5/src/transformers/testing_utils.py#L1787
def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
"""
Expand Down
6 changes: 6 additions & 0 deletions tests/pipelines/text_to_video/test_video_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
is_flaky,
skip_mps,
slow,
torch_device,
Expand Down Expand Up @@ -156,9 +157,14 @@ def test_text_to_video_default_case(self):

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

@is_flaky()
def test_save_load_optional_components(self):
super().test_save_load_optional_components(expected_max_difference=0.001)

@is_flaky()
def test_dict_tuple_outputs_equivalent(self):
super().test_dict_tuple_outputs_equivalent()

@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
Expand Down

0 comments on commit 22b19d5

Please sign in to comment.