diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 27e9fe5e191b..6ac66db73026 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -238,6 +238,8 @@
title: Textual Inversion
- local: api/loaders/unet
title: UNet
+ - local: api/loaders/transformer_sd3
+ title: SD3Transformer2D
- local: api/loaders/peft
title: PEFT
title: Loaders
diff --git a/docs/source/en/api/attnprocessor.md b/docs/source/en/api/attnprocessor.md
index fee0d7e35764..8bdffc330567 100644
--- a/docs/source/en/api/attnprocessor.md
+++ b/docs/source/en/api/attnprocessor.md
@@ -86,6 +86,8 @@ An attention processor is a class for applying different types of attention mech
[[autodoc]] models.attention_processor.IPAdapterAttnProcessor2_0
+[[autodoc]] models.attention_processor.SD3IPAdapterJointAttnProcessor2_0
+
## JointAttnProcessor2_0
[[autodoc]] models.attention_processor.JointAttnProcessor2_0
diff --git a/docs/source/en/api/loaders/ip_adapter.md b/docs/source/en/api/loaders/ip_adapter.md
index a10f30ef8e5b..946a8b1af875 100644
--- a/docs/source/en/api/loaders/ip_adapter.md
+++ b/docs/source/en/api/loaders/ip_adapter.md
@@ -24,6 +24,12 @@ Learn how to load an IP-Adapter checkpoint and image in the IP-Adapter [loading]
[[autodoc]] loaders.ip_adapter.IPAdapterMixin
+## SD3IPAdapterMixin
+
+[[autodoc]] loaders.ip_adapter.SD3IPAdapterMixin
+ - all
+ - is_ip_adapter_active
+
## IPAdapterMaskProcessor
[[autodoc]] image_processor.IPAdapterMaskProcessor
\ No newline at end of file
diff --git a/docs/source/en/api/loaders/transformer_sd3.md b/docs/source/en/api/loaders/transformer_sd3.md
new file mode 100644
index 000000000000..4fc9603054b4
--- /dev/null
+++ b/docs/source/en/api/loaders/transformer_sd3.md
@@ -0,0 +1,29 @@
+
+
+# SD3Transformer2D
+
+This class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and SD3Transformer2DModel, check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead.
+
+The [`SD3Transformer2DLoadersMixin`] class currently only loads IP-Adapter weights, but will be used in the future to save weights and load LoRAs.
+
+
+
+To learn more about how to load LoRA weights, see the [LoRA](../../using-diffusers/loading_adapters#lora) loading guide.
+
+
+
+## SD3Transformer2DLoadersMixin
+
+[[autodoc]] loaders.transformer_sd3.SD3Transformer2DLoadersMixin
+ - all
+ - _load_ip_adapter_weights
\ No newline at end of file
diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md
index 8170c5280d38..eb67964ab0bd 100644
--- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md
+++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md
@@ -59,9 +59,76 @@ image.save("sd3_hello_world.png")
- [`stabilityai/stable-diffusion-3.5-large`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large)
- [`stabilityai/stable-diffusion-3.5-large-turbo`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large-turbo)
+## Image Prompting with IP-Adapters
+
+An IP-Adapter lets you prompt SD3 with images, in addition to the text prompt. This is especially useful when describing complex concepts that are difficult to articulate through text alone and you have reference images. To load and use an IP-Adapter, you need:
+
+- `image_encoder`: Pre-trained vision model used to obtain image features, usually a CLIP image encoder.
+- `feature_extractor`: Image processor that prepares the input image for the chosen `image_encoder`.
+- `ip_adapter_id`: Checkpoint containing parameters of image cross attention layers and image projection.
+
+IP-Adapters are trained for a specific model architecture, so they also work in finetuned variations of the base model. You can use the [`~SD3IPAdapterMixin.set_ip_adapter_scale`] function to adjust how strongly the output aligns with the image prompt. The higher the value, the more closely the model follows the image prompt. A default value of 0.5 is typically a good balance, ensuring the model considers both the text and image prompts equally.
+
+```python
+import torch
+from PIL import Image
+
+from diffusers import StableDiffusion3Pipeline
+from transformers import SiglipVisionModel, SiglipImageProcessor
+
+image_encoder_id = "google/siglip-so400m-patch14-384"
+ip_adapter_id = "InstantX/SD3.5-Large-IP-Adapter"
+
+feature_extractor = SiglipImageProcessor.from_pretrained(
+ image_encoder_id,
+ torch_dtype=torch.float16
+)
+image_encoder = SiglipVisionModel.from_pretrained(
+ image_encoder_id,
+ torch_dtype=torch.float16
+).to( "cuda")
+
+pipe = StableDiffusion3Pipeline.from_pretrained(
+ "stabilityai/stable-diffusion-3.5-large",
+ torch_dtype=torch.float16,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+).to("cuda")
+
+pipe.load_ip_adapter(ip_adapter_id)
+pipe.set_ip_adapter_scale(0.6)
+
+ref_img = Image.open("image.jpg").convert('RGB')
+
+image = pipe(
+ width=1024,
+ height=1024,
+ prompt="a cat",
+ negative_prompt="lowres, low quality, worst quality",
+ num_inference_steps=24,
+ guidance_scale=5.0,
+ ip_adapter_image=ref_img
+).images[0]
+
+image.save("result.jpg")
+```
+
+
+
+
IP-Adapter examples with prompt "a cat"
+
+
+
+
+
+Check out [IP-Adapter](../../../using-diffusers/ip_adapter) to learn more about how IP-Adapters work.
+
+
+
+
## Memory Optimisations for SD3
-SD3 uses three text encoders, one if which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware.
+SD3 uses three text encoders, one of which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware.
### Running Inference with Model Offloading
diff --git a/examples/dreambooth/README_sana.md b/examples/dreambooth/README_sana.md
index fe861d62472b..d82529c64de8 100644
--- a/examples/dreambooth/README_sana.md
+++ b/examples/dreambooth/README_sana.md
@@ -73,7 +73,7 @@ This will also allow us to push the trained LoRA parameters to the Hugging Face
Now, we can launch training using:
```bash
-export MODEL_NAME="Efficient-Large-Model/Sana_1600M_1024px_diffusers"
+export MODEL_NAME="Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-sana-lora"
@@ -124,4 +124,4 @@ We provide several options for optimizing memory optimization:
* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.
* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
-Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana) of the `SanaPipeline` to know more about the models available under the SANA family and their preferred dtypes during inference.
\ No newline at end of file
+Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana) of the `SanaPipeline` to know more about the models available under the SANA family and their preferred dtypes during inference.
diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py
index b59150376599..c7ea0be55db2 100644
--- a/src/diffusers/loaders/__init__.py
+++ b/src/diffusers/loaders/__init__.py
@@ -56,6 +56,7 @@ def text_encoder_attn_modules(text_encoder):
if is_torch_available():
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
+ _import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
_import_structure["utils"] = ["AttnProcsLayers"]
if is_transformers_available():
@@ -70,10 +71,14 @@ def text_encoder_attn_modules(text_encoder):
"FluxLoraLoaderMixin",
"CogVideoXLoraLoaderMixin",
"Mochi1LoraLoaderMixin",
+ "HunyuanVideoLoraLoaderMixin",
"SanaLoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
- _import_structure["ip_adapter"] = ["IPAdapterMixin"]
+ _import_structure["ip_adapter"] = [
+ "IPAdapterMixin",
+ "SD3IPAdapterMixin",
+ ]
_import_structure["peft"] = ["PeftAdapterMixin"]
@@ -81,15 +86,20 @@ def text_encoder_attn_modules(text_encoder):
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .single_file_model import FromOriginalModelMixin
+ from .transformer_sd3 import SD3Transformer2DLoadersMixin
from .unet import UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers
if is_transformers_available():
- from .ip_adapter import IPAdapterMixin
+ from .ip_adapter import (
+ IPAdapterMixin,
+ SD3IPAdapterMixin,
+ )
from .lora_pipeline import (
AmusedLoraLoaderMixin,
CogVideoXLoraLoaderMixin,
FluxLoraLoaderMixin,
+ HunyuanVideoLoraLoaderMixin,
LoraLoaderMixin,
LTXVideoLoraLoaderMixin,
Mochi1LoraLoaderMixin,
diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py
index ca460f948e6f..11ce4f1634d7 100644
--- a/src/diffusers/loaders/ip_adapter.py
+++ b/src/diffusers/loaders/ip_adapter.py
@@ -33,15 +33,18 @@
if is_transformers_available():
- from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor, SiglipVisionModel
+
+from ..models.attention_processor import (
+ AttnProcessor,
+ AttnProcessor2_0,
+ IPAdapterAttnProcessor,
+ IPAdapterAttnProcessor2_0,
+ IPAdapterXFormersAttnProcessor,
+ JointAttnProcessor2_0,
+ SD3IPAdapterJointAttnProcessor2_0,
+)
- from ..models.attention_processor import (
- AttnProcessor,
- AttnProcessor2_0,
- IPAdapterAttnProcessor,
- IPAdapterAttnProcessor2_0,
- IPAdapterXFormersAttnProcessor,
- )
logger = logging.get_logger(__name__)
@@ -348,3 +351,235 @@ def unload_ip_adapter(self):
else value.__class__()
)
self.unet.set_attn_processor(attn_procs)
+
+
+class SD3IPAdapterMixin:
+ """Mixin for handling StableDiffusion 3 IP Adapters."""
+
+ @property
+ def is_ip_adapter_active(self) -> bool:
+ """Checks if IP-Adapter is loaded and scale > 0.
+
+ IP-Adapter scale controls the influence of the image prompt versus text prompt. When this value is set to 0,
+ the image context is irrelevant.
+
+ Returns:
+ `bool`: True when IP-Adapter is loaded and any layer has scale > 0.
+ """
+ scales = [
+ attn_proc.scale
+ for attn_proc in self.transformer.attn_processors.values()
+ if isinstance(attn_proc, SD3IPAdapterJointAttnProcessor2_0)
+ ]
+
+ return len(scales) > 0 and any(scale > 0 for scale in scales)
+
+ @validate_hf_hub_args
+ def load_ip_adapter(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ weight_name: str = "ip-adapter.safetensors",
+ subfolder: Optional[str] = None,
+ image_encoder_folder: Optional[str] = "image_encoder",
+ **kwargs,
+ ) -> None:
+ """
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+ weight_name (`str`, defaults to "ip-adapter.safetensors"):
+ The name of the weight file to load. If a list is passed, it should have the same length as
+ `subfolder`.
+ subfolder (`str`, *optional*):
+ The subfolder location of a model file within a larger model repository on the Hub or locally. If a
+ list is passed, it should have the same length as `weight_name`.
+ image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
+ The subfolder location of the image encoder within a larger model repository on the Hub or locally.
+ Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
+ `subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
+ `image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
+ `subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
+ `image_encoder_folder="different_subfolder/image_encoder"`.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ """
+ # Load the main state dict first
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
+
+ if low_cpu_mem_usage and not is_accelerate_available():
+ low_cpu_mem_usage = False
+ logger.warning(
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+ " install accelerate\n```\n."
+ )
+
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `low_cpu_mem_usage=False`."
+ )
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ if weight_name.endswith(".safetensors"):
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
+ with safe_open(model_file, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ if key.startswith("image_proj."):
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
+ elif key.startswith("ip_adapter."):
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
+ else:
+ state_dict = load_state_dict(model_file)
+ else:
+ state_dict = pretrained_model_name_or_path_or_dict
+
+ keys = list(state_dict.keys())
+ if "image_proj" not in keys and "ip_adapter" not in keys:
+ raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
+
+ # Load image_encoder and feature_extractor here if they haven't been registered to the pipeline yet
+ if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
+ if image_encoder_folder is not None:
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+ logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
+ if image_encoder_folder.count("/") == 0:
+ image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
+ else:
+ image_encoder_subfolder = Path(image_encoder_folder).as_posix()
+
+ # Commons args for loading image encoder and image processor
+ kwargs = {
+ "low_cpu_mem_usage": low_cpu_mem_usage,
+ "cache_dir": cache_dir,
+ "local_files_only": local_files_only,
+ }
+
+ self.register_modules(
+ feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs).to(
+ self.device, dtype=self.dtype
+ ),
+ image_encoder=SiglipVisionModel.from_pretrained(image_encoder_subfolder, **kwargs).to(
+ self.device, dtype=self.dtype
+ ),
+ )
+ else:
+ raise ValueError(
+ "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
+ )
+ else:
+ logger.warning(
+ "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
+ "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
+ )
+
+ # Load IP-Adapter into transformer
+ self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage)
+
+ def set_ip_adapter_scale(self, scale: float) -> None:
+ """
+ Set IP-Adapter scale, which controls image prompt conditioning. A value of 1.0 means the model is only
+ conditioned on the image prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages
+ the model to produce more diverse images, but they may not be as aligned with the image prompt.
+
+ Example:
+
+ ```python
+ >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
+ >>> pipeline.set_ip_adapter_scale(0.6)
+ >>> ...
+ ```
+
+ Args:
+ scale (float):
+ IP-Adapter scale to be set.
+
+ """
+ for attn_processor in self.transformer.attn_processors.values():
+ if isinstance(attn_processor, SD3IPAdapterJointAttnProcessor2_0):
+ attn_processor.scale = scale
+
+ def unload_ip_adapter(self) -> None:
+ """
+ Unloads the IP Adapter weights.
+
+ Example:
+
+ ```python
+ >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
+ >>> pipeline.unload_ip_adapter()
+ >>> ...
+ ```
+ """
+ # Remove image encoder
+ if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
+ self.image_encoder = None
+ self.register_to_config(image_encoder=None)
+
+ # Remove feature extractor
+ if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
+ self.feature_extractor = None
+ self.register_to_config(feature_extractor=None)
+
+ # Remove image projection
+ self.transformer.image_proj = None
+
+ # Restore original attention processors layers
+ attn_procs = {
+ name: (
+ JointAttnProcessor2_0() if isinstance(value, SD3IPAdapterJointAttnProcessor2_0) else value.__class__()
+ )
+ for name, value in self.transformer.attn_processors.items()
+ }
+ self.transformer.set_attn_processor(attn_procs)
diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py
index aab87b8f4dba..07c2c2272422 100644
--- a/src/diffusers/loaders/lora_conversion_utils.py
+++ b/src/diffusers/loaders/lora_conversion_utils.py
@@ -643,7 +643,11 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
old_state_dict,
new_state_dict,
old_key,
- [f"transformer.single_transformer_blocks.{block_num}.norm.linear"],
+ [
+ f"transformer.single_transformer_blocks.{block_num}.attn.to_q",
+ f"transformer.single_transformer_blocks.{block_num}.attn.to_k",
+ f"transformer.single_transformer_blocks.{block_num}.attn.to_v",
+ ],
)
if "down" in old_key:
diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index b8c44e480093..46d744233014 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -3870,6 +3870,314 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
super().unfuse_lora(components=components)
+class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ return state_dict
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
+ def load_lora_weights(
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
+ def load_lora_into_transformer(
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ transformer (`HunyuanVideoTransformer3DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not transformer_lora_layers:
+ raise ValueError("You must pass `transformer_lora_layers`.")
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer", "text_encoder"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ unfuse_text_encoder (`bool`, defaults to `True`):
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
+ LoRA parameters then it won't have any effect.
+ """
+ super().unfuse_lora(components=components)
+
+
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
def __init__(self, *args, **kwargs):
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py
index a791a250af08..9c00012ebc65 100644
--- a/src/diffusers/loaders/peft.py
+++ b/src/diffusers/loaders/peft.py
@@ -53,6 +53,7 @@
"FluxTransformer2DModel": lambda model_cls, weights: weights,
"CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
"MochiTransformer3DModel": lambda model_cls, weights: weights,
+ "HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
"SanaTransformer2DModel": lambda model_cls, weights: weights,
}
diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py
index 9641435fa5a6..d102282025c7 100644
--- a/src/diffusers/loaders/single_file_model.py
+++ b/src/diffusers/loaders/single_file_model.py
@@ -32,6 +32,7 @@
convert_ldm_vae_checkpoint,
convert_ltx_transformer_checkpoint_to_diffusers,
convert_ltx_vae_checkpoint_to_diffusers,
+ convert_mochi_transformer_checkpoint_to_diffusers,
convert_sd3_transformer_checkpoint_to_diffusers,
convert_stable_cascade_unet_single_file_to_diffusers,
create_controlnet_diffusers_config_from_ldm,
@@ -96,6 +97,10 @@
"default_subfolder": "vae",
},
"AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
+ "MochiTransformer3DModel": {
+ "checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
+ "default_subfolder": "transformer",
+ },
}
diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py
index ded466b35e9a..8b2bf12214cd 100644
--- a/src/diffusers/loaders/single_file_utils.py
+++ b/src/diffusers/loaders/single_file_utils.py
@@ -106,6 +106,7 @@
],
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
+ "mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -159,6 +160,7 @@
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
+ "mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
}
# Use to configure model sample size when original config is provided
@@ -618,6 +620,9 @@ def infer_diffusers_model_type(checkpoint):
else:
model_type = "autoencoder-dc-f128c512"
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
+ model_type = "mochi-1-preview"
+
else:
model_type = "v1"
@@ -1758,6 +1763,12 @@ def swap_scale_shift(weight, dim):
return new_weight
+def swap_proj_gate(weight):
+ proj, gate = weight.chunk(2, dim=0)
+ new_weight = torch.cat([gate, proj], dim=0)
+ return new_weight
+
+
def get_attn2_layers(state_dict):
attn2_layers = []
for key in state_dict.keys():
@@ -2414,3 +2425,101 @@ def remap_proj_conv_(key: str, state_dict):
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict
+
+
+def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
+ new_state_dict = {}
+
+ # Comfy checkpoints add this prefix
+ keys = list(checkpoint.keys())
+ for k in keys:
+ if "model.diffusion_model." in k:
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
+
+ # Convert patch_embed
+ new_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
+ new_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
+
+ # Convert time_embed
+ new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
+ new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
+ new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
+ new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
+ new_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
+ new_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
+ new_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
+ new_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
+ new_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
+ new_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
+ new_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
+ new_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
+
+ # Convert transformer blocks
+ num_layers = 48
+ for i in range(num_layers):
+ block_prefix = f"transformer_blocks.{i}."
+ old_prefix = f"blocks.{i}."
+
+ # norm1
+ new_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
+ new_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
+ if i < num_layers - 1:
+ new_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(old_prefix + "mod_y.weight")
+ new_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
+ else:
+ new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
+ old_prefix + "mod_y.weight"
+ )
+ new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
+
+ # Visual attention
+ qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight")
+ q, k, v = qkv_weight.chunk(3, dim=0)
+
+ new_state_dict[block_prefix + "attn1.to_q.weight"] = q
+ new_state_dict[block_prefix + "attn1.to_k.weight"] = k
+ new_state_dict[block_prefix + "attn1.to_v.weight"] = v
+ new_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(old_prefix + "attn.q_norm_x.weight")
+ new_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(old_prefix + "attn.k_norm_x.weight")
+ new_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(old_prefix + "attn.proj_x.weight")
+ new_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
+
+ # Context attention
+ qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight")
+ q, k, v = qkv_weight.chunk(3, dim=0)
+
+ new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
+ new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
+ new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
+ new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
+ old_prefix + "attn.q_norm_y.weight"
+ )
+ new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
+ old_prefix + "attn.k_norm_y.weight"
+ )
+ if i < num_layers - 1:
+ new_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
+ old_prefix + "attn.proj_y.weight"
+ )
+ new_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(old_prefix + "attn.proj_y.bias")
+
+ # MLP
+ new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
+ checkpoint.pop(old_prefix + "mlp_x.w1.weight")
+ )
+ new_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
+ if i < num_layers - 1:
+ new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
+ checkpoint.pop(old_prefix + "mlp_y.w1.weight")
+ )
+ new_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_y.w2.weight")
+
+ # Output layers
+ new_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
+ new_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
+ new_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
+ new_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
+
+ new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
+
+ return new_state_dict
diff --git a/src/diffusers/loaders/transformer_sd3.py b/src/diffusers/loaders/transformer_sd3.py
new file mode 100644
index 000000000000..435d1da06ca1
--- /dev/null
+++ b/src/diffusers/loaders/transformer_sd3.py
@@ -0,0 +1,89 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Dict
+
+from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
+from ..models.embeddings import IPAdapterTimeImageProjection
+from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
+
+
+class SD3Transformer2DLoadersMixin:
+ """Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""
+
+ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None:
+ """Sets IP-Adapter attention processors, image projection, and loads state_dict.
+
+ Args:
+ state_dict (`Dict`):
+ State dict with keys "ip_adapter", which contains parameters for attention processors, and
+ "image_proj", which contains parameters for image projection net.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ """
+ # IP-Adapter cross attention parameters
+ hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
+ ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads
+ timesteps_emb_dim = state_dict["ip_adapter"]["0.norm_ip.linear.weight"].shape[1]
+
+ # Dict where key is transformer layer index, value is attention processor's state dict
+ # ip_adapter state dict keys example: "0.norm_ip.linear.weight"
+ layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))}
+ for key, weights in state_dict["ip_adapter"].items():
+ idx, name = key.split(".", maxsplit=1)
+ layer_state_dict[int(idx)][name] = weights
+
+ # Create IP-Adapter attention processor
+ attn_procs = {}
+ for idx, name in enumerate(self.attn_processors.keys()):
+ attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0(
+ hidden_size=hidden_size,
+ ip_hidden_states_dim=ip_hidden_states_dim,
+ head_dim=self.config.attention_head_dim,
+ timesteps_emb_dim=timesteps_emb_dim,
+ ).to(self.device, dtype=self.dtype)
+
+ if not low_cpu_mem_usage:
+ attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
+ else:
+ load_model_dict_into_meta(
+ attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype
+ )
+
+ self.set_attn_processor(attn_procs)
+
+ # Image projetion parameters
+ embed_dim = state_dict["image_proj"]["proj_in.weight"].shape[1]
+ output_dim = state_dict["image_proj"]["proj_out.weight"].shape[0]
+ hidden_dim = state_dict["image_proj"]["proj_in.weight"].shape[0]
+ heads = state_dict["image_proj"]["layers.0.attn.to_q.weight"].shape[0] // 64
+ num_queries = state_dict["image_proj"]["latents"].shape[1]
+ timestep_in_dim = state_dict["image_proj"]["time_embedding.linear_1.weight"].shape[1]
+
+ # Image projection
+ self.image_proj = IPAdapterTimeImageProjection(
+ embed_dim=embed_dim,
+ output_dim=output_dim,
+ hidden_dim=hidden_dim,
+ heads=heads,
+ num_queries=num_queries,
+ timestep_in_dim=timestep_in_dim,
+ ).to(device=self.device, dtype=self.dtype)
+
+ if not low_cpu_mem_usage:
+ self.image_proj.load_state_dict(state_dict["image_proj"], strict=True)
+ else:
+ load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype)
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index 05cbaa40e693..ed0dd4f71d27 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -5243,6 +5243,177 @@ def __call__(
return hidden_states
+class SD3IPAdapterJointAttnProcessor2_0(torch.nn.Module):
+ """
+ Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections, with
+ additional image-based information and timestep embeddings.
+
+ Args:
+ hidden_size (`int`):
+ The number of hidden channels.
+ ip_hidden_states_dim (`int`):
+ The image feature dimension.
+ head_dim (`int`):
+ The number of head channels.
+ timesteps_emb_dim (`int`, defaults to 1280):
+ The number of input channels for timestep embedding.
+ scale (`float`, defaults to 0.5):
+ IP-Adapter scale.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ ip_hidden_states_dim: int,
+ head_dim: int,
+ timesteps_emb_dim: int = 1280,
+ scale: float = 0.5,
+ ):
+ super().__init__()
+
+ # To prevent circular import
+ from .normalization import AdaLayerNorm, RMSNorm
+
+ self.norm_ip = AdaLayerNorm(timesteps_emb_dim, output_dim=ip_hidden_states_dim * 2, norm_eps=1e-6, chunk_dim=1)
+ self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
+ self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
+ self.norm_q = RMSNorm(head_dim, 1e-6)
+ self.norm_k = RMSNorm(head_dim, 1e-6)
+ self.norm_ip_k = RMSNorm(head_dim, 1e-6)
+ self.scale = scale
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ ip_hidden_states: torch.FloatTensor = None,
+ temb: torch.FloatTensor = None,
+ ) -> torch.FloatTensor:
+ """
+ Perform the attention computation, integrating image features (if provided) and timestep embeddings.
+
+ If `ip_hidden_states` is `None`, this is equivalent to using JointAttnProcessor2_0.
+
+ Args:
+ attn (`Attention`):
+ Attention instance.
+ hidden_states (`torch.FloatTensor`):
+ Input `hidden_states`.
+ encoder_hidden_states (`torch.FloatTensor`, *optional*):
+ The encoder hidden states.
+ attention_mask (`torch.FloatTensor`, *optional*):
+ Attention mask.
+ ip_hidden_states (`torch.FloatTensor`, *optional*):
+ Image embeddings.
+ temb (`torch.FloatTensor`, *optional*):
+ Timestep embeddings.
+
+ Returns:
+ `torch.FloatTensor`: Output hidden states.
+ """
+ residual = hidden_states
+
+ batch_size = hidden_states.shape[0]
+
+ # `sample` projections.
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ img_query = query
+ img_key = key
+ img_value = value
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # `context` projections.
+ if encoder_hidden_states is not None:
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
+
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
+
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ # Split the attention outputs.
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, : residual.shape[1]],
+ hidden_states[:, residual.shape[1] :],
+ )
+ if not attn.context_pre_only:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ # IP Adapter
+ if self.scale != 0 and ip_hidden_states is not None:
+ # Norm image features
+ norm_ip_hidden_states = self.norm_ip(ip_hidden_states, temb=temb)
+
+ # To k and v
+ ip_key = self.to_k_ip(norm_ip_hidden_states)
+ ip_value = self.to_v_ip(norm_ip_hidden_states)
+
+ # Reshape
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # Norm
+ query = self.norm_q(img_query)
+ img_key = self.norm_k(img_key)
+ ip_key = self.norm_ip_k(ip_key)
+
+ # cat img
+ key = torch.cat([img_key, ip_key], dim=2)
+ value = torch.cat([img_value, ip_value], dim=2)
+
+ ip_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).view(batch_size, -1, attn.heads * head_dim)
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
+
+ hidden_states = hidden_states + ip_hidden_states * self.scale
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if encoder_hidden_states is not None:
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
class PAGIdentitySelfAttnProcessor2_0:
r"""
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
@@ -5772,6 +5943,7 @@ def __call__(
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
+ SD3IPAdapterJointAttnProcessor2_0,
PAGIdentitySelfAttnProcessor2_0,
PAGCFGIdentitySelfAttnProcessor2_0,
LoRAAttnProcessor,
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py
index bded90a8bcff..5c1d94d4e18f 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py
@@ -792,12 +792,12 @@ def __init__(
# The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 256
self.tile_sample_min_width = 256
- self.tile_sample_min_num_frames = 64
+ self.tile_sample_min_num_frames = 16
# The minimal distance between two spatial tiles
self.tile_sample_stride_height = 192
self.tile_sample_stride_width = 192
- self.tile_sample_stride_num_frames = 48
+ self.tile_sample_stride_num_frames = 12
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (HunyuanVideoEncoder3D, HunyuanVideoDecoder3D)):
@@ -1003,7 +1003,7 @@ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
for i in range(0, height, self.tile_sample_stride_height):
row = []
for j in range(0, width, self.tile_sample_stride_width):
- tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
+ tile = x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
tile = self.encoder(tile)
tile = self.quant_conv(tile)
row.append(tile)
@@ -1020,7 +1020,7 @@ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
- result_rows.append(torch.cat(result_row, dim=-1))
+ result_rows.append(torch.cat(result_row, dim=4))
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
return enc
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index 69b3ee8466f4..f1b339e6180b 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -2396,6 +2396,187 @@ def forward(self, id_embeds: torch.Tensor) -> torch.Tensor:
return out
+class IPAdapterTimeImageProjectionBlock(nn.Module):
+ """Block for IPAdapterTimeImageProjection.
+
+ Args:
+ hidden_dim (`int`, defaults to 1280):
+ The number of hidden channels.
+ dim_head (`int`, defaults to 64):
+ The number of head channels.
+ heads (`int`, defaults to 20):
+ Parallel attention heads.
+ ffn_ratio (`int`, defaults to 4):
+ The expansion ratio of feedforward network hidden layer channels.
+ """
+
+ def __init__(
+ self,
+ hidden_dim: int = 1280,
+ dim_head: int = 64,
+ heads: int = 20,
+ ffn_ratio: int = 4,
+ ) -> None:
+ super().__init__()
+ from .attention import FeedForward
+
+ self.ln0 = nn.LayerNorm(hidden_dim)
+ self.ln1 = nn.LayerNorm(hidden_dim)
+ self.attn = Attention(
+ query_dim=hidden_dim,
+ cross_attention_dim=hidden_dim,
+ dim_head=dim_head,
+ heads=heads,
+ bias=False,
+ out_bias=False,
+ )
+ self.ff = FeedForward(hidden_dim, hidden_dim, activation_fn="gelu", mult=ffn_ratio, bias=False)
+
+ # AdaLayerNorm
+ self.adaln_silu = nn.SiLU()
+ self.adaln_proj = nn.Linear(hidden_dim, 4 * hidden_dim)
+ self.adaln_norm = nn.LayerNorm(hidden_dim)
+
+ # Set attention scale and fuse KV
+ self.attn.scale = 1 / math.sqrt(math.sqrt(dim_head))
+ self.attn.fuse_projections()
+ self.attn.to_k = None
+ self.attn.to_v = None
+
+ def forward(self, x: torch.Tensor, latents: torch.Tensor, timestep_emb: torch.Tensor) -> torch.Tensor:
+ """Forward pass.
+
+ Args:
+ x (`torch.Tensor`):
+ Image features.
+ latents (`torch.Tensor`):
+ Latent features.
+ timestep_emb (`torch.Tensor`):
+ Timestep embedding.
+
+ Returns:
+ `torch.Tensor`: Output latent features.
+ """
+
+ # Shift and scale for AdaLayerNorm
+ emb = self.adaln_proj(self.adaln_silu(timestep_emb))
+ shift_msa, scale_msa, shift_mlp, scale_mlp = emb.chunk(4, dim=1)
+
+ # Fused Attention
+ residual = latents
+ x = self.ln0(x)
+ latents = self.ln1(latents) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+
+ batch_size = latents.shape[0]
+
+ query = self.attn.to_q(latents)
+ kv_input = torch.cat((x, latents), dim=-2)
+ key, value = self.attn.to_kv(kv_input).chunk(2, dim=-1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // self.attn.heads
+
+ query = query.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
+
+ weight = (query * self.attn.scale) @ (key * self.attn.scale).transpose(-2, -1)
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ latents = weight @ value
+
+ latents = latents.transpose(1, 2).reshape(batch_size, -1, self.attn.heads * head_dim)
+ latents = self.attn.to_out[0](latents)
+ latents = self.attn.to_out[1](latents)
+ latents = latents + residual
+
+ ## FeedForward
+ residual = latents
+ latents = self.adaln_norm(latents) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ return self.ff(latents) + residual
+
+
+# Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
+class IPAdapterTimeImageProjection(nn.Module):
+ """Resampler of SD3 IP-Adapter with timestep embedding.
+
+ Args:
+ embed_dim (`int`, defaults to 1152):
+ The feature dimension.
+ output_dim (`int`, defaults to 2432):
+ The number of output channels.
+ hidden_dim (`int`, defaults to 1280):
+ The number of hidden channels.
+ depth (`int`, defaults to 4):
+ The number of blocks.
+ dim_head (`int`, defaults to 64):
+ The number of head channels.
+ heads (`int`, defaults to 20):
+ Parallel attention heads.
+ num_queries (`int`, defaults to 64):
+ The number of queries.
+ ffn_ratio (`int`, defaults to 4):
+ The expansion ratio of feedforward network hidden layer channels.
+ timestep_in_dim (`int`, defaults to 320):
+ The number of input channels for timestep embedding.
+ timestep_flip_sin_to_cos (`bool`, defaults to True):
+ Flip the timestep embedding order to `cos, sin` (if True) or `sin, cos` (if False).
+ timestep_freq_shift (`int`, defaults to 0):
+ Controls the timestep delta between frequencies between dimensions.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int = 1152,
+ output_dim: int = 2432,
+ hidden_dim: int = 1280,
+ depth: int = 4,
+ dim_head: int = 64,
+ heads: int = 20,
+ num_queries: int = 64,
+ ffn_ratio: int = 4,
+ timestep_in_dim: int = 320,
+ timestep_flip_sin_to_cos: bool = True,
+ timestep_freq_shift: int = 0,
+ ) -> None:
+ super().__init__()
+ self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dim) / hidden_dim**0.5)
+ self.proj_in = nn.Linear(embed_dim, hidden_dim)
+ self.proj_out = nn.Linear(hidden_dim, output_dim)
+ self.norm_out = nn.LayerNorm(output_dim)
+ self.layers = nn.ModuleList(
+ [IPAdapterTimeImageProjectionBlock(hidden_dim, dim_head, heads, ffn_ratio) for _ in range(depth)]
+ )
+ self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
+ self.time_embedding = TimestepEmbedding(timestep_in_dim, hidden_dim, act_fn="silu")
+
+ def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Forward pass.
+
+ Args:
+ x (`torch.Tensor`):
+ Image features.
+ timestep (`torch.Tensor`):
+ Timestep in denoising process.
+ Returns:
+ `Tuple`[`torch.Tensor`, `torch.Tensor`]: The pair (latents, timestep_emb).
+ """
+ timestep_emb = self.time_proj(timestep).to(dtype=x.dtype)
+ timestep_emb = self.time_embedding(timestep_emb)
+
+ latents = self.latents.repeat(x.size(0), 1, 1)
+
+ x = self.proj_in(x)
+ x = x + timestep_emb[:, None]
+
+ for block in self.layers:
+ latents = block(x, latents, timestep_emb)
+
+ latents = self.proj_out(latents)
+ latents = self.norm_out(latents)
+
+ return latents, timestep_emb
+
+
class MultiIPAdapterImageProjection(nn.Module):
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
super().__init__()
diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py
index d8f9834ea61c..089389b5f9ad 100644
--- a/src/diffusers/models/transformers/transformer_hunyuan_video.py
+++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py
@@ -19,7 +19,8 @@
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
-from ...utils import is_torch_version
+from ...loaders import PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ..attention import FeedForward
from ..attention_processor import Attention, AttentionProcessor
from ..embeddings import (
@@ -32,6 +33,9 @@
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
class HunyuanVideoAttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
@@ -496,7 +500,47 @@ def forward(
return hidden_states, encoder_hidden_states
-class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin):
+class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
+ r"""
+ A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
+
+ Args:
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, defaults to `16`):
+ The number of channels in the output.
+ num_attention_heads (`int`, defaults to `24`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of channels in each head.
+ num_layers (`int`, defaults to `20`):
+ The number of layers of dual-stream blocks to use.
+ num_single_layers (`int`, defaults to `40`):
+ The number of layers of single-stream blocks to use.
+ num_refiner_layers (`int`, defaults to `2`):
+ The number of layers of refiner blocks to use.
+ mlp_ratio (`float`, defaults to `4.0`):
+ The ratio of the hidden layer size to the input size in the feedforward network.
+ patch_size (`int`, defaults to `2`):
+ The size of the spatial patches to use in the patch embedding layer.
+ patch_size_t (`int`, defaults to `1`):
+ The size of the tmeporal patches to use in the patch embedding layer.
+ qk_norm (`str`, defaults to `rms_norm`):
+ The normalization to use for the query and key projections in the attention layers.
+ guidance_embeds (`bool`, defaults to `True`):
+ Whether to use guidance embeddings in the model.
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ pooled_projection_dim (`int`, defaults to `768`):
+ The dimension of the pooled projection of the text embeddings.
+ rope_theta (`float`, defaults to `256.0`):
+ The value of theta to use in the RoPE layer.
+ rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
+ The dimensions of the axes to use in the RoPE layer.
+ """
+
+ _supports_gradient_checkpointing = True
+
@register_to_config
def __init__(
self,
@@ -630,8 +674,24 @@ def forward(
encoder_attention_mask: torch.Tensor,
pooled_projections: torch.Tensor,
guidance: torch.Tensor = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p, p_t = self.config.patch_size, self.config.patch_size_t
post_patch_num_frames = num_frames // p_t
@@ -717,6 +777,10 @@ def custom_forward(*inputs):
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
if not return_dict:
return (hidden_states,)
diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py
index fe72dc56883e..8763ea450253 100644
--- a/src/diffusers/models/transformers/transformer_mochi.py
+++ b/src/diffusers/models/transformers/transformer_mochi.py
@@ -20,6 +20,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
+from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
@@ -304,7 +305,7 @@ def forward(
@maybe_allow_in_graph
-class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
+class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
r"""
A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
@@ -334,6 +335,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
"""
_supports_gradient_checkpointing = True
+ _no_split_modules = ["MochiTransformerBlock"]
@register_to_config
def __init__(
diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py
index 79c4069e9a37..415540ef7f6a 100644
--- a/src/diffusers/models/transformers/transformer_sd3.py
+++ b/src/diffusers/models/transformers/transformer_sd3.py
@@ -18,7 +18,7 @@
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
-from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
from ...models.attention import FeedForward, JointTransformerBlock
from ...models.attention_processor import (
Attention,
@@ -103,7 +103,9 @@ def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor):
return hidden_states
-class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
+class SD3Transformer2DModel(
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
+):
"""
The Transformer model introduced in Stable Diffusion 3.
@@ -349,8 +351,8 @@ def forward(
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
- pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
- from the embeddings of input conditions.
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
+ Embeddings projected from the embeddings of input conditions.
timestep (`torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states (`list` of `torch.Tensor`):
@@ -390,6 +392,12 @@ def forward(
temb = self.time_text_embed(timestep, pooled_projections)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
+ ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep)
+
+ joint_attention_kwargs.update(ip_hidden_states=ip_hidden_states, temb=ip_temb)
+
for index_block, block in enumerate(self.transformer_blocks):
# Skip specified layers
is_skip = True if skip_layers is not None and index_block in skip_layers else False
diff --git a/src/diffusers/models/unets/unet_2d.py b/src/diffusers/models/unets/unet_2d.py
index d05af686dede..bec62ce5cf45 100644
--- a/src/diffusers/models/unets/unet_2d.py
+++ b/src/diffusers/models/unets/unet_2d.py
@@ -89,6 +89,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
conditioning with `class_embed_type` equal to `None`.
"""
+ _supports_gradient_checkpointing = True
+
@register_to_config
def __init__(
self,
@@ -241,6 +243,10 @@ def __init__(
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
def forward(
self,
sample: torch.Tensor,
diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py
index b9d186ac1aa6..b4e0cea7c71d 100644
--- a/src/diffusers/models/unets/unet_2d_blocks.py
+++ b/src/diffusers/models/unets/unet_2d_blocks.py
@@ -731,12 +731,35 @@ def __init__(
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
+ self.gradient_checkpointing = False
+
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
- if attn is not None:
- hidden_states = attn(hidden_states, temb=temb)
- hidden_states = resnet(hidden_states, temb)
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ if attn is not None:
+ hidden_states = attn(hidden_states, temb=temb)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ else:
+ if attn is not None:
+ hidden_states = attn(hidden_states, temb=temb)
+ hidden_states = resnet(hidden_states, temb)
return hidden_states
@@ -1116,6 +1139,8 @@ def __init__(
else:
self.downsamplers = None
+ self.gradient_checkpointing = False
+
def forward(
self,
hidden_states: torch.Tensor,
@@ -1130,9 +1155,30 @@ def forward(
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
- hidden_states = resnet(hidden_states, temb)
- hidden_states = attn(hidden_states, **cross_attention_kwargs)
- output_states = output_states + (hidden_states,)
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = attn(hidden_states, **cross_attention_kwargs)
+ output_states = output_states + (hidden_states,)
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, **cross_attention_kwargs)
+ output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
@@ -2354,6 +2400,7 @@ def __init__(
else:
self.upsamplers = None
+ self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
def forward(
@@ -2375,8 +2422,28 @@ def forward(
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
- hidden_states = resnet(hidden_states, temb)
- hidden_states = attn(hidden_states)
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = attn(hidden_states)
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py
index 4f55df32b738..e488f5897ebc 100644
--- a/src/diffusers/models/unets/unet_2d_condition.py
+++ b/src/diffusers/models/unets/unet_2d_condition.py
@@ -170,7 +170,7 @@ class conditioning with `class_embed_type` equal to `None`.
@register_to_config
def __init__(
self,
- sample_size: Optional[int] = None,
+ sample_size: Optional[Union[int, Tuple[int, int]]] = None,
in_channels: int = 4,
out_channels: int = 4,
center_input_sample: bool = False,
diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py
index a0f95fe6cdc1..f3a05c2c661f 100644
--- a/src/diffusers/pipelines/auto_pipeline.py
+++ b/src/diffusers/pipelines/auto_pipeline.py
@@ -35,9 +35,12 @@
)
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
from .flux import (
+ FluxControlImg2ImgPipeline,
+ FluxControlInpaintPipeline,
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline,
+ FluxControlPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
FluxPipeline,
@@ -125,6 +128,7 @@
("pixart-sigma-pag", PixArtSigmaPAGPipeline),
("auraflow", AuraFlowPipeline),
("flux", FluxPipeline),
+ ("flux-control", FluxControlPipeline),
("flux-controlnet", FluxControlNetPipeline),
("lumina", LuminaText2ImgPipeline),
("cogview3", CogView3PlusPipeline),
@@ -150,6 +154,7 @@
("lcm", LatentConsistencyModelImg2ImgPipeline),
("flux", FluxImg2ImgPipeline),
("flux-controlnet", FluxControlNetImg2ImgPipeline),
+ ("flux-control", FluxControlImg2ImgPipeline),
]
)
@@ -168,6 +173,7 @@
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
("flux", FluxInpaintPipeline),
("flux-controlnet", FluxControlNetInpaintPipeline),
+ ("flux-control", FluxControlInpaintPipeline),
("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline),
]
)
@@ -401,16 +407,20 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
orig_class_name = config["_class_name"]
+ if "ControlPipeline" in orig_class_name:
+ to_replace = "ControlPipeline"
+ else:
+ to_replace = "Pipeline"
if "controlnet" in kwargs:
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
- orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetUnionPipeline")
+ orig_class_name = config["_class_name"].replace(to_replace, "ControlNetUnionPipeline")
else:
- orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
+ orig_class_name = config["_class_name"].replace(to_replace, "ControlNetPipeline")
if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag")
if enable_pag:
- orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline")
+ orig_class_name = orig_class_name.replace(to_replace, "PAGPipeline")
text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name)
@@ -694,8 +704,14 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
# the `orig_class_name` can be:
# `- *Pipeline` (for regular text-to-image checkpoint)
+ # - `*ControlPipeline` (for Flux tools specific checkpoint)
# `- *Img2ImgPipeline` (for refiner checkpoint)
- to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline"
+ if "Img2Img" in orig_class_name:
+ to_replace = "Img2ImgPipeline"
+ elif "ControlPipeline" in orig_class_name:
+ to_replace = "ControlPipeline"
+ else:
+ to_replace = "Pipeline"
if "controlnet" in kwargs:
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
@@ -707,6 +723,9 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
if enable_pag:
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
+ if to_replace == "ControlPipeline":
+ orig_class_name = orig_class_name.replace(to_replace, "ControlImg2ImgPipeline")
+
image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name)
kwargs = {**load_config_kwargs, **kwargs}
@@ -994,8 +1013,14 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
# The `orig_class_name`` can be:
# `- *InpaintPipeline` (for inpaint-specific checkpoint)
+ # - `*ControlPipeline` (for Flux tools specific checkpoint)
# - or *Pipeline (for regular text-to-image checkpoint)
- to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline"
+ if "Inpaint" in orig_class_name:
+ to_replace = "InpaintPipeline"
+ elif "ControlPipeline" in orig_class_name:
+ to_replace = "ControlPipeline"
+ else:
+ to_replace = "Pipeline"
if "controlnet" in kwargs:
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
@@ -1006,6 +1031,8 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
enable_pag = kwargs.pop("enable_pag")
if enable_pag:
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
+ if to_replace == "ControlPipeline":
+ orig_class_name = orig_class_name.replace(to_replace, "ControlInpaintPipeline")
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
kwargs = {**load_config_kwargs, **kwargs}
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
index 107a5a45bfa2..0fd8875a88a1 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
@@ -2223,12 +2223,35 @@ def __init__(
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
+ self.gradient_checkpointing = False
+
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
- if attn is not None:
- hidden_states = attn(hidden_states, temb=temb)
- hidden_states = resnet(hidden_states, temb)
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ if attn is not None:
+ hidden_states = attn(hidden_states, temb=temb)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ else:
+ if attn is not None:
+ hidden_states = attn(hidden_states, temb=temb)
+ hidden_states = resnet(hidden_states, temb)
return hidden_states
diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
index bd3d3c1e8485..4423ccf97932 100644
--- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
+++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
@@ -20,6 +20,7 @@
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import HunyuanVideoLoraLoaderMixin
from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging, replace_example_docstring
@@ -132,7 +133,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
-class HunyuanVideoPipeline(DiffusionPipeline):
+class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
r"""
Pipeline for text-to-video generation using HunyuanVideo.
@@ -447,6 +448,10 @@ def guidance_scale(self):
def num_timesteps(self):
return self._num_timesteps
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
@property
def interrupt(self):
return self._interrupt
@@ -471,6 +476,7 @@ def __call__(
prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
@@ -525,6 +531,10 @@ def __call__(
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
@@ -562,6 +572,7 @@ def __call__(
)
self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
self._interrupt = False
device = self._execution_device
@@ -640,6 +651,7 @@ def __call__(
encoder_attention_mask=prompt_attention_mask,
pooled_projections=pooled_prompt_embeds,
guidance=guidance,
+ attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py
index a504184ea2f2..c505c5a262a3 100644
--- a/src/diffusers/pipelines/pipeline_utils.py
+++ b/src/diffusers/pipelines/pipeline_utils.py
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import enum
import fnmatch
import importlib
import inspect
@@ -811,6 +812,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# in this case they are already instantiated in `kwargs`
# extract them here
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
+ expected_types = pipeline_class._get_signature_types()
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
@@ -833,6 +835,26 @@ def load_module(name, value):
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
+ for key in init_dict.keys():
+ if key not in passed_class_obj:
+ continue
+ if "scheduler" in key:
+ continue
+
+ class_obj = passed_class_obj[key]
+ _expected_class_types = []
+ for expected_type in expected_types[key]:
+ if isinstance(expected_type, enum.EnumMeta):
+ _expected_class_types.extend(expected_type.__members__.keys())
+ else:
+ _expected_class_types.append(expected_type.__name__)
+
+ _is_valid_type = class_obj.__class__.__name__ in _expected_class_types
+ if not _is_valid_type:
+ logger.warning(
+ f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}."
+ )
+
# Special case: safety_checker must be loaded separately when using `from_flax`
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
raise NotImplementedError(
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
index 4fd6a43a955a..ac6c8253e432 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
@@ -255,7 +255,12 @@ def __init__(
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ self._is_unet_config_sample_size_int = isinstance(unet.config.sample_size, int)
+ is_unet_sample_size_less_64 = (
+ hasattr(unet.config, "sample_size")
+ and self._is_unet_config_sample_size_int
+ and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -902,8 +907,18 @@ def __call__(
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 0. Default height and width to unet
- height = height or self.unet.config.sample_size * self.vae_scale_factor
- width = width or self.unet.config.sample_size * self.vae_scale_factor
+ if not height or not width:
+ height = (
+ self.unet.config.sample_size
+ if self._is_unet_config_sample_size_int
+ else self.unet.config.sample_size[0]
+ )
+ width = (
+ self.unet.config.sample_size
+ if self._is_unet_config_sample_size_int
+ else self.unet.config.sample_size[1]
+ )
+ height, width = height * self.vae_scale_factor, width * self.vae_scale_factor
# to deal with lora scaling and other possible forward hooks
# 1. Check inputs. Raise error if not correct
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
index 0a51dcbc1261..a53d786798ca 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
+# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,14 +17,16 @@
import torch
from transformers import (
+ BaseImageProcessor,
CLIPTextModelWithProjection,
CLIPTokenizer,
+ PreTrainedModel,
T5EncoderModel,
T5TokenizerFast,
)
-from ...image_processor import VaeImageProcessor
-from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import SD3Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -142,7 +144,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
-class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
+class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
r"""
Args:
transformer ([`SD3Transformer2DModel`]):
@@ -174,10 +176,14 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
tokenizer_3 (`T5TokenizerFast`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ image_encoder (`PreTrainedModel`, *optional*):
+ Pre-trained Vision Model for IP Adapter.
+ feature_extractor (`BaseImageProcessor`, *optional*):
+ Image processor for IP Adapter.
"""
- model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
- _optional_components = []
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
+ _optional_components = ["image_encoder", "feature_extractor"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
def __init__(
@@ -191,6 +197,8 @@ def __init__(
tokenizer_2: CLIPTokenizer,
text_encoder_3: T5EncoderModel,
tokenizer_3: T5TokenizerFast,
+ image_encoder: PreTrainedModel = None,
+ feature_extractor: BaseImageProcessor = None,
):
super().__init__()
@@ -204,6 +212,8 @@ def __init__(
tokenizer_3=tokenizer_3,
transformer=transformer,
scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
@@ -683,6 +693,83 @@ def num_timesteps(self):
def interrupt(self):
return self._interrupt
+ # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image
+ def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
+ """Encodes the given image into a feature representation using a pre-trained image encoder.
+
+ Args:
+ image (`PipelineImageInput`):
+ Input image to be encoded.
+ device: (`torch.device`):
+ Torch device.
+
+ Returns:
+ `torch.Tensor`: The encoded image feature representation.
+ """
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=self.dtype)
+
+ return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+
+ # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ ) -> torch.Tensor:
+ """Prepares image embeddings for use in the IP-Adapter.
+
+ Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
+
+ Args:
+ ip_adapter_image (`PipelineImageInput`, *optional*):
+ The input image to extract features from for IP-Adapter.
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
+ Precomputed image embeddings.
+ device: (`torch.device`, *optional*):
+ Torch device.
+ num_images_per_prompt (`int`, defaults to 1):
+ Number of images that should be generated per prompt.
+ do_classifier_free_guidance (`bool`, defaults to True):
+ Whether to use classifier free guidance or not.
+ """
+ device = device or self._execution_device
+
+ if ip_adapter_image_embeds is not None:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
+ else:
+ single_image_embeds = ip_adapter_image_embeds
+ elif ip_adapter_image is not None:
+ single_image_embeds = self.encode_image(ip_adapter_image, device)
+ if do_classifier_free_guidance:
+ single_negative_image_embeds = torch.zeros_like(single_image_embeds)
+ else:
+ raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
+
+ image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
+
+ return image_embeds.to(device=device)
+
+ def enable_sequential_cpu_offload(self, *args, **kwargs):
+ if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
+ logger.warning(
+ "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
+ "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
+ "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
+ )
+
+ super().enable_sequential_cpu_offload(*args, **kwargs)
+
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -705,6 +792,8 @@ def __call__(
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -713,9 +802,9 @@ def __call__(
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 256,
skip_guidance_layers: List[int] = None,
- skip_layer_guidance_scale: int = 2.8,
- skip_layer_guidance_stop: int = 0.2,
- skip_layer_guidance_start: int = 0.01,
+ skip_layer_guidance_scale: float = 2.8,
+ skip_layer_guidance_stop: float = 0.2,
+ skip_layer_guidance_start: float = 0.01,
mu: Optional[float] = None,
):
r"""
@@ -781,6 +870,11 @@ def __call__(
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
+ ip_adapter_image (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
+ emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
+ `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -938,7 +1032,22 @@ def __call__(
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
- # 6. Denoising loop
+ # 6. Prepare image embeddings
+ if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
+ ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ if self.joint_attention_kwargs is None:
+ self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
+ else:
+ self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
+
+ # 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py
index ef4715ee0e1e..a6dfe18433e3 100644
--- a/src/diffusers/utils/hub_utils.py
+++ b/src/diffusers/utils/hub_utils.py
@@ -455,48 +455,39 @@ def _get_checkpoint_shard_files(
allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns]
ignore_patterns = ["*.json", "*.md"]
- if not local_files_only:
- # `model_info` call must guarded with the above condition.
- model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
- for shard_file in original_shard_filenames:
- shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
- if not shard_file_present:
- raise EnvironmentError(
- f"{shards_path} does not appear to have a file named {shard_file} which is "
- "required according to the checkpoint index."
- )
-
- try:
- # Load from URL
- cached_folder = snapshot_download(
- pretrained_model_name_or_path,
- cache_dir=cache_dir,
- proxies=proxies,
- local_files_only=local_files_only,
- token=token,
- revision=revision,
- allow_patterns=allow_patterns,
- ignore_patterns=ignore_patterns,
- user_agent=user_agent,
- )
- if subfolder is not None:
- cached_folder = os.path.join(cached_folder, subfolder)
-
- # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
- # we don't have to catch them here. We have also dealt with EntryNotFoundError.
- except HTTPError as e:
+ # `model_info` call must guarded with the above condition.
+ model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
+ for shard_file in original_shard_filenames:
+ shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
+ if not shard_file_present:
raise EnvironmentError(
- f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
- " again after checking your internet connection."
- ) from e
+ f"{shards_path} does not appear to have a file named {shard_file} which is "
+ "required according to the checkpoint index."
+ )
- # If `local_files_only=True`, `cached_folder` may not contain all the shard files.
- elif local_files_only:
- _check_if_shards_exist_locally(
- local_dir=cache_dir, subfolder=subfolder, original_shard_filenames=original_shard_filenames
+ try:
+ # Load from URL
+ cached_folder = snapshot_download(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ allow_patterns=allow_patterns,
+ ignore_patterns=ignore_patterns,
+ user_agent=user_agent,
)
if subfolder is not None:
- cached_folder = os.path.join(cache_dir, subfolder)
+ cached_folder = os.path.join(cached_folder, subfolder)
+
+ # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
+ # we don't have to catch them here. We have also dealt with EntryNotFoundError.
+ except HTTPError as e:
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
+ " again after checking your internet connection."
+ ) from e
return cached_folder, sharded_metadata
diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py
new file mode 100644
index 000000000000..59464c052684
--- /dev/null
+++ b/tests/lora/test_lora_layers_hunyuanvideo.py
@@ -0,0 +1,228 @@
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import unittest
+
+import numpy as np
+import pytest
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
+
+from diffusers import (
+ AutoencoderKLHunyuanVideo,
+ FlowMatchEulerDiscreteScheduler,
+ HunyuanVideoPipeline,
+ HunyuanVideoTransformer3DModel,
+)
+from diffusers.utils.testing_utils import (
+ floats_tensor,
+ is_torch_version,
+ require_peft_backend,
+ skip_mps,
+ torch_device,
+)
+
+
+sys.path.append(".")
+
+from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
+
+
+@require_peft_backend
+@skip_mps
+class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
+ pipeline_class = HunyuanVideoPipeline
+ scheduler_cls = FlowMatchEulerDiscreteScheduler
+ scheduler_classes = [FlowMatchEulerDiscreteScheduler]
+ scheduler_kwargs = {}
+
+ transformer_kwargs = {
+ "in_channels": 4,
+ "out_channels": 4,
+ "num_attention_heads": 2,
+ "attention_head_dim": 10,
+ "num_layers": 1,
+ "num_single_layers": 1,
+ "num_refiner_layers": 1,
+ "patch_size": 1,
+ "patch_size_t": 1,
+ "guidance_embeds": True,
+ "text_embed_dim": 16,
+ "pooled_projection_dim": 8,
+ "rope_axes_dim": (2, 4, 4),
+ }
+ transformer_cls = HunyuanVideoTransformer3DModel
+ vae_kwargs = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "latent_channels": 4,
+ "down_block_types": (
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ ),
+ "up_block_types": (
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ ),
+ "block_out_channels": (8, 8, 8, 8),
+ "layers_per_block": 1,
+ "act_fn": "silu",
+ "norm_num_groups": 4,
+ "scaling_factor": 0.476986,
+ "spatial_compression_ratio": 8,
+ "temporal_compression_ratio": 4,
+ "mid_block_add_attention": True,
+ }
+ vae_cls = AutoencoderKLHunyuanVideo
+ has_two_text_encoders = True
+ tokenizer_cls, tokenizer_id, tokenizer_subfolder = (
+ LlamaTokenizerFast,
+ "hf-internal-testing/tiny-random-hunyuanvideo",
+ "tokenizer",
+ )
+ tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = (
+ CLIPTokenizer,
+ "hf-internal-testing/tiny-random-hunyuanvideo",
+ "tokenizer_2",
+ )
+ text_encoder_cls, text_encoder_id, text_encoder_subfolder = (
+ LlamaModel,
+ "hf-internal-testing/tiny-random-hunyuanvideo",
+ "text_encoder",
+ )
+ text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder = (
+ CLIPTextModel,
+ "hf-internal-testing/tiny-random-hunyuanvideo",
+ "text_encoder_2",
+ )
+
+ @property
+ def output_shape(self):
+ return (1, 9, 32, 32, 3)
+
+ def get_dummy_inputs(self, with_generator=True):
+ batch_size = 1
+ sequence_length = 16
+ num_channels = 4
+ num_frames = 9
+ num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1
+ sizes = (4, 4)
+
+ generator = torch.manual_seed(0)
+ noise = floats_tensor((batch_size, num_latent_frames, num_channels) + sizes)
+ input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
+
+ pipeline_inputs = {
+ "prompt": "",
+ "num_frames": num_frames,
+ "num_inference_steps": 1,
+ "guidance_scale": 6.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": sequence_length,
+ "prompt_template": {"template": "{}", "crop_start": 0},
+ "output_type": "np",
+ }
+ if with_generator:
+ pipeline_inputs.update({"generator": generator})
+
+ return noise, input_ids, pipeline_inputs
+
+ @pytest.mark.xfail(
+ condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
+ reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
+ strict=True,
+ )
+ def test_lora_fuse_nan(self):
+ for scheduler_cls in self.scheduler_classes:
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
+
+ self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
+
+ # corrupt one LoRA weight with `inf` values
+ with torch.no_grad():
+ pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
+
+ # with `safe_fusing=True` we should see an Error
+ with self.assertRaises(ValueError):
+ pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
+
+ # without we should not see an error, but every image will be black
+ pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
+
+ out = pipe(
+ prompt=inputs["prompt"],
+ height=inputs["height"],
+ width=inputs["width"],
+ num_frames=inputs["num_frames"],
+ num_inference_steps=inputs["num_inference_steps"],
+ max_sequence_length=inputs["max_sequence_length"],
+ output_type="np",
+ )[0]
+
+ self.assertTrue(np.isnan(out).all())
+
+ def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
+ super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
+
+ def test_simple_inference_with_text_denoiser_lora_unfused(self):
+ super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
+
+ # TODO(aryan): Fix the following test
+ @unittest.skip("This test fails with an error I haven't been able to debug yet.")
+ def test_simple_inference_save_pretrained(self):
+ pass
+
+ @unittest.skip("Not supported in HunyuanVideo.")
+ def test_simple_inference_with_text_denoiser_block_scale(self):
+ pass
+
+ @unittest.skip("Not supported in HunyuanVideo.")
+ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
+ pass
+
+ @unittest.skip("Not supported in HunyuanVideo.")
+ def test_modify_padding_mode(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
+ def test_simple_inference_with_partial_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
+ def test_simple_inference_with_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
+ def test_simple_inference_with_text_lora_and_scale(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
+ def test_simple_inference_with_text_lora_fused(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
+ def test_simple_inference_with_text_lora_save_load(self):
+ pass
diff --git a/tests/lora/utils.py b/tests/lora/utils.py
index ac7a944cd026..0a0366fd8d2b 100644
--- a/tests/lora/utils.py
+++ b/tests/lora/utils.py
@@ -89,12 +89,12 @@ class PeftLoraLoaderMixinTests:
has_two_text_encoders = False
has_three_text_encoders = False
- text_encoder_cls, text_encoder_id = None, None
- text_encoder_2_cls, text_encoder_2_id = None, None
- text_encoder_3_cls, text_encoder_3_id = None, None
- tokenizer_cls, tokenizer_id = None, None
- tokenizer_2_cls, tokenizer_2_id = None, None
- tokenizer_3_cls, tokenizer_3_id = None, None
+ text_encoder_cls, text_encoder_id, text_encoder_subfolder = None, None, ""
+ text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder = None, None, ""
+ text_encoder_3_cls, text_encoder_3_id, text_encoder_3_subfolder = None, None, ""
+ tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, ""
+ tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, ""
+ tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, ""
unet_kwargs = None
transformer_cls = None
@@ -124,16 +124,26 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False):
torch.manual_seed(0)
vae = self.vae_cls(**self.vae_kwargs)
- text_encoder = self.text_encoder_cls.from_pretrained(self.text_encoder_id)
- tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id)
+ text_encoder = self.text_encoder_cls.from_pretrained(
+ self.text_encoder_id, subfolder=self.text_encoder_subfolder
+ )
+ tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id, subfolder=self.tokenizer_subfolder)
if self.text_encoder_2_cls is not None:
- text_encoder_2 = self.text_encoder_2_cls.from_pretrained(self.text_encoder_2_id)
- tokenizer_2 = self.tokenizer_2_cls.from_pretrained(self.tokenizer_2_id)
+ text_encoder_2 = self.text_encoder_2_cls.from_pretrained(
+ self.text_encoder_2_id, subfolder=self.text_encoder_2_subfolder
+ )
+ tokenizer_2 = self.tokenizer_2_cls.from_pretrained(
+ self.tokenizer_2_id, subfolder=self.tokenizer_2_subfolder
+ )
if self.text_encoder_3_cls is not None:
- text_encoder_3 = self.text_encoder_3_cls.from_pretrained(self.text_encoder_3_id)
- tokenizer_3 = self.tokenizer_3_cls.from_pretrained(self.tokenizer_3_id)
+ text_encoder_3 = self.text_encoder_3_cls.from_pretrained(
+ self.text_encoder_3_id, subfolder=self.text_encoder_3_subfolder
+ )
+ tokenizer_3 = self.tokenizer_3_cls.from_pretrained(
+ self.tokenizer_3_id, subfolder=self.tokenizer_3_subfolder
+ )
text_lora_config = LoraConfig(
r=rank,
diff --git a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py
index 826ac30d5f2f..7b7901a6fd94 100644
--- a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py
+++ b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py
@@ -43,10 +43,14 @@ def get_autoencoder_kl_hunyuan_video_config(self):
"down_block_types": (
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
),
"up_block_types": (
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
),
"block_out_channels": (8, 8, 8, 8),
"layers_per_block": 1,
@@ -154,6 +158,27 @@ def test_gradient_checkpointing_is_applied(self):
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+ # We need to overwrite this test because the base test does not account length of down_block_types
+ def test_forward_with_norm_groups(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ init_dict["norm_num_groups"] = 16
+ init_dict["block_out_channels"] = (16, 16, 16, 16)
+
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ output = model(**inputs_dict)
+
+ if isinstance(output, dict):
+ output = output.to_tuple()[0]
+
+ self.assertIsNotNone(output)
+ expected_shape = inputs_dict["sample"].shape
+ self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
+
@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass
diff --git a/tests/models/autoencoders/test_models_autoencoder_kl.py b/tests/models/autoencoders/test_models_autoencoder_kl.py
index 52bf5aba204b..c584bdcf56a2 100644
--- a/tests/models/autoencoders/test_models_autoencoder_kl.py
+++ b/tests/models/autoencoders/test_models_autoencoder_kl.py
@@ -146,7 +146,7 @@ def test_enable_disable_slicing(self):
)
def test_gradient_checkpointing_is_applied(self):
- expected_set = {"Decoder", "Encoder"}
+ expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
def test_from_pretrained_hub(self):
diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py b/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py
index 4308cb64896e..cf80ff50443e 100644
--- a/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py
+++ b/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py
@@ -65,7 +65,7 @@ def prepare_init_args_and_inputs_for_common(self):
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
- expected_set = {"Encoder", "TemporalDecoder"}
+ expected_set = {"Encoder", "TemporalDecoder", "UNetMidBlock2D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Test unsupported.")
diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py
index a7594f2ea13f..91a462d5878e 100644
--- a/tests/models/test_modeling_common.py
+++ b/tests/models/test_modeling_common.py
@@ -803,7 +803,7 @@ def test_enable_disable_gradient_checkpointing(self):
self.assertFalse(model.is_gradient_checkpointing)
@require_torch_accelerator_with_training
- def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5):
+ def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip: set[str] = {}):
if not self.model_class._supports_gradient_checkpointing:
return # Skip test if model does not support gradient checkpointing
@@ -850,6 +850,8 @@ def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_
for name, param in named_params.items():
if "post_quant_conv" in name:
continue
+ if name in skip:
+ continue
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol))
@unittest.skipIf(torch_device == "mps", "This test is not supported for MPS devices.")
diff --git a/tests/models/unets/test_models_unet_2d.py b/tests/models/unets/test_models_unet_2d.py
index 5f827f274224..ddf5f53511f7 100644
--- a/tests/models/unets/test_models_unet_2d.py
+++ b/tests/models/unets/test_models_unet_2d.py
@@ -105,6 +105,23 @@ def test_mid_block_attn_groups(self):
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {
+ "AttnUpBlock2D",
+ "AttnDownBlock2D",
+ "UNetMidBlock2D",
+ "UpBlock2D",
+ "DownBlock2D",
+ }
+
+ # NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
+ attention_head_dim = 8
+ block_out_channels = (16, 32)
+
+ super().test_gradient_checkpointing_is_applied(
+ expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
+ )
+
class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DModel
@@ -220,6 +237,17 @@ def test_output_pretrained(self):
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3))
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"}
+
+ # NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
+ attention_head_dim = 32
+ block_out_channels = (32, 64)
+
+ super().test_gradient_checkpointing_is_applied(
+ expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
+ )
+
class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DModel
@@ -329,3 +357,17 @@ def test_output_pretrained_ve_large(self):
def test_forward_with_norm_groups(self):
# not required for this model
pass
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {
+ "UNetMidBlock2D",
+ }
+
+ block_out_channels = (32, 64, 64, 64)
+
+ super().test_gradient_checkpointing_is_applied(
+ expected_set=expected_set, block_out_channels=block_out_channels
+ )
+
+ def test_effective_gradient_checkpointing(self):
+ super().test_effective_gradient_checkpointing(skip={"time_proj.weight"})
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py
index f37d598c8387..ccd5567106d2 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py
@@ -840,6 +840,14 @@ def callback_on_step_end(pipe, i, t, callback_kwargs):
# they should be the same
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
+ def test_pipeline_accept_tuple_type_unet_sample_size(self):
+ # the purpose of this test is to see whether the pipeline would accept a unet with the tuple-typed sample size
+ sd_repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
+ sample_size = [60, 80]
+ customised_unet = UNet2DConditionModel(sample_size=sample_size)
+ pipe = StableDiffusionPipeline.from_pretrained(sd_repo_id, unet=customised_unet)
+ assert pipe.unet.config.sample_size == sample_size
+
@slow
@require_torch_gpu
diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
index 07ce5487f256..a6f718ae4fbb 100644
--- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
+++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
@@ -103,6 +103,8 @@ def get_dummy_components(self):
"tokenizer_3": tokenizer_3,
"transformer": transformer,
"vae": vae,
+ "image_encoder": None,
+ "feature_extractor": None,
}
def get_dummy_inputs(self, device, seed=0):
diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py
index 43b01c40f5bb..423c82e0602e 100644
--- a/tests/pipelines/test_pipelines.py
+++ b/tests/pipelines/test_pipelines.py
@@ -1802,6 +1802,16 @@ def test_pipe_same_device_id_offload(self):
sd.maybe_free_model_hooks()
assert sd._offload_gpu_id == 5
+ def test_wrong_model(self):
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ with self.assertRaises(ValueError) as error_context:
+ _ = StableDiffusionPipeline.from_pretrained(
+ "hf-internal-testing/diffusers-stable-diffusion-tiny-all", text_encoder=tokenizer
+ )
+
+ assert "is of type" in str(error_context.exception)
+ assert "but should be" in str(error_context.exception)
+
@slow
@require_torch_gpu