From 9c0e20de61a6e0adcec706564cee739520c1d2f4 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 19 Dec 2024 10:24:57 +0530 Subject: [PATCH 01/13] [chore] Update README_sana.md to update the default model (#10285) Update README_sana.md to update the default model --- examples/dreambooth/README_sana.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/README_sana.md b/examples/dreambooth/README_sana.md index fe861d62472b..d82529c64de8 100644 --- a/examples/dreambooth/README_sana.md +++ b/examples/dreambooth/README_sana.md @@ -73,7 +73,7 @@ This will also allow us to push the trained LoRA parameters to the Hugging Face Now, we can launch training using: ```bash -export MODEL_NAME="Efficient-Large-Model/Sana_1600M_1024px_diffusers" +export MODEL_NAME="Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers" export INSTANCE_DIR="dog" export OUTPUT_DIR="trained-sana-lora" @@ -124,4 +124,4 @@ We provide several options for optimizing memory optimization: * `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done. * `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library. -Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana) of the `SanaPipeline` to know more about the models available under the SANA family and their preferred dtypes during inference. \ No newline at end of file +Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana) of the `SanaPipeline` to know more about the models available under the SANA family and their preferred dtypes during inference. From f781b8c30c4d70fbf0afcc9799c7f9e9693b2921 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 19 Dec 2024 10:28:10 +0530 Subject: [PATCH 02/13] Hunyuan VAE tiling fixes and transformer docs (#10295) * update * udpate * fix test --- .../autoencoder_kl_hunyuan_video.py | 8 ++-- .../transformers/transformer_hunyuan_video.py | 40 +++++++++++++++++++ .../test_models_autoencoder_hunyuan_video.py | 25 ++++++++++++ 3 files changed, 69 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py index bded90a8bcff..5c1d94d4e18f 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -792,12 +792,12 @@ def __init__( # The minimal tile height and width for spatial tiling to be used self.tile_sample_min_height = 256 self.tile_sample_min_width = 256 - self.tile_sample_min_num_frames = 64 + self.tile_sample_min_num_frames = 16 # The minimal distance between two spatial tiles self.tile_sample_stride_height = 192 self.tile_sample_stride_width = 192 - self.tile_sample_stride_num_frames = 48 + self.tile_sample_stride_num_frames = 12 def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (HunyuanVideoEncoder3D, HunyuanVideoDecoder3D)): @@ -1003,7 +1003,7 @@ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: for i in range(0, height, self.tile_sample_stride_height): row = [] for j in range(0, width, self.tile_sample_stride_width): - tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] tile = self.encoder(tile) tile = self.quant_conv(tile) row.append(tile) @@ -1020,7 +1020,7 @@ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: if j > 0: tile = self.blend_h(row[j - 1], tile, blend_width) result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) - result_rows.append(torch.cat(result_row, dim=-1)) + result_rows.append(torch.cat(result_row, dim=4)) enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] return enc diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index d8f9834ea61c..737be99c5a10 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -497,6 +497,46 @@ def forward( class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin): + r""" + A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo). + + Args: + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + num_attention_heads (`int`, defaults to `24`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + num_layers (`int`, defaults to `20`): + The number of layers of dual-stream blocks to use. + num_single_layers (`int`, defaults to `40`): + The number of layers of single-stream blocks to use. + num_refiner_layers (`int`, defaults to `2`): + The number of layers of refiner blocks to use. + mlp_ratio (`float`, defaults to `4.0`): + The ratio of the hidden layer size to the input size in the feedforward network. + patch_size (`int`, defaults to `2`): + The size of the spatial patches to use in the patch embedding layer. + patch_size_t (`int`, defaults to `1`): + The size of the tmeporal patches to use in the patch embedding layer. + qk_norm (`str`, defaults to `rms_norm`): + The normalization to use for the query and key projections in the attention layers. + guidance_embeds (`bool`, defaults to `True`): + Whether to use guidance embeddings in the model. + text_embed_dim (`int`, defaults to `4096`): + Input dimension of text embeddings from the text encoder. + pooled_projection_dim (`int`, defaults to `768`): + The dimension of the pooled projection of the text embeddings. + rope_theta (`float`, defaults to `256.0`): + The value of theta to use in the RoPE layer. + rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions of the axes to use in the RoPE layer. + """ + + _supports_gradient_checkpointing = True + @register_to_config def __init__( self, diff --git a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py index 826ac30d5f2f..7b7901a6fd94 100644 --- a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py +++ b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py @@ -43,10 +43,14 @@ def get_autoencoder_kl_hunyuan_video_config(self): "down_block_types": ( "HunyuanVideoDownBlock3D", "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", ), "up_block_types": ( "HunyuanVideoUpBlock3D", "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", ), "block_out_channels": (8, 8, 8, 8), "layers_per_block": 1, @@ -154,6 +158,27 @@ def test_gradient_checkpointing_is_applied(self): } super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + # We need to overwrite this test because the base test does not account length of down_block_types + def test_forward_with_norm_groups(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["norm_num_groups"] = 16 + init_dict["block_out_channels"] = (16, 16, 16, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.to_tuple()[0] + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + @unittest.skip("Unsupported test.") def test_outputs_equivalence(self): pass From 4450d26b63b4f6e7736ca86f11d0c37827159bfa Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 19 Dec 2024 08:28:56 +0000 Subject: [PATCH 03/13] Add Flux Control to AutoPipeline (#10292) --- src/diffusers/pipelines/auto_pipeline.py | 37 ++++++++++++++++++++---- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index a0f95fe6cdc1..f3a05c2c661f 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -35,9 +35,12 @@ ) from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline from .flux import ( + FluxControlImg2ImgPipeline, + FluxControlInpaintPipeline, FluxControlNetImg2ImgPipeline, FluxControlNetInpaintPipeline, FluxControlNetPipeline, + FluxControlPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline, @@ -125,6 +128,7 @@ ("pixart-sigma-pag", PixArtSigmaPAGPipeline), ("auraflow", AuraFlowPipeline), ("flux", FluxPipeline), + ("flux-control", FluxControlPipeline), ("flux-controlnet", FluxControlNetPipeline), ("lumina", LuminaText2ImgPipeline), ("cogview3", CogView3PlusPipeline), @@ -150,6 +154,7 @@ ("lcm", LatentConsistencyModelImg2ImgPipeline), ("flux", FluxImg2ImgPipeline), ("flux-controlnet", FluxControlNetImg2ImgPipeline), + ("flux-control", FluxControlImg2ImgPipeline), ] ) @@ -168,6 +173,7 @@ ("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline), ("flux", FluxInpaintPipeline), ("flux-controlnet", FluxControlNetInpaintPipeline), + ("flux-control", FluxControlInpaintPipeline), ("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline), ] ) @@ -401,16 +407,20 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) orig_class_name = config["_class_name"] + if "ControlPipeline" in orig_class_name: + to_replace = "ControlPipeline" + else: + to_replace = "Pipeline" if "controlnet" in kwargs: if isinstance(kwargs["controlnet"], ControlNetUnionModel): - orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetUnionPipeline") + orig_class_name = config["_class_name"].replace(to_replace, "ControlNetUnionPipeline") else: - orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") + orig_class_name = config["_class_name"].replace(to_replace, "ControlNetPipeline") if "enable_pag" in kwargs: enable_pag = kwargs.pop("enable_pag") if enable_pag: - orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline") + orig_class_name = orig_class_name.replace(to_replace, "PAGPipeline") text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name) @@ -694,8 +704,14 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): # the `orig_class_name` can be: # `- *Pipeline` (for regular text-to-image checkpoint) + # - `*ControlPipeline` (for Flux tools specific checkpoint) # `- *Img2ImgPipeline` (for refiner checkpoint) - to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline" + if "Img2Img" in orig_class_name: + to_replace = "Img2ImgPipeline" + elif "ControlPipeline" in orig_class_name: + to_replace = "ControlPipeline" + else: + to_replace = "Pipeline" if "controlnet" in kwargs: if isinstance(kwargs["controlnet"], ControlNetUnionModel): @@ -707,6 +723,9 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): if enable_pag: orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace) + if to_replace == "ControlPipeline": + orig_class_name = orig_class_name.replace(to_replace, "ControlImg2ImgPipeline") + image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name) kwargs = {**load_config_kwargs, **kwargs} @@ -994,8 +1013,14 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): # The `orig_class_name`` can be: # `- *InpaintPipeline` (for inpaint-specific checkpoint) + # - `*ControlPipeline` (for Flux tools specific checkpoint) # - or *Pipeline (for regular text-to-image checkpoint) - to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline" + if "Inpaint" in orig_class_name: + to_replace = "InpaintPipeline" + elif "ControlPipeline" in orig_class_name: + to_replace = "ControlPipeline" + else: + to_replace = "Pipeline" if "controlnet" in kwargs: if isinstance(kwargs["controlnet"], ControlNetUnionModel): @@ -1006,6 +1031,8 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): enable_pag = kwargs.pop("enable_pag") if enable_pag: orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace) + if to_replace == "ControlPipeline": + orig_class_name = orig_class_name.replace(to_replace, "ControlInpaintPipeline") inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name) kwargs = {**load_config_kwargs, **kwargs} From 2f7a417d1fb11bd242ad7f9098bb9fdf77c54422 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E4=B8=89=E7=9F=B3?= <49309820+zhaowendao30@users.noreply.github.com> Date: Thu, 19 Dec 2024 17:07:50 +0800 Subject: [PATCH 04/13] Update lora_conversion_utils.py (#9980) x-flux single-blocks lora load Co-authored-by: Sayak Paul Co-authored-by: YiYi Xu --- src/diffusers/loaders/lora_conversion_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index aab87b8f4dba..07c2c2272422 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -643,7 +643,11 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): old_state_dict, new_state_dict, old_key, - [f"transformer.single_transformer_blocks.{block_num}.norm.linear"], + [ + f"transformer.single_transformer_blocks.{block_num}.attn.to_q", + f"transformer.single_transformer_blocks.{block_num}.attn.to_k", + f"transformer.single_transformer_blocks.{block_num}.attn.to_v", + ], ) if "down" in old_key: From 0ed09a17bbab784a78fb163b557b4827467b0468 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 19 Dec 2024 09:24:52 +0000 Subject: [PATCH 05/13] Check correct model type is passed to `from_pretrained` (#10189) * Check correct model type is passed to `from_pretrained` * Flax, skip scheduler * test_wrong_model * Fix for scheduler * Update tests/pipelines/test_pipelines.py Co-authored-by: Sayak Paul * EnumMeta * Flax * scheduler in expected types * make * type object 'CLIPTokenizer' has no attribute '_PipelineFastTests__name' * support union * fix typing in kandinsky * make * add LCMScheduler * 'LCMScheduler' object has no attribute 'sigmas' * tests for wrong scheduler * make * update * warning * tests * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Dhruv Nair * import FlaxSchedulerMixin * skip scheduler --------- Co-authored-by: Sayak Paul Co-authored-by: Dhruv Nair --- src/diffusers/pipelines/pipeline_utils.py | 22 ++++++++++++++++++++++ tests/pipelines/test_pipelines.py | 10 ++++++++++ 2 files changed, 32 insertions(+) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index a504184ea2f2..c505c5a262a3 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import enum import fnmatch import importlib import inspect @@ -811,6 +812,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # in this case they are already instantiated in `kwargs` # extract them here expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) + expected_types = pipeline_class._get_signature_types() passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) @@ -833,6 +835,26 @@ def load_module(name, value): init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} + for key in init_dict.keys(): + if key not in passed_class_obj: + continue + if "scheduler" in key: + continue + + class_obj = passed_class_obj[key] + _expected_class_types = [] + for expected_type in expected_types[key]: + if isinstance(expected_type, enum.EnumMeta): + _expected_class_types.extend(expected_type.__members__.keys()) + else: + _expected_class_types.append(expected_type.__name__) + + _is_valid_type = class_obj.__class__.__name__ in _expected_class_types + if not _is_valid_type: + logger.warning( + f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}." + ) + # Special case: safety_checker must be loaded separately when using `from_flax` if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj: raise NotImplementedError( diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 43b01c40f5bb..423c82e0602e 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -1802,6 +1802,16 @@ def test_pipe_same_device_id_offload(self): sd.maybe_free_model_hooks() assert sd._offload_gpu_id == 5 + def test_wrong_model(self): + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + with self.assertRaises(ValueError) as error_context: + _ = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/diffusers-stable-diffusion-tiny-all", text_encoder=tokenizer + ) + + assert "is of type" in str(error_context.exception) + assert "but should be" in str(error_context.exception) + @slow @require_torch_gpu From 1826a1e7d31df48d345a20028b3ace48f09a4e60 Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Thu, 19 Dec 2024 18:52:20 +0800 Subject: [PATCH 06/13] [LoRA] Support HunyuanVideo (#10254) * 1217 * 1217 * 1217 * update * reverse * add test * update test * make style * update * make style --------- Co-authored-by: Aryan --- src/diffusers/loaders/__init__.py | 2 + src/diffusers/loaders/lora_pipeline.py | 308 ++++++++++++++++++ src/diffusers/loaders/peft.py | 1 + .../transformers/transformer_hunyuan_video.py | 28 +- .../hunyuan_video/pipeline_hunyuan_video.py | 14 +- tests/lora/test_lora_layers_hunyuanvideo.py | 228 +++++++++++++ tests/lora/utils.py | 34 +- 7 files changed, 600 insertions(+), 15 deletions(-) create mode 100644 tests/lora/test_lora_layers_hunyuanvideo.py diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index b59150376599..6ea382d721de 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -70,6 +70,7 @@ def text_encoder_attn_modules(text_encoder): "FluxLoraLoaderMixin", "CogVideoXLoraLoaderMixin", "Mochi1LoraLoaderMixin", + "HunyuanVideoLoraLoaderMixin", "SanaLoraLoaderMixin", ] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] @@ -90,6 +91,7 @@ def text_encoder_attn_modules(text_encoder): AmusedLoraLoaderMixin, CogVideoXLoraLoaderMixin, FluxLoraLoaderMixin, + HunyuanVideoLoraLoaderMixin, LoraLoaderMixin, LTXVideoLoraLoaderMixin, Mochi1LoraLoaderMixin, diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index b8c44e480093..46d744233014 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -3870,6 +3870,314 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * super().unfuse_lora(components=components) +class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + return state_dict + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + def load_lora_weights( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See + [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel + def load_lora_into_transformer( + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`HunyuanVideoTransformer3DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer + def fuse_lora( + self, + components: List[str] = ["transformer", "text_encoder"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + ) + + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer + def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + unfuse_text_encoder (`bool`, defaults to `True`): + Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the + LoRA parameters then it won't have any effect. + """ + super().unfuse_lora(components=components) + + class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): def __init__(self, *args, **kwargs): deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead." diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index a791a250af08..9c00012ebc65 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -53,6 +53,7 @@ "FluxTransformer2DModel": lambda model_cls, weights: weights, "CogVideoXTransformer3DModel": lambda model_cls, weights: weights, "MochiTransformer3DModel": lambda model_cls, weights: weights, + "HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights, "LTXVideoTransformer3DModel": lambda model_cls, weights: weights, "SanaTransformer2DModel": lambda model_cls, weights: weights, } diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 737be99c5a10..089389b5f9ad 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -19,7 +19,8 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import is_torch_version +from ...loaders import PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ..attention import FeedForward from ..attention_processor import Attention, AttentionProcessor from ..embeddings import ( @@ -32,6 +33,9 @@ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + class HunyuanVideoAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): @@ -496,7 +500,7 @@ def forward( return hidden_states, encoder_hidden_states -class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin): +class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): r""" A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo). @@ -670,8 +674,24 @@ def forward( encoder_attention_mask: torch.Tensor, pooled_projections: torch.Tensor, guidance: torch.Tensor = None, + attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + batch_size, num_channels, num_frames, height, width = hidden_states.shape p, p_t = self.config.patch_size, self.config.patch_size_t post_patch_num_frames = num_frames // p_t @@ -757,6 +777,10 @@ def custom_forward(*inputs): hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + if not return_dict: return (hidden_states,) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index bd3d3c1e8485..4423ccf97932 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -20,6 +20,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import HunyuanVideoLoraLoaderMixin from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging, replace_example_docstring @@ -132,7 +133,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class HunyuanVideoPipeline(DiffusionPipeline): +class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): r""" Pipeline for text-to-video generation using HunyuanVideo. @@ -447,6 +448,10 @@ def guidance_scale(self): def num_timesteps(self): return self._num_timesteps + @property + def attention_kwargs(self): + return self._attention_kwargs + @property def interrupt(self): return self._interrupt @@ -471,6 +476,7 @@ def __call__( prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -525,6 +531,10 @@ def __call__( The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. @@ -562,6 +572,7 @@ def __call__( ) self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs self._interrupt = False device = self._execution_device @@ -640,6 +651,7 @@ def __call__( encoder_attention_mask=prompt_attention_mask, pooled_projections=pooled_prompt_embeds, guidance=guidance, + attention_kwargs=attention_kwargs, return_dict=False, )[0] diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py new file mode 100644 index 000000000000..59464c052684 --- /dev/null +++ b/tests/lora/test_lora_layers_hunyuanvideo.py @@ -0,0 +1,228 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import numpy as np +import pytest +import torch +from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast + +from diffusers import ( + AutoencoderKLHunyuanVideo, + FlowMatchEulerDiscreteScheduler, + HunyuanVideoPipeline, + HunyuanVideoTransformer3DModel, +) +from diffusers.utils.testing_utils import ( + floats_tensor, + is_torch_version, + require_peft_backend, + skip_mps, + torch_device, +) + + +sys.path.append(".") + +from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 + + +@require_peft_backend +@skip_mps +class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): + pipeline_class = HunyuanVideoPipeline + scheduler_cls = FlowMatchEulerDiscreteScheduler + scheduler_classes = [FlowMatchEulerDiscreteScheduler] + scheduler_kwargs = {} + + transformer_kwargs = { + "in_channels": 4, + "out_channels": 4, + "num_attention_heads": 2, + "attention_head_dim": 10, + "num_layers": 1, + "num_single_layers": 1, + "num_refiner_layers": 1, + "patch_size": 1, + "patch_size_t": 1, + "guidance_embeds": True, + "text_embed_dim": 16, + "pooled_projection_dim": 8, + "rope_axes_dim": (2, 4, 4), + } + transformer_cls = HunyuanVideoTransformer3DModel + vae_kwargs = { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 4, + "down_block_types": ( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ), + "up_block_types": ( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ), + "block_out_channels": (8, 8, 8, 8), + "layers_per_block": 1, + "act_fn": "silu", + "norm_num_groups": 4, + "scaling_factor": 0.476986, + "spatial_compression_ratio": 8, + "temporal_compression_ratio": 4, + "mid_block_add_attention": True, + } + vae_cls = AutoencoderKLHunyuanVideo + has_two_text_encoders = True + tokenizer_cls, tokenizer_id, tokenizer_subfolder = ( + LlamaTokenizerFast, + "hf-internal-testing/tiny-random-hunyuanvideo", + "tokenizer", + ) + tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = ( + CLIPTokenizer, + "hf-internal-testing/tiny-random-hunyuanvideo", + "tokenizer_2", + ) + text_encoder_cls, text_encoder_id, text_encoder_subfolder = ( + LlamaModel, + "hf-internal-testing/tiny-random-hunyuanvideo", + "text_encoder", + ) + text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder = ( + CLIPTextModel, + "hf-internal-testing/tiny-random-hunyuanvideo", + "text_encoder_2", + ) + + @property + def output_shape(self): + return (1, 9, 32, 32, 3) + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 16 + num_channels = 4 + num_frames = 9 + num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1 + sizes = (4, 4) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_latent_frames, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "", + "num_frames": num_frames, + "num_inference_steps": 1, + "guidance_scale": 6.0, + "height": 32, + "width": 32, + "max_sequence_length": sequence_length, + "prompt_template": {"template": "{}", "crop_start": 0}, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + + @pytest.mark.xfail( + condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"), + reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.", + strict=True, + ) + def test_lora_fuse_nan(self): + for scheduler_cls in self.scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + + # corrupt one LoRA weight with `inf` values + with torch.no_grad(): + pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf") + + # with `safe_fusing=True` we should see an Error + with self.assertRaises(ValueError): + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) + + # without we should not see an error, but every image will be black + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) + + out = pipe( + prompt=inputs["prompt"], + height=inputs["height"], + width=inputs["width"], + num_frames=inputs["num_frames"], + num_inference_steps=inputs["num_inference_steps"], + max_sequence_length=inputs["max_sequence_length"], + output_type="np", + )[0] + + self.assertTrue(np.isnan(out).all()) + + def test_simple_inference_with_text_lora_denoiser_fused_multi(self): + super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) + + def test_simple_inference_with_text_denoiser_lora_unfused(self): + super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + + # TODO(aryan): Fix the following test + @unittest.skip("This test fails with an error I haven't been able to debug yet.") + def test_simple_inference_save_pretrained(self): + pass + + @unittest.skip("Not supported in HunyuanVideo.") + def test_simple_inference_with_text_denoiser_block_scale(self): + pass + + @unittest.skip("Not supported in HunyuanVideo.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + pass + + @unittest.skip("Not supported in HunyuanVideo.") + def test_modify_padding_mode(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") + def test_simple_inference_with_partial_text_lora(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") + def test_simple_inference_with_text_lora(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") + def test_simple_inference_with_text_lora_and_scale(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") + def test_simple_inference_with_text_lora_fused(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") + def test_simple_inference_with_text_lora_save_load(self): + pass diff --git a/tests/lora/utils.py b/tests/lora/utils.py index ac7a944cd026..73ed17049c1b 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -89,12 +89,12 @@ class PeftLoraLoaderMixinTests: has_two_text_encoders = False has_three_text_encoders = False - text_encoder_cls, text_encoder_id = None, None - text_encoder_2_cls, text_encoder_2_id = None, None - text_encoder_3_cls, text_encoder_3_id = None, None - tokenizer_cls, tokenizer_id = None, None - tokenizer_2_cls, tokenizer_2_id = None, None - tokenizer_3_cls, tokenizer_3_id = None, None + text_encoder_cls, text_encoder_id, text_encoder_subfolder = None, None, None + text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder = None, None, None + text_encoder_3_cls, text_encoder_3_id, text_encoder_3_subfolder = None, None, None + tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, None + tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, None + tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, None unet_kwargs = None transformer_cls = None @@ -124,16 +124,26 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False): torch.manual_seed(0) vae = self.vae_cls(**self.vae_kwargs) - text_encoder = self.text_encoder_cls.from_pretrained(self.text_encoder_id) - tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id) + text_encoder = self.text_encoder_cls.from_pretrained( + self.text_encoder_id, subfolder=self.text_encoder_subfolder + ) + tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id, subfolder=self.tokenizer_subfolder) if self.text_encoder_2_cls is not None: - text_encoder_2 = self.text_encoder_2_cls.from_pretrained(self.text_encoder_2_id) - tokenizer_2 = self.tokenizer_2_cls.from_pretrained(self.tokenizer_2_id) + text_encoder_2 = self.text_encoder_2_cls.from_pretrained( + self.text_encoder_2_id, subfolder=self.text_encoder_2_subfolder + ) + tokenizer_2 = self.tokenizer_2_cls.from_pretrained( + self.tokenizer_2_id, subfolder=self.tokenizer_2_subfolder + ) if self.text_encoder_3_cls is not None: - text_encoder_3 = self.text_encoder_3_cls.from_pretrained(self.text_encoder_3_id) - tokenizer_3 = self.tokenizer_3_cls.from_pretrained(self.tokenizer_3_id) + text_encoder_3 = self.text_encoder_3_cls.from_pretrained( + self.text_encoder_3_id, subfolder=self.text_encoder_3_subfolder + ) + tokenizer_3 = self.tokenizer_3_cls.from_pretrained( + self.tokenizer_3_id, subfolder=self.tokenizer_3_subfolder + ) text_lora_config = LoraConfig( r=rank, From 9764f229d4a8386b4602711d0da5a4b02d9aa791 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 19 Dec 2024 22:20:40 +0530 Subject: [PATCH 07/13] [Single File] Add single file support for Mochi Transformer (#10268) update --- src/diffusers/loaders/single_file_model.py | 5 + src/diffusers/loaders/single_file_utils.py | 109 ++++++++++++++++++ .../models/transformers/transformer_mochi.py | 3 +- 3 files changed, 116 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 9641435fa5a6..d102282025c7 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -32,6 +32,7 @@ convert_ldm_vae_checkpoint, convert_ltx_transformer_checkpoint_to_diffusers, convert_ltx_vae_checkpoint_to_diffusers, + convert_mochi_transformer_checkpoint_to_diffusers, convert_sd3_transformer_checkpoint_to_diffusers, convert_stable_cascade_unet_single_file_to_diffusers, create_controlnet_diffusers_config_from_ldm, @@ -96,6 +97,10 @@ "default_subfolder": "vae", }, "AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers}, + "MochiTransformer3DModel": { + "checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers, + "default_subfolder": "transformer", + }, } diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index ded466b35e9a..8b2bf12214cd 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -106,6 +106,7 @@ ], "autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias", "autoencoder-dc-sana": "encoder.project_in.conv.bias", + "mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"], } DIFFUSERS_DEFAULT_PIPELINE_PATHS = { @@ -159,6 +160,7 @@ "autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"}, "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"}, "autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"}, + "mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"}, } # Use to configure model sample size when original config is provided @@ -618,6 +620,9 @@ def infer_diffusers_model_type(checkpoint): else: model_type = "autoencoder-dc-f128c512" + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]): + model_type = "mochi-1-preview" + else: model_type = "v1" @@ -1758,6 +1763,12 @@ def swap_scale_shift(weight, dim): return new_weight +def swap_proj_gate(weight): + proj, gate = weight.chunk(2, dim=0) + new_weight = torch.cat([gate, proj], dim=0) + return new_weight + + def get_attn2_layers(state_dict): attn2_layers = [] for key in state_dict.keys(): @@ -2414,3 +2425,101 @@ def remap_proj_conv_(key: str, state_dict): handler_fn_inplace(key, converted_state_dict) return converted_state_dict + + +def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + new_state_dict = {} + + # Comfy checkpoints add this prefix + keys = list(checkpoint.keys()) + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + + # Convert patch_embed + new_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") + new_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") + + # Convert time_embed + new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight") + new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") + new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight") + new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") + new_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight") + new_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias") + new_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight") + new_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias") + new_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight") + new_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias") + new_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight") + new_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias") + + # Convert transformer blocks + num_layers = 48 + for i in range(num_layers): + block_prefix = f"transformer_blocks.{i}." + old_prefix = f"blocks.{i}." + + # norm1 + new_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight") + new_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias") + if i < num_layers - 1: + new_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(old_prefix + "mod_y.weight") + new_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(old_prefix + "mod_y.bias") + else: + new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop( + old_prefix + "mod_y.weight" + ) + new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(old_prefix + "mod_y.bias") + + # Visual attention + qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight") + q, k, v = qkv_weight.chunk(3, dim=0) + + new_state_dict[block_prefix + "attn1.to_q.weight"] = q + new_state_dict[block_prefix + "attn1.to_k.weight"] = k + new_state_dict[block_prefix + "attn1.to_v.weight"] = v + new_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(old_prefix + "attn.q_norm_x.weight") + new_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(old_prefix + "attn.k_norm_x.weight") + new_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(old_prefix + "attn.proj_x.weight") + new_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias") + + # Context attention + qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight") + q, k, v = qkv_weight.chunk(3, dim=0) + + new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q + new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k + new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v + new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop( + old_prefix + "attn.q_norm_y.weight" + ) + new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop( + old_prefix + "attn.k_norm_y.weight" + ) + if i < num_layers - 1: + new_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop( + old_prefix + "attn.proj_y.weight" + ) + new_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(old_prefix + "attn.proj_y.bias") + + # MLP + new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate( + checkpoint.pop(old_prefix + "mlp_x.w1.weight") + ) + new_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight") + if i < num_layers - 1: + new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate( + checkpoint.pop(old_prefix + "mlp_y.w1.weight") + ) + new_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_y.w2.weight") + + # Output layers + new_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0) + new_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0) + new_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") + new_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") + + new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies") + + return new_state_dict diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index fe72dc56883e..41e5289f2d57 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -20,6 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin +from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward @@ -304,7 +305,7 @@ def forward( @maybe_allow_in_graph -class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): +class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): r""" A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview). From 3ee966950b636bcb9a78cc107da7887f195ac1a2 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 19 Dec 2024 22:34:44 +0530 Subject: [PATCH 08/13] Allow Mochi Transformer to be split across multiple GPUs (#10300) update --- src/diffusers/models/transformers/transformer_mochi.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 41e5289f2d57..8763ea450253 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -335,6 +335,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri """ _supports_gradient_checkpointing = True + _no_split_modules = ["MochiTransformerBlock"] @register_to_config def __init__( From 074798b2997a6f1a329924b400a0db924e8e6735 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 19 Dec 2024 17:04:57 +0000 Subject: [PATCH 09/13] Fix `local_files_only` for checkpoints with shards (#10294) --- src/diffusers/utils/hub_utils.py | 67 ++++++++++++++------------------ 1 file changed, 29 insertions(+), 38 deletions(-) diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index ef4715ee0e1e..a6dfe18433e3 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -455,48 +455,39 @@ def _get_checkpoint_shard_files( allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns] ignore_patterns = ["*.json", "*.md"] - if not local_files_only: - # `model_info` call must guarded with the above condition. - model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token) - for shard_file in original_shard_filenames: - shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings) - if not shard_file_present: - raise EnvironmentError( - f"{shards_path} does not appear to have a file named {shard_file} which is " - "required according to the checkpoint index." - ) - - try: - # Load from URL - cached_folder = snapshot_download( - pretrained_model_name_or_path, - cache_dir=cache_dir, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - allow_patterns=allow_patterns, - ignore_patterns=ignore_patterns, - user_agent=user_agent, - ) - if subfolder is not None: - cached_folder = os.path.join(cached_folder, subfolder) - - # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so - # we don't have to catch them here. We have also dealt with EntryNotFoundError. - except HTTPError as e: + # `model_info` call must guarded with the above condition. + model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token) + for shard_file in original_shard_filenames: + shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings) + if not shard_file_present: raise EnvironmentError( - f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try" - " again after checking your internet connection." - ) from e + f"{shards_path} does not appear to have a file named {shard_file} which is " + "required according to the checkpoint index." + ) - # If `local_files_only=True`, `cached_folder` may not contain all the shard files. - elif local_files_only: - _check_if_shards_exist_locally( - local_dir=cache_dir, subfolder=subfolder, original_shard_filenames=original_shard_filenames + try: + # Load from URL + cached_folder = snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + user_agent=user_agent, ) if subfolder is not None: - cached_folder = os.path.join(cache_dir, subfolder) + cached_folder = os.path.join(cached_folder, subfolder) + + # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so + # we don't have to catch them here. We have also dealt with EntryNotFoundError. + except HTTPError as e: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try" + " again after checking your internet connection." + ) from e return cached_folder, sharded_metadata From d8825e7697d2ac982046f96652261a60596c4944 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 20 Dec 2024 02:35:41 +0530 Subject: [PATCH 10/13] Fix failing lora tests after HunyuanVideo lora (#10307) fix --- tests/lora/utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 73ed17049c1b..0a0366fd8d2b 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -89,12 +89,12 @@ class PeftLoraLoaderMixinTests: has_two_text_encoders = False has_three_text_encoders = False - text_encoder_cls, text_encoder_id, text_encoder_subfolder = None, None, None - text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder = None, None, None - text_encoder_3_cls, text_encoder_3_id, text_encoder_3_subfolder = None, None, None - tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, None - tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, None - tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, None + text_encoder_cls, text_encoder_id, text_encoder_subfolder = None, None, "" + text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder = None, None, "" + text_encoder_3_cls, text_encoder_3_id, text_encoder_3_subfolder = None, None, "" + tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, "" + tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, "" + tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, "" unet_kwargs = None transformer_cls = None From b756ec6e80b3d94c3ae7dc356bdbbdb426a05dca Mon Sep 17 00:00:00 2001 From: djm <92705171+Foundsheep@users.noreply.github.com> Date: Fri, 20 Dec 2024 07:24:18 +0900 Subject: [PATCH 11/13] unet's `sample_size` attribute is to accept tuple(h, w) in `StableDiffusionPipeline` (#10181) --- .../models/unets/unet_2d_condition.py | 2 +- .../pipeline_stable_diffusion.py | 21 ++++++++++++++++--- .../stable_diffusion/test_stable_diffusion.py | 8 +++++++ 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 4f55df32b738..e488f5897ebc 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -170,7 +170,7 @@ class conditioning with `class_embed_type` equal to `None`. @register_to_config def __init__( self, - sample_size: Optional[int] = None, + sample_size: Optional[Union[int, Tuple[int, int]]] = None, in_channels: int = 4, out_channels: int = 4, center_input_sample: bool = False, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 4fd6a43a955a..ac6c8253e432 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -255,7 +255,12 @@ def __init__( is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + self._is_unet_config_sample_size_int = isinstance(unet.config.sample_size, int) + is_unet_sample_size_less_64 = ( + hasattr(unet.config, "sample_size") + and self._is_unet_config_sample_size_int + and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -902,8 +907,18 @@ def __call__( callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor + if not height or not width: + height = ( + self.unet.config.sample_size + if self._is_unet_config_sample_size_int + else self.unet.config.sample_size[0] + ) + width = ( + self.unet.config.sample_size + if self._is_unet_config_sample_size_int + else self.unet.config.sample_size[1] + ) + height, width = height * self.vae_scale_factor, width * self.vae_scale_factor # to deal with lora scaling and other possible forward hooks # 1. Check inputs. Raise error if not correct diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index f37d598c8387..ccd5567106d2 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -840,6 +840,14 @@ def callback_on_step_end(pipe, i, t, callback_kwargs): # they should be the same assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4) + def test_pipeline_accept_tuple_type_unet_sample_size(self): + # the purpose of this test is to see whether the pipeline would accept a unet with the tuple-typed sample size + sd_repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5" + sample_size = [60, 80] + customised_unet = UNet2DConditionModel(sample_size=sample_size) + pipe = StableDiffusionPipeline.from_pretrained(sd_repo_id, unet=customised_unet) + assert pipe.unet.config.sample_size == sample_size + @slow @require_torch_gpu From 648d968cfc69074eaf51df3d337100f9805b030e Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Thu, 19 Dec 2024 16:45:45 -0800 Subject: [PATCH 12/13] Enable Gradient Checkpointing for UNet2DModel (New) (#7201) * Port UNet2DModel gradient checkpointing code from #6718. --------- Co-authored-by: Sayak Paul Co-authored-by: Vincent Neemie <92559302+VincentNeemie@users.noreply.github.com> Co-authored-by: Patrick von Platen Co-authored-by: Dhruv Nair Co-authored-by: hlky --- src/diffusers/models/unets/unet_2d.py | 6 ++ src/diffusers/models/unets/unet_2d_blocks.py | 83 +++++++++++++++++-- .../versatile_diffusion/modeling_text_unet.py | 29 ++++++- .../test_models_autoencoder_kl.py | 2 +- ..._models_autoencoder_kl_temporal_decoder.py | 2 +- tests/models/test_modeling_common.py | 4 +- tests/models/unets/test_models_unet_2d.py | 42 ++++++++++ 7 files changed, 154 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/unets/unet_2d.py b/src/diffusers/models/unets/unet_2d.py index d05af686dede..bec62ce5cf45 100644 --- a/src/diffusers/models/unets/unet_2d.py +++ b/src/diffusers/models/unets/unet_2d.py @@ -89,6 +89,8 @@ class UNet2DModel(ModelMixin, ConfigMixin): conditioning with `class_embed_type` equal to `None`. """ + _supports_gradient_checkpointing = True + @register_to_config def __init__( self, @@ -241,6 +243,10 @@ def __init__( self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + def forward( self, sample: torch.Tensor, diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index b9d186ac1aa6..b4e0cea7c71d 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -731,12 +731,35 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) + self.gradient_checkpointing = False + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - if attn is not None: - hidden_states = attn(hidden_states, temb=temb) - hidden_states = resnet(hidden_states, temb) + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) return hidden_states @@ -1116,6 +1139,8 @@ def __init__( else: self.downsamplers = None + self.gradient_checkpointing = False + def forward( self, hidden_states: torch.Tensor, @@ -1130,9 +1155,30 @@ def forward( output_states = () for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, **cross_attention_kwargs) - output_states = output_states + (hidden_states,) + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn(hidden_states, **cross_attention_kwargs) + output_states = output_states + (hidden_states,) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, **cross_attention_kwargs) + output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: @@ -2354,6 +2400,7 @@ def __init__( else: self.upsamplers = None + self.gradient_checkpointing = False self.resolution_idx = resolution_idx def forward( @@ -2375,8 +2422,28 @@ def forward( res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn(hidden_states) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 107a5a45bfa2..0fd8875a88a1 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -2223,12 +2223,35 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) + self.gradient_checkpointing = False + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - if attn is not None: - hidden_states = attn(hidden_states, temb=temb) - hidden_states = resnet(hidden_states, temb) + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) return hidden_states diff --git a/tests/models/autoencoders/test_models_autoencoder_kl.py b/tests/models/autoencoders/test_models_autoencoder_kl.py index 52bf5aba204b..c584bdcf56a2 100644 --- a/tests/models/autoencoders/test_models_autoencoder_kl.py +++ b/tests/models/autoencoders/test_models_autoencoder_kl.py @@ -146,7 +146,7 @@ def test_enable_disable_slicing(self): ) def test_gradient_checkpointing_is_applied(self): - expected_set = {"Decoder", "Encoder"} + expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) def test_from_pretrained_hub(self): diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py b/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py index 4308cb64896e..cf80ff50443e 100644 --- a/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py +++ b/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py @@ -65,7 +65,7 @@ def prepare_init_args_and_inputs_for_common(self): return init_dict, inputs_dict def test_gradient_checkpointing_is_applied(self): - expected_set = {"Encoder", "TemporalDecoder"} + expected_set = {"Encoder", "TemporalDecoder", "UNetMidBlock2D"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) @unittest.skip("Test unsupported.") diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index a7594f2ea13f..91a462d5878e 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -803,7 +803,7 @@ def test_enable_disable_gradient_checkpointing(self): self.assertFalse(model.is_gradient_checkpointing) @require_torch_accelerator_with_training - def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5): + def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip: set[str] = {}): if not self.model_class._supports_gradient_checkpointing: return # Skip test if model does not support gradient checkpointing @@ -850,6 +850,8 @@ def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_ for name, param in named_params.items(): if "post_quant_conv" in name: continue + if name in skip: + continue self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol)) @unittest.skipIf(torch_device == "mps", "This test is not supported for MPS devices.") diff --git a/tests/models/unets/test_models_unet_2d.py b/tests/models/unets/test_models_unet_2d.py index 5f827f274224..ddf5f53511f7 100644 --- a/tests/models/unets/test_models_unet_2d.py +++ b/tests/models/unets/test_models_unet_2d.py @@ -105,6 +105,23 @@ def test_mid_block_attn_groups(self): expected_shape = inputs_dict["sample"].shape self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "AttnUpBlock2D", + "AttnDownBlock2D", + "UNetMidBlock2D", + "UpBlock2D", + "DownBlock2D", + } + + # NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim` + attention_head_dim = 8 + block_out_channels = (16, 32) + + super().test_gradient_checkpointing_is_applied( + expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels + ) + class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = UNet2DModel @@ -220,6 +237,17 @@ def test_output_pretrained(self): self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3)) + def test_gradient_checkpointing_is_applied(self): + expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"} + + # NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim` + attention_head_dim = 32 + block_out_channels = (32, 64) + + super().test_gradient_checkpointing_is_applied( + expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels + ) + class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = UNet2DModel @@ -329,3 +357,17 @@ def test_output_pretrained_ve_large(self): def test_forward_with_norm_groups(self): # not required for this model pass + + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "UNetMidBlock2D", + } + + block_out_channels = (32, 64, 64, 64) + + super().test_gradient_checkpointing_is_applied( + expected_set=expected_set, block_out_channels=block_out_channels + ) + + def test_effective_gradient_checkpointing(self): + super().test_effective_gradient_checkpointing(skip={"time_proj.weight"}) From 319124847216a57a6ae12b567689aa72b28f1c02 Mon Sep 17 00:00:00 2001 From: Daniel Regado <35548192+guiyrt@users.noreply.github.com> Date: Fri, 20 Dec 2024 00:48:18 +0000 Subject: [PATCH 13/13] [WIP] SD3.5 IP-Adapter Pipeline Integration (#9987) * Added support for single IPAdapter on SD3.5 pipeline --------- Co-authored-by: hlky Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: YiYi Xu --- docs/source/en/_toctree.yml | 2 + docs/source/en/api/attnprocessor.md | 2 + docs/source/en/api/loaders/ip_adapter.md | 6 + docs/source/en/api/loaders/transformer_sd3.md | 29 ++ .../stable_diffusion/stable_diffusion_3.md | 69 ++++- src/diffusers/loaders/__init__.py | 12 +- src/diffusers/loaders/ip_adapter.py | 251 +++++++++++++++++- src/diffusers/loaders/transformer_sd3.py | 89 +++++++ src/diffusers/models/attention_processor.py | 172 ++++++++++++ src/diffusers/models/embeddings.py | 181 +++++++++++++ .../models/transformers/transformer_sd3.py | 16 +- .../pipeline_stable_diffusion_3.py | 129 ++++++++- .../test_pipeline_stable_diffusion_3.py | 2 + 13 files changed, 935 insertions(+), 25 deletions(-) create mode 100644 docs/source/en/api/loaders/transformer_sd3.md create mode 100644 src/diffusers/loaders/transformer_sd3.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 27e9fe5e191b..6ac66db73026 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -238,6 +238,8 @@ title: Textual Inversion - local: api/loaders/unet title: UNet + - local: api/loaders/transformer_sd3 + title: SD3Transformer2D - local: api/loaders/peft title: PEFT title: Loaders diff --git a/docs/source/en/api/attnprocessor.md b/docs/source/en/api/attnprocessor.md index fee0d7e35764..8bdffc330567 100644 --- a/docs/source/en/api/attnprocessor.md +++ b/docs/source/en/api/attnprocessor.md @@ -86,6 +86,8 @@ An attention processor is a class for applying different types of attention mech [[autodoc]] models.attention_processor.IPAdapterAttnProcessor2_0 +[[autodoc]] models.attention_processor.SD3IPAdapterJointAttnProcessor2_0 + ## JointAttnProcessor2_0 [[autodoc]] models.attention_processor.JointAttnProcessor2_0 diff --git a/docs/source/en/api/loaders/ip_adapter.md b/docs/source/en/api/loaders/ip_adapter.md index a10f30ef8e5b..946a8b1af875 100644 --- a/docs/source/en/api/loaders/ip_adapter.md +++ b/docs/source/en/api/loaders/ip_adapter.md @@ -24,6 +24,12 @@ Learn how to load an IP-Adapter checkpoint and image in the IP-Adapter [loading] [[autodoc]] loaders.ip_adapter.IPAdapterMixin +## SD3IPAdapterMixin + +[[autodoc]] loaders.ip_adapter.SD3IPAdapterMixin + - all + - is_ip_adapter_active + ## IPAdapterMaskProcessor [[autodoc]] image_processor.IPAdapterMaskProcessor \ No newline at end of file diff --git a/docs/source/en/api/loaders/transformer_sd3.md b/docs/source/en/api/loaders/transformer_sd3.md new file mode 100644 index 000000000000..4fc9603054b4 --- /dev/null +++ b/docs/source/en/api/loaders/transformer_sd3.md @@ -0,0 +1,29 @@ + + +# SD3Transformer2D + +This class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and SD3Transformer2DModel, check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead. + +The [`SD3Transformer2DLoadersMixin`] class currently only loads IP-Adapter weights, but will be used in the future to save weights and load LoRAs. + + + +To learn more about how to load LoRA weights, see the [LoRA](../../using-diffusers/loading_adapters#lora) loading guide. + + + +## SD3Transformer2DLoadersMixin + +[[autodoc]] loaders.transformer_sd3.SD3Transformer2DLoadersMixin + - all + - _load_ip_adapter_weights \ No newline at end of file diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md index 8170c5280d38..eb67964ab0bd 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md +++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md @@ -59,9 +59,76 @@ image.save("sd3_hello_world.png") - [`stabilityai/stable-diffusion-3.5-large`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large) - [`stabilityai/stable-diffusion-3.5-large-turbo`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large-turbo) +## Image Prompting with IP-Adapters + +An IP-Adapter lets you prompt SD3 with images, in addition to the text prompt. This is especially useful when describing complex concepts that are difficult to articulate through text alone and you have reference images. To load and use an IP-Adapter, you need: + +- `image_encoder`: Pre-trained vision model used to obtain image features, usually a CLIP image encoder. +- `feature_extractor`: Image processor that prepares the input image for the chosen `image_encoder`. +- `ip_adapter_id`: Checkpoint containing parameters of image cross attention layers and image projection. + +IP-Adapters are trained for a specific model architecture, so they also work in finetuned variations of the base model. You can use the [`~SD3IPAdapterMixin.set_ip_adapter_scale`] function to adjust how strongly the output aligns with the image prompt. The higher the value, the more closely the model follows the image prompt. A default value of 0.5 is typically a good balance, ensuring the model considers both the text and image prompts equally. + +```python +import torch +from PIL import Image + +from diffusers import StableDiffusion3Pipeline +from transformers import SiglipVisionModel, SiglipImageProcessor + +image_encoder_id = "google/siglip-so400m-patch14-384" +ip_adapter_id = "InstantX/SD3.5-Large-IP-Adapter" + +feature_extractor = SiglipImageProcessor.from_pretrained( + image_encoder_id, + torch_dtype=torch.float16 +) +image_encoder = SiglipVisionModel.from_pretrained( + image_encoder_id, + torch_dtype=torch.float16 +).to( "cuda") + +pipe = StableDiffusion3Pipeline.from_pretrained( + "stabilityai/stable-diffusion-3.5-large", + torch_dtype=torch.float16, + feature_extractor=feature_extractor, + image_encoder=image_encoder, +).to("cuda") + +pipe.load_ip_adapter(ip_adapter_id) +pipe.set_ip_adapter_scale(0.6) + +ref_img = Image.open("image.jpg").convert('RGB') + +image = pipe( + width=1024, + height=1024, + prompt="a cat", + negative_prompt="lowres, low quality, worst quality", + num_inference_steps=24, + guidance_scale=5.0, + ip_adapter_image=ref_img +).images[0] + +image.save("result.jpg") +``` + +
+ +
IP-Adapter examples with prompt "a cat"
+
+ + + + +Check out [IP-Adapter](../../../using-diffusers/ip_adapter) to learn more about how IP-Adapters work. + + + + ## Memory Optimisations for SD3 -SD3 uses three text encoders, one if which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware. +SD3 uses three text encoders, one of which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware. ### Running Inference with Model Offloading diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 6ea382d721de..c7ea0be55db2 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -56,6 +56,7 @@ def text_encoder_attn_modules(text_encoder): if is_torch_available(): _import_structure["single_file_model"] = ["FromOriginalModelMixin"] + _import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"] _import_structure["unet"] = ["UNet2DConditionLoadersMixin"] _import_structure["utils"] = ["AttnProcsLayers"] if is_transformers_available(): @@ -74,7 +75,10 @@ def text_encoder_attn_modules(text_encoder): "SanaLoraLoaderMixin", ] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] - _import_structure["ip_adapter"] = ["IPAdapterMixin"] + _import_structure["ip_adapter"] = [ + "IPAdapterMixin", + "SD3IPAdapterMixin", + ] _import_structure["peft"] = ["PeftAdapterMixin"] @@ -82,11 +86,15 @@ def text_encoder_attn_modules(text_encoder): if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): from .single_file_model import FromOriginalModelMixin + from .transformer_sd3 import SD3Transformer2DLoadersMixin from .unet import UNet2DConditionLoadersMixin from .utils import AttnProcsLayers if is_transformers_available(): - from .ip_adapter import IPAdapterMixin + from .ip_adapter import ( + IPAdapterMixin, + SD3IPAdapterMixin, + ) from .lora_pipeline import ( AmusedLoraLoaderMixin, CogVideoXLoraLoaderMixin, diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index ca460f948e6f..11ce4f1634d7 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -33,15 +33,18 @@ if is_transformers_available(): - from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor, SiglipVisionModel + +from ..models.attention_processor import ( + AttnProcessor, + AttnProcessor2_0, + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, + IPAdapterXFormersAttnProcessor, + JointAttnProcessor2_0, + SD3IPAdapterJointAttnProcessor2_0, +) - from ..models.attention_processor import ( - AttnProcessor, - AttnProcessor2_0, - IPAdapterAttnProcessor, - IPAdapterAttnProcessor2_0, - IPAdapterXFormersAttnProcessor, - ) logger = logging.get_logger(__name__) @@ -348,3 +351,235 @@ def unload_ip_adapter(self): else value.__class__() ) self.unet.set_attn_processor(attn_procs) + + +class SD3IPAdapterMixin: + """Mixin for handling StableDiffusion 3 IP Adapters.""" + + @property + def is_ip_adapter_active(self) -> bool: + """Checks if IP-Adapter is loaded and scale > 0. + + IP-Adapter scale controls the influence of the image prompt versus text prompt. When this value is set to 0, + the image context is irrelevant. + + Returns: + `bool`: True when IP-Adapter is loaded and any layer has scale > 0. + """ + scales = [ + attn_proc.scale + for attn_proc in self.transformer.attn_processors.values() + if isinstance(attn_proc, SD3IPAdapterJointAttnProcessor2_0) + ] + + return len(scales) > 0 and any(scale > 0 for scale in scales) + + @validate_hf_hub_args + def load_ip_adapter( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + weight_name: str = "ip-adapter.safetensors", + subfolder: Optional[str] = None, + image_encoder_folder: Optional[str] = "image_encoder", + **kwargs, + ) -> None: + """ + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + weight_name (`str`, defaults to "ip-adapter.safetensors"): + The name of the weight file to load. If a list is passed, it should have the same length as + `subfolder`. + subfolder (`str`, *optional*): + The subfolder location of a model file within a larger model repository on the Hub or locally. If a + list is passed, it should have the same length as `weight_name`. + image_encoder_folder (`str`, *optional*, defaults to `image_encoder`): + The subfolder location of the image encoder within a larger model repository on the Hub or locally. + Pass `None` to not load the image encoder. If the image encoder is located in a folder inside + `subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g. + `image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than + `subfolder`, you should pass the path to the folder that contains image encoder weights, for example, + `image_encoder_folder="different_subfolder/image_encoder"`. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + """ + # Load the main state dict first + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + if weight_name.endswith(".safetensors"): + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(model_file, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = load_state_dict(model_file) + else: + state_dict = pretrained_model_name_or_path_or_dict + + keys = list(state_dict.keys()) + if "image_proj" not in keys and "ip_adapter" not in keys: + raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") + + # Load image_encoder and feature_extractor here if they haven't been registered to the pipeline yet + if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: + if image_encoder_folder is not None: + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") + if image_encoder_folder.count("/") == 0: + image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix() + else: + image_encoder_subfolder = Path(image_encoder_folder).as_posix() + + # Commons args for loading image encoder and image processor + kwargs = { + "low_cpu_mem_usage": low_cpu_mem_usage, + "cache_dir": cache_dir, + "local_files_only": local_files_only, + } + + self.register_modules( + feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs).to( + self.device, dtype=self.dtype + ), + image_encoder=SiglipVisionModel.from_pretrained(image_encoder_subfolder, **kwargs).to( + self.device, dtype=self.dtype + ), + ) + else: + raise ValueError( + "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict." + ) + else: + logger.warning( + "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter." + "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead." + ) + + # Load IP-Adapter into transformer + self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage) + + def set_ip_adapter_scale(self, scale: float) -> None: + """ + Set IP-Adapter scale, which controls image prompt conditioning. A value of 1.0 means the model is only + conditioned on the image prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages + the model to produce more diverse images, but they may not be as aligned with the image prompt. + + Example: + + ```python + >>> # Assuming `pipeline` is already loaded with the IP Adapter weights. + >>> pipeline.set_ip_adapter_scale(0.6) + >>> ... + ``` + + Args: + scale (float): + IP-Adapter scale to be set. + + """ + for attn_processor in self.transformer.attn_processors.values(): + if isinstance(attn_processor, SD3IPAdapterJointAttnProcessor2_0): + attn_processor.scale = scale + + def unload_ip_adapter(self) -> None: + """ + Unloads the IP Adapter weights. + + Example: + + ```python + >>> # Assuming `pipeline` is already loaded with the IP Adapter weights. + >>> pipeline.unload_ip_adapter() + >>> ... + ``` + """ + # Remove image encoder + if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None: + self.image_encoder = None + self.register_to_config(image_encoder=None) + + # Remove feature extractor + if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None: + self.feature_extractor = None + self.register_to_config(feature_extractor=None) + + # Remove image projection + self.transformer.image_proj = None + + # Restore original attention processors layers + attn_procs = { + name: ( + JointAttnProcessor2_0() if isinstance(value, SD3IPAdapterJointAttnProcessor2_0) else value.__class__() + ) + for name, value in self.transformer.attn_processors.items() + } + self.transformer.set_attn_processor(attn_procs) diff --git a/src/diffusers/loaders/transformer_sd3.py b/src/diffusers/loaders/transformer_sd3.py new file mode 100644 index 000000000000..435d1da06ca1 --- /dev/null +++ b/src/diffusers/loaders/transformer_sd3.py @@ -0,0 +1,89 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict + +from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0 +from ..models.embeddings import IPAdapterTimeImageProjection +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta + + +class SD3Transformer2DLoadersMixin: + """Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`.""" + + def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None: + """Sets IP-Adapter attention processors, image projection, and loads state_dict. + + Args: + state_dict (`Dict`): + State dict with keys "ip_adapter", which contains parameters for attention processors, and + "image_proj", which contains parameters for image projection net. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + """ + # IP-Adapter cross attention parameters + hidden_size = self.config.attention_head_dim * self.config.num_attention_heads + ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads + timesteps_emb_dim = state_dict["ip_adapter"]["0.norm_ip.linear.weight"].shape[1] + + # Dict where key is transformer layer index, value is attention processor's state dict + # ip_adapter state dict keys example: "0.norm_ip.linear.weight" + layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))} + for key, weights in state_dict["ip_adapter"].items(): + idx, name = key.split(".", maxsplit=1) + layer_state_dict[int(idx)][name] = weights + + # Create IP-Adapter attention processor + attn_procs = {} + for idx, name in enumerate(self.attn_processors.keys()): + attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0( + hidden_size=hidden_size, + ip_hidden_states_dim=ip_hidden_states_dim, + head_dim=self.config.attention_head_dim, + timesteps_emb_dim=timesteps_emb_dim, + ).to(self.device, dtype=self.dtype) + + if not low_cpu_mem_usage: + attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True) + else: + load_model_dict_into_meta( + attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype + ) + + self.set_attn_processor(attn_procs) + + # Image projetion parameters + embed_dim = state_dict["image_proj"]["proj_in.weight"].shape[1] + output_dim = state_dict["image_proj"]["proj_out.weight"].shape[0] + hidden_dim = state_dict["image_proj"]["proj_in.weight"].shape[0] + heads = state_dict["image_proj"]["layers.0.attn.to_q.weight"].shape[0] // 64 + num_queries = state_dict["image_proj"]["latents"].shape[1] + timestep_in_dim = state_dict["image_proj"]["time_embedding.linear_1.weight"].shape[1] + + # Image projection + self.image_proj = IPAdapterTimeImageProjection( + embed_dim=embed_dim, + output_dim=output_dim, + hidden_dim=hidden_dim, + heads=heads, + num_queries=num_queries, + timestep_in_dim=timestep_in_dim, + ).to(device=self.device, dtype=self.dtype) + + if not low_cpu_mem_usage: + self.image_proj.load_state_dict(state_dict["image_proj"], strict=True) + else: + load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 05cbaa40e693..ed0dd4f71d27 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -5243,6 +5243,177 @@ def __call__( return hidden_states +class SD3IPAdapterJointAttnProcessor2_0(torch.nn.Module): + """ + Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections, with + additional image-based information and timestep embeddings. + + Args: + hidden_size (`int`): + The number of hidden channels. + ip_hidden_states_dim (`int`): + The image feature dimension. + head_dim (`int`): + The number of head channels. + timesteps_emb_dim (`int`, defaults to 1280): + The number of input channels for timestep embedding. + scale (`float`, defaults to 0.5): + IP-Adapter scale. + """ + + def __init__( + self, + hidden_size: int, + ip_hidden_states_dim: int, + head_dim: int, + timesteps_emb_dim: int = 1280, + scale: float = 0.5, + ): + super().__init__() + + # To prevent circular import + from .normalization import AdaLayerNorm, RMSNorm + + self.norm_ip = AdaLayerNorm(timesteps_emb_dim, output_dim=ip_hidden_states_dim * 2, norm_eps=1e-6, chunk_dim=1) + self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False) + self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False) + self.norm_q = RMSNorm(head_dim, 1e-6) + self.norm_k = RMSNorm(head_dim, 1e-6) + self.norm_ip_k = RMSNorm(head_dim, 1e-6) + self.scale = scale + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + ip_hidden_states: torch.FloatTensor = None, + temb: torch.FloatTensor = None, + ) -> torch.FloatTensor: + """ + Perform the attention computation, integrating image features (if provided) and timestep embeddings. + + If `ip_hidden_states` is `None`, this is equivalent to using JointAttnProcessor2_0. + + Args: + attn (`Attention`): + Attention instance. + hidden_states (`torch.FloatTensor`): + Input `hidden_states`. + encoder_hidden_states (`torch.FloatTensor`, *optional*): + The encoder hidden states. + attention_mask (`torch.FloatTensor`, *optional*): + Attention mask. + ip_hidden_states (`torch.FloatTensor`, *optional*): + Image embeddings. + temb (`torch.FloatTensor`, *optional*): + Timestep embeddings. + + Returns: + `torch.FloatTensor`: Output hidden states. + """ + residual = hidden_states + + batch_size = hidden_states.shape[0] + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + img_query = query + img_key = key + img_value = value + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # `context` projections. + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + # Split the attention outputs. + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + # IP Adapter + if self.scale != 0 and ip_hidden_states is not None: + # Norm image features + norm_ip_hidden_states = self.norm_ip(ip_hidden_states, temb=temb) + + # To k and v + ip_key = self.to_k_ip(norm_ip_hidden_states) + ip_value = self.to_v_ip(norm_ip_hidden_states) + + # Reshape + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # Norm + query = self.norm_q(img_query) + img_key = self.norm_k(img_key) + ip_key = self.norm_ip_k(ip_key) + + # cat img + key = torch.cat([img_key, ip_key], dim=2) + value = torch.cat([img_value, ip_value], dim=2) + + ip_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + ip_hidden_states = ip_hidden_states.transpose(1, 2).view(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + ip_hidden_states * self.scale + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states + + class PAGIdentitySelfAttnProcessor2_0: r""" Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0). @@ -5772,6 +5943,7 @@ def __call__( IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor, + SD3IPAdapterJointAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0, PAGCFGIdentitySelfAttnProcessor2_0, LoRAAttnProcessor, diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 69b3ee8466f4..f1b339e6180b 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -2396,6 +2396,187 @@ def forward(self, id_embeds: torch.Tensor) -> torch.Tensor: return out +class IPAdapterTimeImageProjectionBlock(nn.Module): + """Block for IPAdapterTimeImageProjection. + + Args: + hidden_dim (`int`, defaults to 1280): + The number of hidden channels. + dim_head (`int`, defaults to 64): + The number of head channels. + heads (`int`, defaults to 20): + Parallel attention heads. + ffn_ratio (`int`, defaults to 4): + The expansion ratio of feedforward network hidden layer channels. + """ + + def __init__( + self, + hidden_dim: int = 1280, + dim_head: int = 64, + heads: int = 20, + ffn_ratio: int = 4, + ) -> None: + super().__init__() + from .attention import FeedForward + + self.ln0 = nn.LayerNorm(hidden_dim) + self.ln1 = nn.LayerNorm(hidden_dim) + self.attn = Attention( + query_dim=hidden_dim, + cross_attention_dim=hidden_dim, + dim_head=dim_head, + heads=heads, + bias=False, + out_bias=False, + ) + self.ff = FeedForward(hidden_dim, hidden_dim, activation_fn="gelu", mult=ffn_ratio, bias=False) + + # AdaLayerNorm + self.adaln_silu = nn.SiLU() + self.adaln_proj = nn.Linear(hidden_dim, 4 * hidden_dim) + self.adaln_norm = nn.LayerNorm(hidden_dim) + + # Set attention scale and fuse KV + self.attn.scale = 1 / math.sqrt(math.sqrt(dim_head)) + self.attn.fuse_projections() + self.attn.to_k = None + self.attn.to_v = None + + def forward(self, x: torch.Tensor, latents: torch.Tensor, timestep_emb: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x (`torch.Tensor`): + Image features. + latents (`torch.Tensor`): + Latent features. + timestep_emb (`torch.Tensor`): + Timestep embedding. + + Returns: + `torch.Tensor`: Output latent features. + """ + + # Shift and scale for AdaLayerNorm + emb = self.adaln_proj(self.adaln_silu(timestep_emb)) + shift_msa, scale_msa, shift_mlp, scale_mlp = emb.chunk(4, dim=1) + + # Fused Attention + residual = latents + x = self.ln0(x) + latents = self.ln1(latents) * (1 + scale_msa[:, None]) + shift_msa[:, None] + + batch_size = latents.shape[0] + + query = self.attn.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + key, value = self.attn.to_kv(kv_input).chunk(2, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // self.attn.heads + + query = query.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2) + + weight = (query * self.attn.scale) @ (key * self.attn.scale).transpose(-2, -1) + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + latents = weight @ value + + latents = latents.transpose(1, 2).reshape(batch_size, -1, self.attn.heads * head_dim) + latents = self.attn.to_out[0](latents) + latents = self.attn.to_out[1](latents) + latents = latents + residual + + ## FeedForward + residual = latents + latents = self.adaln_norm(latents) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + return self.ff(latents) + residual + + +# Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +class IPAdapterTimeImageProjection(nn.Module): + """Resampler of SD3 IP-Adapter with timestep embedding. + + Args: + embed_dim (`int`, defaults to 1152): + The feature dimension. + output_dim (`int`, defaults to 2432): + The number of output channels. + hidden_dim (`int`, defaults to 1280): + The number of hidden channels. + depth (`int`, defaults to 4): + The number of blocks. + dim_head (`int`, defaults to 64): + The number of head channels. + heads (`int`, defaults to 20): + Parallel attention heads. + num_queries (`int`, defaults to 64): + The number of queries. + ffn_ratio (`int`, defaults to 4): + The expansion ratio of feedforward network hidden layer channels. + timestep_in_dim (`int`, defaults to 320): + The number of input channels for timestep embedding. + timestep_flip_sin_to_cos (`bool`, defaults to True): + Flip the timestep embedding order to `cos, sin` (if True) or `sin, cos` (if False). + timestep_freq_shift (`int`, defaults to 0): + Controls the timestep delta between frequencies between dimensions. + """ + + def __init__( + self, + embed_dim: int = 1152, + output_dim: int = 2432, + hidden_dim: int = 1280, + depth: int = 4, + dim_head: int = 64, + heads: int = 20, + num_queries: int = 64, + ffn_ratio: int = 4, + timestep_in_dim: int = 320, + timestep_flip_sin_to_cos: bool = True, + timestep_freq_shift: int = 0, + ) -> None: + super().__init__() + self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dim) / hidden_dim**0.5) + self.proj_in = nn.Linear(embed_dim, hidden_dim) + self.proj_out = nn.Linear(hidden_dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + self.layers = nn.ModuleList( + [IPAdapterTimeImageProjectionBlock(hidden_dim, dim_head, heads, ffn_ratio) for _ in range(depth)] + ) + self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift) + self.time_embedding = TimestepEmbedding(timestep_in_dim, hidden_dim, act_fn="silu") + + def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass. + + Args: + x (`torch.Tensor`): + Image features. + timestep (`torch.Tensor`): + Timestep in denoising process. + Returns: + `Tuple`[`torch.Tensor`, `torch.Tensor`]: The pair (latents, timestep_emb). + """ + timestep_emb = self.time_proj(timestep).to(dtype=x.dtype) + timestep_emb = self.time_embedding(timestep_emb) + + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + x = x + timestep_emb[:, None] + + for block in self.layers: + latents = block(x, latents, timestep_emb) + + latents = self.proj_out(latents) + latents = self.norm_out(latents) + + return latents, timestep_emb + + class MultiIPAdapterImageProjection(nn.Module): def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]): super().__init__() diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 79c4069e9a37..415540ef7f6a 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -18,7 +18,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin from ...models.attention import FeedForward, JointTransformerBlock from ...models.attention_processor import ( Attention, @@ -103,7 +103,9 @@ def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor): return hidden_states -class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class SD3Transformer2DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin +): """ The Transformer model introduced in Stable Diffusion 3. @@ -349,8 +351,8 @@ def forward( Input `hidden_states`. encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected - from the embeddings of input conditions. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): + Embeddings projected from the embeddings of input conditions. timestep (`torch.LongTensor`): Used to indicate denoising step. block_controlnet_hidden_states (`list` of `torch.Tensor`): @@ -390,6 +392,12 @@ def forward( temb = self.time_text_embed(timestep, pooled_projections) encoder_hidden_states = self.context_embedder(encoder_hidden_states) + if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: + ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") + ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep) + + joint_attention_kwargs.update(ip_hidden_states=ip_hidden_states, temb=ip_temb) + for index_block, block in enumerate(self.transformer_blocks): # Skip specified layers is_skip = True if skip_layers is not None and index_block in skip_layers else False diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 0a51dcbc1261..a53d786798ca 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -1,4 +1,4 @@ -# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,14 +17,16 @@ import torch from transformers import ( + BaseImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, + PreTrainedModel, T5EncoderModel, T5TokenizerFast, ) -from ...image_processor import VaeImageProcessor -from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import SD3Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -142,7 +144,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin): +class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin): r""" Args: transformer ([`SD3Transformer2DModel`]): @@ -174,10 +176,14 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle tokenizer_3 (`T5TokenizerFast`): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + image_encoder (`PreTrainedModel`, *optional*): + Pre-trained Vision Model for IP Adapter. + feature_extractor (`BaseImageProcessor`, *optional*): + Image processor for IP Adapter. """ - model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" - _optional_components = [] + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] def __init__( @@ -191,6 +197,8 @@ def __init__( tokenizer_2: CLIPTokenizer, text_encoder_3: T5EncoderModel, tokenizer_3: T5TokenizerFast, + image_encoder: PreTrainedModel = None, + feature_extractor: BaseImageProcessor = None, ): super().__init__() @@ -204,6 +212,8 @@ def __init__( tokenizer_3=tokenizer_3, transformer=transformer, scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, ) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 @@ -683,6 +693,83 @@ def num_timesteps(self): def interrupt(self): return self._interrupt + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image + def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor: + """Encodes the given image into a feature representation using a pre-trained image encoder. + + Args: + image (`PipelineImageInput`): + Input image to be encoded. + device: (`torch.device`): + Torch device. + + Returns: + `torch.Tensor`: The encoded image feature representation. + """ + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=self.dtype) + + return self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + ) -> torch.Tensor: + """Prepares image embeddings for use in the IP-Adapter. + + Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed. + + Args: + ip_adapter_image (`PipelineImageInput`, *optional*): + The input image to extract features from for IP-Adapter. + ip_adapter_image_embeds (`torch.Tensor`, *optional*): + Precomputed image embeddings. + device: (`torch.device`, *optional*): + Torch device. + num_images_per_prompt (`int`, defaults to 1): + Number of images that should be generated per prompt. + do_classifier_free_guidance (`bool`, defaults to True): + Whether to use classifier free guidance or not. + """ + device = device or self._execution_device + + if ip_adapter_image_embeds is not None: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2) + else: + single_image_embeds = ip_adapter_image_embeds + elif ip_adapter_image is not None: + single_image_embeds = self.encode_image(ip_adapter_image, device) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.zeros_like(single_image_embeds) + else: + raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.") + + image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0) + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + + return image_embeds.to(device=device) + + def enable_sequential_cpu_offload(self, *args, **kwargs): + if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload: + logger.warning( + "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses " + "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling " + "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`." + ) + + super().enable_sequential_cpu_offload(*args, **kwargs) + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -705,6 +792,8 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -713,9 +802,9 @@ def __call__( callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 256, skip_guidance_layers: List[int] = None, - skip_layer_guidance_scale: int = 2.8, - skip_layer_guidance_stop: int = 0.2, - skip_layer_guidance_start: int = 0.01, + skip_layer_guidance_scale: float = 2.8, + skip_layer_guidance_stop: float = 0.2, + skip_layer_guidance_start: float = 0.01, mu: Optional[float] = None, ): r""" @@ -781,6 +870,11 @@ def __call__( Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. + ip_adapter_image (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images, + emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to + `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -938,7 +1032,22 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) - # 6. Denoising loop + # 6. Prepare image embeddings + if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds} + else: + self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds) + + # 7. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py index 07ce5487f256..a6f718ae4fbb 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py @@ -103,6 +103,8 @@ def get_dummy_components(self): "tokenizer_3": tokenizer_3, "transformer": transformer, "vae": vae, + "image_encoder": None, + "feature_extractor": None, } def get_dummy_inputs(self, device, seed=0):