Skip to content

Commit

Permalink
[Core] Improve .to(...) method, fix offloads multi-gpu, add docstri…
Browse files Browse the repository at this point in the history
…ng, add dtype (#5132)

* fix cpu offload

* fix

* fix

* Update src/diffusers/pipelines/pipeline_utils.py

* make style

* Apply suggestions from code review

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>

* fix more

* fix more

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
  • Loading branch information
3 people authored Sep 25, 2023
1 parent 92f15f5 commit 30a512e
Show file tree
Hide file tree
Showing 3 changed files with 249 additions and 21 deletions.
183 changes: 162 additions & 21 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,14 +670,98 @@ def is_saveable_module(name, value):
create_pr=create_pr,
)

def to(
self,
torch_device: Optional[Union[str, torch.device]] = None,
torch_dtype: Optional[torch.dtype] = None,
silence_dtype_warnings: bool = False,
):
if torch_device is None and torch_dtype is None:
return self
def to(self, *args, **kwargs):
r"""
Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the
arguments of `self.to(*args, **kwargs).`
<Tip>
If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise,
the returned pipeline is a copy of self with the desired torch.dtype and torch.device.
</Tip>
Here are the ways to call `to`:
- `to(dtype, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified
[`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
- `to(device, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified
[`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device)
- `to(device=None, dtype=None, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the
specified [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) and
[`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
Arguments:
dtype (`torch.dtype`, *optional*):
Returns a pipeline with the specified
[`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
device (`torch.Device`, *optional*):
Returns a pipeline with the specified
[`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device)
silence_dtype_warnings (`str`, *optional*, defaults to `False`):
Whether to omit warnings if the target `dtype` is not compatible with the target `device`.
Returns:
[`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`.
"""

torch_dtype = kwargs.pop("torch_dtype", None)
if torch_dtype is not None:
deprecate("torch_dtype", "0.25.0", "")
torch_device = kwargs.pop("torch_device", None)
if torch_device is not None:
deprecate("torch_device", "0.25.0", "")

dtype_kwarg = kwargs.pop("dtype", None)
device_kwarg = kwargs.pop("device", None)
silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False)

if torch_dtype is not None and dtype_kwarg is not None:
raise ValueError(
"You have passed both `torch_dtype` and `dtype` as a keyword argument. Please make sure to only pass `dtype`."
)

dtype = torch_dtype or dtype_kwarg

if torch_device is not None and device_kwarg is not None:
raise ValueError(
"You have passed both `torch_device` and `device` as a keyword argument. Please make sure to only pass `device`."
)

device = torch_device or device_kwarg

dtype_arg = None
device_arg = None
if len(args) == 1:
if isinstance(args[0], torch.dtype):
dtype_arg = args[0]
else:
device_arg = torch.device(args[0]) if args[0] is not None else None
elif len(args) == 2:
if isinstance(args[0], torch.dtype):
raise ValueError(
"When passing two arguments, make sure the first corresponds to `device` and the second to `dtype`."
)
device_arg = torch.device(args[0]) if args[0] is not None else None
dtype_arg = args[1]
elif len(args) > 2:
raise ValueError("Please make sure to pass at most two arguments (`device` and `dtype`) `.to(...)`")

if dtype is not None and dtype_arg is not None:
raise ValueError(
"You have passed `dtype` both as an argument and as a keyword argument. Please only pass one of the two."
)

dtype = dtype or dtype_arg

if device is not None and device_arg is not None:
raise ValueError(
"You have passed `device` both as an argument and as a keyword argument. Please only pass one of the two."
)

device = device or device_arg

# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
def module_is_sequentially_offloaded(module):
Expand All @@ -698,14 +782,14 @@ def module_is_offloaded(module):
pipeline_is_sequentially_offloaded = any(
module_is_sequentially_offloaded(module) for _, module in self.components.items()
)
if pipeline_is_sequentially_offloaded and torch_device and torch.device(torch_device).type == "cuda":
if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda":
raise ValueError(
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
)

# Display a warning in this case (the operation succeeds but the benefits are lost)
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
if pipeline_is_offloaded and torch_device and torch.device(torch_device).type == "cuda":
if pipeline_is_offloaded and device and torch.device(device).type == "cuda":
logger.warning(
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
)
Expand All @@ -718,26 +802,26 @@ def module_is_offloaded(module):
for module in modules:
is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit

if is_loaded_in_8bit and torch_dtype is not None:
if is_loaded_in_8bit and dtype is not None:
logger.warning(
f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {torch_dtype} is not yet supported. Module is still in 8bit precision."
)

if is_loaded_in_8bit and torch_device is not None:
if is_loaded_in_8bit and device is not None:
logger.warning(
f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {torch_dtype} via `.to()` is not yet supported. Module is still on {module.device}."
)
else:
module.to(torch_device, torch_dtype)
module.to(device, dtype)

if (
module.dtype == torch.float16
and str(torch_device) in ["cpu"]
and str(device) in ["cpu"]
and not silence_dtype_warnings
and not is_offloaded
):
logger.warning(
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
"Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It"
" is not recommended to move them to `cpu` as running them will fail. Please make"
" sure to use an accelerator to run the pipeline in inference, due to the lack of"
" support for`float16` operations on this device in PyTorch. Please, remove the"
Expand All @@ -760,6 +844,21 @@ def device(self) -> torch.device:

return torch.device("cpu")

@property
def dtype(self) -> torch.dtype:
r"""
Returns:
`torch.dtype`: The torch dtype on which the pipeline is located.
"""
module_names, _ = self._get_signature_keys(self)
modules = [getattr(self, n, None) for n in module_names]
modules = [m for m in modules if isinstance(m, torch.nn.Module)]

for module in modules:
return module.dtype

return torch.float32

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r"""
Expand Down Expand Up @@ -1222,12 +1321,19 @@ def _execution_device(self):
return torch.device(module._hf_hook.execution_device)
return self.device

def enable_model_cpu_offload(self, gpu_id: int = 0, device: Union[torch.device, str] = "cuda"):
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
Arguments:
gpu_id (`int`, *optional*):
The ID of the accelerator that shall be used in inference. If not specified, it will default to 0.
device (`torch.Device` or `str`, *optional*, defaults to "cuda"):
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
default to "cuda".
"""
if self.model_cpu_offload_seq is None:
raise ValueError(
Expand All @@ -1239,7 +1345,20 @@ def enable_model_cpu_offload(self, gpu_id: int = 0, device: Union[torch.device,
else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")

device = torch.device(f"cuda:{gpu_id}")
torch_device = torch.device(device)
device_index = torch_device.index

if gpu_id is not None and device_index is not None:
raise ValueError(
f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}"
f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}"
)

# _offload_gpu_id should be set to passed gpu_id (or id in passed `device`) or default to previously set id or default to 0
self._offload_gpu_id = gpu_id or torch_device.index or self._offload_gpu_id or 0

device_type = torch_device.type
device = torch.device(f"{device_type}:{self._offload_gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
Expand Down Expand Up @@ -1274,7 +1393,10 @@ def enable_model_cpu_offload(self, gpu_id: int = 0, device: Union[torch.device,

def maybe_free_model_hooks(self):
r"""
TODO: Better doc string
Function that offloads all components, removes all model hooks that were added when using
`enable_model_cpu_offload` and then applies them again. In case the model has not been offloaded this function
is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it
functions correctly when applying enable_model_cpu_offload.
"""
if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0:
# `enable_model_cpu_offload` has not be called, so silently do nothing
Expand All @@ -1288,21 +1410,40 @@ def maybe_free_model_hooks(self):
# make sure the model is in the same state as before calling it
self.enable_model_cpu_offload()

def enable_sequential_cpu_offload(self, gpu_id: int = 0, device: Union[torch.device, str] = "cuda"):
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
r"""
Offloads all models to CPU using 🤗 Accelerate, significantly reducing memory usage. When called, the state
dicts of all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are saved to CPU
and then moved to `torch.device('meta')` and loaded to GPU only when their specific submodule has its `forward`
method called. Offloading happens on a submodule basis. Memory savings are higher than with
`enable_model_cpu_offload`, but performance is lower.
Arguments:
gpu_id (`int`, *optional*):
The ID of the accelerator that shall be used in inference. If not specified, it will default to 0.
device (`torch.Device` or `str`, *optional*, defaults to "cuda"):
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
default to "cuda".
"""
if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
from accelerate import cpu_offload
else:
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")

if device == "cuda":
device = torch.device(f"{device}:{gpu_id}")
torch_device = torch.device(device)
device_index = torch_device.index

if gpu_id is not None and device_index is not None:
raise ValueError(
f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}"
f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}"
)

# _offload_gpu_id should be set to passed gpu_id (or id in passed `device`) or default to previously set id or default to 0
self._offload_gpu_id = gpu_id or torch_device.index or self._offload_gpu_id or 0

device_type = torch_device.type
device = torch.device(f"{device_type}:{self._offload_gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
Expand Down
83 changes: 83 additions & 0 deletions tests/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,6 +1475,89 @@ def get_all_filenames(directory):
assert len(variant_model_files) == 0
assert len(all_model_files) > 0

def test_pipe_to(self):
unet = self.dummy_cond_unet()
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

sd = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)

device_type = torch.device(torch_device).type

sd1 = sd.to(device_type)
sd2 = sd.to(torch.device(device_type))
sd3 = sd.to(device_type, torch.float32)
sd4 = sd.to(device=device_type)
sd5 = sd.to(torch_device=device_type)
sd6 = sd.to(device_type, dtype=torch.float32)
sd7 = sd.to(device_type, torch_dtype=torch.float32)

assert sd1.device.type == device_type
assert sd2.device.type == device_type
assert sd3.device.type == device_type
assert sd4.device.type == device_type
assert sd5.device.type == device_type
assert sd6.device.type == device_type
assert sd7.device.type == device_type

sd1 = sd.to(torch.float16)
sd2 = sd.to(None, torch.float16)
sd3 = sd.to(dtype=torch.float16)
sd4 = sd.to(torch_dtype=torch.float16)
sd5 = sd.to(None, dtype=torch.float16)
sd6 = sd.to(None, torch_dtype=torch.float16)

assert sd1.dtype == torch.float16
assert sd2.dtype == torch.float16
assert sd3.dtype == torch.float16
assert sd4.dtype == torch.float16
assert sd5.dtype == torch.float16
assert sd6.dtype == torch.float16

sd1 = sd.to(device=device_type, dtype=torch.float16)
sd2 = sd.to(torch_device=device_type, torch_dtype=torch.float16)
sd3 = sd.to(device_type, torch.float16)

assert sd1.dtype == torch.float16
assert sd2.dtype == torch.float16
assert sd3.dtype == torch.float16

assert sd1.device.type == device_type
assert sd2.device.type == device_type
assert sd3.device.type == device_type

def test_pipe_same_device_id_offload(self):
unet = self.dummy_cond_unet()
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

sd = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)

sd.enable_model_cpu_offload(gpu_id=5)
assert sd._offload_gpu_id == 5
sd.maybe_free_model_hooks()
assert sd._offload_gpu_id == 5


@slow
@require_torch_gpu
Expand Down
4 changes: 4 additions & 0 deletions tests/pipelines/text_to_video/test_video_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ def test_save_load_optional_components(self):
def test_dict_tuple_outputs_equivalent(self):
super().test_dict_tuple_outputs_equivalent()

@is_flaky()
def test_save_load_local(self):
super().test_save_load_local()

@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
Expand Down

0 comments on commit 30a512e

Please sign in to comment.