diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 9a9ad17e7a4e..32a1738e9d04 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -5,7 +5,6 @@ import numpy as np import torch -from transformers import is_torchvision_available from .models import UNet2DConditionModel from .utils import ( @@ -13,6 +12,7 @@ convert_state_dict_to_peft, deprecate, is_peft_available, + is_torchvision_available, is_transformers_available, ) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 667f1fe5e2fd..2da6ec7e7df8 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -75,6 +75,7 @@ is_torch_version, is_torch_xla_available, is_torchsde_available, + is_torchvision_available, is_transformers_available, is_transformers_version, is_unidecode_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index ac1565023b09..e89ac0b7723f 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -278,6 +278,13 @@ except importlib_metadata.PackageNotFoundError: _peft_available = False +_torchvision_available = importlib.util.find_spec("torchvision") is not None +try: + _torchvision_version = importlib_metadata.version("torchvision") + logger.debug(f"Successfully imported torchvision version {_torchvision_version}") +except importlib_metadata.PackageNotFoundError: + _torchvision_available = False + def is_torch_available(): return _torch_available @@ -367,6 +374,10 @@ def is_peft_available(): return _peft_available +def is_torchvision_available(): + return _torchvision_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the