From 2a7b84ee560091eaf9bba778ad3b1c20da1dd4a2 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 20 Oct 2023 01:12:24 +0100 Subject: [PATCH] Use flag() CM instead of custom one --- test/test_models.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index 33c6a84c941..088ea1bf7fa 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -25,16 +25,6 @@ SKIP_BIG_MODEL = os.getenv("SKIP_BIG_MODEL", "1") == "1" -@contextlib.contextmanager -def disable_tf32(): - previous = torch.backends.cudnn.allow_tf32 - torch.backends.cudnn.allow_tf32 = False - try: - yield - finally: - torch.backends.cudnn.allow_tf32 = previous - - def list_model_fns(module): return [get_model_builder(name) for name in list_models(module)] @@ -681,7 +671,7 @@ def test_vitc_models(model_fn, dev): test_classification_model(model_fn, dev) -@disable_tf32() # see: https://github.com/pytorch/vision/issues/7618 +@torch.backends.cudnn.flags(allow_tf32=False) # see: https://github.com/pytorch/vision/issues/7618 @pytest.mark.parametrize("model_fn", list_model_fns(models)) @pytest.mark.parametrize("dev", cpu_and_cuda()) def test_classification_model(model_fn, dev):