diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 596e5c4868fe..9a9ad17e7a4e 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -5,7 +5,7 @@ import numpy as np import torch -from torchvision import transforms +from transformers import is_torchvision_available from .models import UNet2DConditionModel from .utils import ( @@ -23,6 +23,9 @@ if is_peft_available(): from peft import set_peft_model_state_dict +if is_torchvision_available(): + from torchvision import transforms + def set_seed(seed: int): """ @@ -79,6 +82,11 @@ def resolve_interpolation_mode(interpolation_type: str): `torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize` transform. """ + if not is_torchvision_available(): + raise ImportError( + "Please make sure to install `torchvision` to be able to use the `resolve_interpolation_mode()` function." + ) + if interpolation_type == "bilinear": interpolation_mode = transforms.InterpolationMode.BILINEAR elif interpolation_type == "bicubic":