Skip to content

Commit

Permalink
Move accelerate to a soft-dependency (#1134)
Browse files Browse the repository at this point in the history
* finish

* finish

* Update src/diffusers/modeling_utils.py

* Update src/diffusers/pipeline_utils.py

Co-authored-by: Anton Lozhkov <[email protected]>

* more fixes

* fix

Co-authored-by: Anton Lozhkov <[email protected]>
  • Loading branch information
patrickvonplaten and anton-l committed Nov 4, 2022
1 parent bde4880 commit 7c2a58f
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 482 deletions.
8 changes: 0 additions & 8 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .utils import (
is_accelerate_available,
is_flax_available,
is_inflect_available,
is_onnx_available,
Expand All @@ -17,13 +16,6 @@
from .utils import logging


# This will create an extra dummy file "dummy_torch_and_accelerate_objects.py"
# TODO: (patil-suraj, anton-l) maybe import everything under is_torch_and_accelerate_available
if is_torch_available() and not is_accelerate_available():
error_msg = "Please install the `accelerate` library to use Diffusers with PyTorch. You can do so by running `pip install diffusers[torch]`. Or if torch is already installed, you can run `pip install accelerate`." # noqa: E501
raise ImportError(error_msg)


if is_torch_available():
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
Expand Down
34 changes: 30 additions & 4 deletions src/diffusers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,20 @@
import torch
from torch import Tensor, device

import accelerate
from accelerate.utils import set_module_tensor_to_device
from accelerate.utils.versions import is_torch_version
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError

from . import __version__
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, logging
from .utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
WEIGHTS_NAME,
is_accelerate_available,
is_torch_version,
logging,
)


logger = logging.get_logger(__name__)
Expand All @@ -41,6 +46,12 @@
_LOW_CPU_MEM_USAGE_DEFAULT = False


if is_accelerate_available():
import accelerate
from accelerate.utils import set_module_tensor_to_device
from accelerate.utils.versions import is_torch_version


def get_parameter_device(parameter: torch.nn.Module):
try:
return next(parameter.parameters()).device
Expand Down Expand Up @@ -319,6 +330,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)

if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warn(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)

if device_map is not None and not is_accelerate_available():
raise NotImplementedError(
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
" `device_map=None`. You can install accelerate with `pip install accelerate`."
)

# Check if we can handle device_map and dispatching the weights
if device_map is not None and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
Expand Down
12 changes: 11 additions & 1 deletion src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

import diffusers
import PIL
from accelerate.utils.versions import is_torch_version
from huggingface_hub import snapshot_download
from packaging import version
from PIL import Image
Expand All @@ -43,6 +42,8 @@
WEIGHTS_NAME,
BaseOutput,
deprecate,
is_accelerate_available,
is_torch_version,
is_transformers_available,
logging,
)
Expand Down Expand Up @@ -397,6 +398,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)

if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warn(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)

if device_map is not None and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
is_scipy_available,
is_tf_available,
is_torch_available,
is_torch_version,
is_transformers_available,
is_unidecode_available,
requires_backends,
Expand Down
15 changes: 0 additions & 15 deletions src/diffusers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,21 +272,6 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class VQDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class DDIMScheduler(metaclass=DummyObject):
_backends = ["torch"]

Expand Down
Loading

0 comments on commit 7c2a58f

Please sign in to comment.