diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index bbe40fa92e23..759dc98718c0 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -670,14 +670,98 @@ def is_saveable_module(name, value): create_pr=create_pr, ) - def to( - self, - torch_device: Optional[Union[str, torch.device]] = None, - torch_dtype: Optional[torch.dtype] = None, - silence_dtype_warnings: bool = False, - ): - if torch_device is None and torch_dtype is None: - return self + def to(self, *args, **kwargs): + r""" + Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the + arguments of `self.to(*args, **kwargs).` + + + + If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise, + the returned pipeline is a copy of self with the desired torch.dtype and torch.device. + + + + + Here are the ways to call `to`: + + - `to(dtype, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified + [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) + - `to(device, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified + [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) + - `to(device=None, dtype=None, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the + specified [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) and + [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) + + Arguments: + dtype (`torch.dtype`, *optional*): + Returns a pipeline with the specified + [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) + device (`torch.Device`, *optional*): + Returns a pipeline with the specified + [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) + silence_dtype_warnings (`str`, *optional*, defaults to `False`): + Whether to omit warnings if the target `dtype` is not compatible with the target `device`. + + Returns: + [`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`. + """ + + torch_dtype = kwargs.pop("torch_dtype", None) + if torch_dtype is not None: + deprecate("torch_dtype", "0.25.0", "") + torch_device = kwargs.pop("torch_device", None) + if torch_device is not None: + deprecate("torch_device", "0.25.0", "") + + dtype_kwarg = kwargs.pop("dtype", None) + device_kwarg = kwargs.pop("device", None) + silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False) + + if torch_dtype is not None and dtype_kwarg is not None: + raise ValueError( + "You have passed both `torch_dtype` and `dtype` as a keyword argument. Please make sure to only pass `dtype`." + ) + + dtype = torch_dtype or dtype_kwarg + + if torch_device is not None and device_kwarg is not None: + raise ValueError( + "You have passed both `torch_device` and `device` as a keyword argument. Please make sure to only pass `device`." + ) + + device = torch_device or device_kwarg + + dtype_arg = None + device_arg = None + if len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype_arg = args[0] + else: + device_arg = torch.device(args[0]) if args[0] is not None else None + elif len(args) == 2: + if isinstance(args[0], torch.dtype): + raise ValueError( + "When passing two arguments, make sure the first corresponds to `device` and the second to `dtype`." + ) + device_arg = torch.device(args[0]) if args[0] is not None else None + dtype_arg = args[1] + elif len(args) > 2: + raise ValueError("Please make sure to pass at most two arguments (`device` and `dtype`) `.to(...)`") + + if dtype is not None and dtype_arg is not None: + raise ValueError( + "You have passed `dtype` both as an argument and as a keyword argument. Please only pass one of the two." + ) + + dtype = dtype or dtype_arg + + if device is not None and device_arg is not None: + raise ValueError( + "You have passed `device` both as an argument and as a keyword argument. Please only pass one of the two." + ) + + device = device or device_arg # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU. def module_is_sequentially_offloaded(module): @@ -698,14 +782,14 @@ def module_is_offloaded(module): pipeline_is_sequentially_offloaded = any( module_is_sequentially_offloaded(module) for _, module in self.components.items() ) - if pipeline_is_sequentially_offloaded and torch_device and torch.device(torch_device).type == "cuda": + if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda": raise ValueError( "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." ) # Display a warning in this case (the operation succeeds but the benefits are lost) pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items()) - if pipeline_is_offloaded and torch_device and torch.device(torch_device).type == "cuda": + if pipeline_is_offloaded and device and torch.device(device).type == "cuda": logger.warning( f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading." ) @@ -718,26 +802,26 @@ def module_is_offloaded(module): for module in modules: is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit - if is_loaded_in_8bit and torch_dtype is not None: + if is_loaded_in_8bit and dtype is not None: logger.warning( f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {torch_dtype} is not yet supported. Module is still in 8bit precision." ) - if is_loaded_in_8bit and torch_device is not None: + if is_loaded_in_8bit and device is not None: logger.warning( f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {torch_dtype} via `.to()` is not yet supported. Module is still on {module.device}." ) else: - module.to(torch_device, torch_dtype) + module.to(device, dtype) if ( module.dtype == torch.float16 - and str(torch_device) in ["cpu"] + and str(device) in ["cpu"] and not silence_dtype_warnings and not is_offloaded ): logger.warning( - "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It" + "Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It" " is not recommended to move them to `cpu` as running them will fail. Please make" " sure to use an accelerator to run the pipeline in inference, due to the lack of" " support for`float16` operations on this device in PyTorch. Please, remove the" @@ -760,6 +844,21 @@ def device(self) -> torch.device: return torch.device("cpu") + @property + def dtype(self) -> torch.dtype: + r""" + Returns: + `torch.dtype`: The torch dtype on which the pipeline is located. + """ + module_names, _ = self._get_signature_keys(self) + modules = [getattr(self, n, None) for n in module_names] + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.dtype + + return torch.float32 + @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): r""" @@ -1222,12 +1321,19 @@ def _execution_device(self): return torch.device(module._hf_hook.execution_device) return self.device - def enable_model_cpu_offload(self, gpu_id: int = 0, device: Union[torch.device, str] = "cuda"): + def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + + Arguments: + gpu_id (`int`, *optional*): + The ID of the accelerator that shall be used in inference. If not specified, it will default to 0. + device (`torch.Device` or `str`, *optional*, defaults to "cuda"): + The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will + default to "cuda". """ if self.model_cpu_offload_seq is None: raise ValueError( @@ -1239,7 +1345,20 @@ def enable_model_cpu_offload(self, gpu_id: int = 0, device: Union[torch.device, else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") - device = torch.device(f"cuda:{gpu_id}") + torch_device = torch.device(device) + device_index = torch_device.index + + if gpu_id is not None and device_index is not None: + raise ValueError( + f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}" + f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}" + ) + + # _offload_gpu_id should be set to passed gpu_id (or id in passed `device`) or default to previously set id or default to 0 + self._offload_gpu_id = gpu_id or torch_device.index or self._offload_gpu_id or 0 + + device_type = torch_device.type + device = torch.device(f"{device_type}:{self._offload_gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) @@ -1274,7 +1393,10 @@ def enable_model_cpu_offload(self, gpu_id: int = 0, device: Union[torch.device, def maybe_free_model_hooks(self): r""" - TODO: Better doc string + Function that offloads all components, removes all model hooks that were added when using + `enable_model_cpu_offload` and then applies them again. In case the model has not been offloaded this function + is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it + functions correctly when applying enable_model_cpu_offload. """ if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0: # `enable_model_cpu_offload` has not be called, so silently do nothing @@ -1288,21 +1410,40 @@ def maybe_free_model_hooks(self): # make sure the model is in the same state as before calling it self.enable_model_cpu_offload() - def enable_sequential_cpu_offload(self, gpu_id: int = 0, device: Union[torch.device, str] = "cuda"): + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): r""" Offloads all models to CPU using 🤗 Accelerate, significantly reducing memory usage. When called, the state dicts of all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are saved to CPU and then moved to `torch.device('meta')` and loaded to GPU only when their specific submodule has its `forward` method called. Offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. + + Arguments: + gpu_id (`int`, *optional*): + The ID of the accelerator that shall be used in inference. If not specified, it will default to 0. + device (`torch.Device` or `str`, *optional*, defaults to "cuda"): + The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will + default to "cuda". """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") - if device == "cuda": - device = torch.device(f"{device}:{gpu_id}") + torch_device = torch.device(device) + device_index = torch_device.index + + if gpu_id is not None and device_index is not None: + raise ValueError( + f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}" + f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}" + ) + + # _offload_gpu_id should be set to passed gpu_id (or id in passed `device`) or default to previously set id or default to 0 + self._offload_gpu_id = gpu_id or torch_device.index or self._offload_gpu_id or 0 + + device_type = torch_device.type + device = torch.device(f"{device_type}:{self._offload_gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 5a0c300c60c4..e4c550183562 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -1475,6 +1475,89 @@ def get_all_filenames(directory): assert len(variant_model_files) == 0 assert len(all_model_files) > 0 + def test_pipe_to(self): + unet = self.dummy_cond_unet() + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + sd = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + + device_type = torch.device(torch_device).type + + sd1 = sd.to(device_type) + sd2 = sd.to(torch.device(device_type)) + sd3 = sd.to(device_type, torch.float32) + sd4 = sd.to(device=device_type) + sd5 = sd.to(torch_device=device_type) + sd6 = sd.to(device_type, dtype=torch.float32) + sd7 = sd.to(device_type, torch_dtype=torch.float32) + + assert sd1.device.type == device_type + assert sd2.device.type == device_type + assert sd3.device.type == device_type + assert sd4.device.type == device_type + assert sd5.device.type == device_type + assert sd6.device.type == device_type + assert sd7.device.type == device_type + + sd1 = sd.to(torch.float16) + sd2 = sd.to(None, torch.float16) + sd3 = sd.to(dtype=torch.float16) + sd4 = sd.to(torch_dtype=torch.float16) + sd5 = sd.to(None, dtype=torch.float16) + sd6 = sd.to(None, torch_dtype=torch.float16) + + assert sd1.dtype == torch.float16 + assert sd2.dtype == torch.float16 + assert sd3.dtype == torch.float16 + assert sd4.dtype == torch.float16 + assert sd5.dtype == torch.float16 + assert sd6.dtype == torch.float16 + + sd1 = sd.to(device=device_type, dtype=torch.float16) + sd2 = sd.to(torch_device=device_type, torch_dtype=torch.float16) + sd3 = sd.to(device_type, torch.float16) + + assert sd1.dtype == torch.float16 + assert sd2.dtype == torch.float16 + assert sd3.dtype == torch.float16 + + assert sd1.device.type == device_type + assert sd2.device.type == device_type + assert sd3.device.type == device_type + + def test_pipe_same_device_id_offload(self): + unet = self.dummy_cond_unet() + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + sd = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + + sd.enable_model_cpu_offload(gpu_id=5) + assert sd._offload_gpu_id == 5 + sd.maybe_free_model_hooks() + assert sd._offload_gpu_id == 5 + @slow @require_torch_gpu diff --git a/tests/pipelines/text_to_video/test_video_to_video.py b/tests/pipelines/text_to_video/test_video_to_video.py index 2b4af2617d2c..397d65cae52f 100644 --- a/tests/pipelines/text_to_video/test_video_to_video.py +++ b/tests/pipelines/text_to_video/test_video_to_video.py @@ -165,6 +165,10 @@ def test_save_load_optional_components(self): def test_dict_tuple_outputs_equivalent(self): super().test_dict_tuple_outputs_equivalent() + @is_flaky() + def test_save_load_local(self): + super().test_save_load_local() + @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(), reason="XFormers attention is only available with CUDA and `xformers` installed",