diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index ab733054fbd3..27e9fe5e191b 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -400,6 +400,8 @@
title: DiT
- local: api/pipelines/flux
title: Flux
+ - local: api/pipelines/control_flux_inpaint
+ title: FluxControlInpaint
- local: api/pipelines/hunyuandit
title: Hunyuan-DiT
- local: api/pipelines/hunyuan_video
diff --git a/docs/source/en/api/pipelines/control_flux_inpaint.md b/docs/source/en/api/pipelines/control_flux_inpaint.md
new file mode 100644
index 000000000000..0cf4f4b4225e
--- /dev/null
+++ b/docs/source/en/api/pipelines/control_flux_inpaint.md
@@ -0,0 +1,89 @@
+
+
+# FluxControlInpaint
+
+FluxControlInpaintPipeline is an implementation of Inpainting for Flux.1 Depth/Canny models. It is a pipeline that allows you to inpaint images using the Flux.1 Depth/Canny models. The pipeline takes an image and a mask as input and returns the inpainted image.
+
+FLUX.1 Depth and Canny [dev] is a 12 billion parameter rectified flow transformer capable of generating an image based on a text description while following the structure of a given input image. **This is not a ControlNet model**.
+
+| Control type | Developer | Link |
+| -------- | ---------- | ---- |
+| Depth | [Black Forest Labs](https://huggingface.co/black-forest-labs) | [Link](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev) |
+| Canny | [Black Forest Labs](https://huggingface.co/black-forest-labs) | [Link](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev) |
+
+
+
+
+Flux can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. For an exhaustive list of resources, check out [this gist](https://gist.github.com/sayakpaul/b664605caf0aa3bf8585ab109dd5ac9c).
+
+
+
+```python
+import torch
+from diffusers import FluxControlInpaintPipeline
+from diffusers.models.transformers import FluxTransformer2DModel
+from transformers import T5EncoderModel
+from diffusers.utils import load_image, make_image_grid
+from image_gen_aux import DepthPreprocessor # https://github.com/huggingface/image_gen_aux
+from PIL import Image
+import numpy as np
+
+pipe = FluxControlInpaintPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-Depth-dev",
+ torch_dtype=torch.bfloat16,
+)
+# use following lines if you have GPU constraints
+# ---------------------------------------------------------------
+transformer = FluxTransformer2DModel.from_pretrained(
+ "sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="transformer", torch_dtype=torch.bfloat16
+)
+text_encoder_2 = T5EncoderModel.from_pretrained(
+ "sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="text_encoder_2", torch_dtype=torch.bfloat16
+)
+pipe.transformer = transformer
+pipe.text_encoder_2 = text_encoder_2
+pipe.enable_model_cpu_offload()
+# ---------------------------------------------------------------
+pipe.to("cuda")
+
+prompt = "a blue robot singing opera with human-like expressions"
+image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
+
+head_mask = np.zeros_like(image)
+head_mask[65:580,300:642] = 255
+mask_image = Image.fromarray(head_mask)
+
+processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
+control_image = processor(image)[0].convert("RGB")
+
+output = pipe(
+ prompt=prompt,
+ image=image,
+ control_image=control_image,
+ mask_image=mask_image,
+ num_inference_steps=30,
+ strength=0.9,
+ guidance_scale=10.0,
+ generator=torch.Generator().manual_seed(42),
+).images[0]
+make_image_grid([image, control_image, mask_image, output.resize(image.size)], rows=1, cols=4).save("output.png")
+```
+
+## FluxControlInpaintPipeline
+[[autodoc]] FluxControlInpaintPipeline
+ - all
+ - __call__
+
+
+## FluxPipelineOutput
+[[autodoc]] pipelines.flux.pipeline_output.FluxPipelineOutput
\ No newline at end of file
diff --git a/docs/source/en/quantization/gguf.md b/docs/source/en/quantization/gguf.md
index dbcd1b1486b2..2ff2a9293130 100644
--- a/docs/source/en/quantization/gguf.md
+++ b/docs/source/en/quantization/gguf.md
@@ -25,9 +25,9 @@ pip install -U gguf
Since GGUF is a single file format, use [`~FromSingleFileMixin.from_single_file`] to load the model and pass in the [`GGUFQuantizationConfig`].
-When using GGUF checkpoints, the quantized weights remain in a low memory `dtype`(typically `torch.unint8`) and are dynamically dequantized and cast to the configured `compute_dtype` during each module's forward pass through the model. The `GGUFQuantizationConfig` allows you to set the `compute_dtype`.
+When using GGUF checkpoints, the quantized weights remain in a low memory `dtype`(typically `torch.uint8`) and are dynamically dequantized and cast to the configured `compute_dtype` during each module's forward pass through the model. The `GGUFQuantizationConfig` allows you to set the `compute_dtype`.
-The functions used for dynamic dequantizatation are based on the great work done by [city96](https://github.com/city96/ComfyUI-GGUF), who created the Pytorch ports of the original (`numpy`)[https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/quants.py] implementation by [compilade](https://github.com/compilade).
+The functions used for dynamic dequantizatation are based on the great work done by [city96](https://github.com/city96/ComfyUI-GGUF), who created the Pytorch ports of the original [`numpy`](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/quants.py) implementation by [compilade](https://github.com/compilade).
```python
import torch
diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md
index 6c2df7514d5e..3eef5238f1ce 100644
--- a/docs/source/en/quantization/overview.md
+++ b/docs/source/en/quantization/overview.md
@@ -33,8 +33,8 @@ If you are new to the quantization field, we recommend you to check out these be
## When to use what?
Diffusers currently supports the following quantization methods.
-- [BitsandBytes]()
-- [TorchAO]()
-- [GGUF]()
+- [BitsandBytes](./bitsandbytes.md)
+- [TorchAO](./torchao.md)
+- [GGUF](./gguf.md)
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
diff --git a/examples/community/pipeline_hunyuandit_differential_img2img.py b/examples/community/pipeline_hunyuandit_differential_img2img.py
index 3ece670e5bde..8cf2830f25ab 100644
--- a/examples/community/pipeline_hunyuandit_differential_img2img.py
+++ b/examples/community/pipeline_hunyuandit_differential_img2img.py
@@ -1008,6 +1008,8 @@ def __call__(
self.transformer.inner_dim // self.transformer.num_heads,
grid_crops_coords,
(grid_height, grid_width),
+ device=device,
+ output_type="pt",
)
style = torch.tensor([0], device=device)
diff --git a/examples/dreambooth/README_sana.md b/examples/dreambooth/README_sana.md
new file mode 100644
index 000000000000..fe861d62472b
--- /dev/null
+++ b/examples/dreambooth/README_sana.md
@@ -0,0 +1,127 @@
+# DreamBooth training example for SANA
+
+[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
+
+The `train_dreambooth_lora_sana.py` script shows how to implement the training procedure with [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) and adapt it for [SANA](https://arxiv.org/abs/2410.10629).
+
+
+This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the `examples/dreambooth` folder and run
+```bash
+pip install -r requirements_sana.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell (e.g., a notebook)
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment.
+
+
+### Dog toy example
+
+Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
+
+Let's first download it locally:
+
+```python
+from huggingface_hub import snapshot_download
+
+local_dir = "./dog"
+snapshot_download(
+ "diffusers/dog-example",
+ local_dir=local_dir, repo_type="dataset",
+ ignore_patterns=".gitattributes",
+)
+```
+
+This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
+
+Now, we can launch training using:
+
+```bash
+export MODEL_NAME="Efficient-Large-Model/Sana_1600M_1024px_diffusers"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="trained-sana-lora"
+
+accelerate launch train_dreambooth_lora_sana.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --mixed_precision="bf16" \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=1024 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --use_8bit_adam \
+ --learning_rate=1e-4 \
+ --report_to="wandb" \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=500 \
+ --validation_prompt="A photo of sks dog in a bucket" \
+ --validation_epochs=25 \
+ --seed="0" \
+ --push_to_hub
+```
+
+For using `push_to_hub`, make you're logged into your Hugging Face account:
+
+```bash
+huggingface-cli login
+```
+
+To better track our training experiments, we're using the following flags in the command above:
+
+* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login ` before training if you haven't done it before.
+* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
+
+## Notes
+
+Additionally, we welcome you to explore the following CLI arguments:
+
+* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only.
+* `--complex_human_instruction`: Instructions for complex human attention as shown in [here](https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55).
+* `--max_sequence_length`: Maximum sequence length to use for text embeddings.
+
+
+We provide several options for optimizing memory optimization:
+
+* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used.
+* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.
+* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
+
+Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana) of the `SanaPipeline` to know more about the models available under the SANA family and their preferred dtypes during inference.
\ No newline at end of file
diff --git a/examples/dreambooth/requirements_sana.txt b/examples/dreambooth/requirements_sana.txt
new file mode 100644
index 000000000000..04b4bd6c29c0
--- /dev/null
+++ b/examples/dreambooth/requirements_sana.txt
@@ -0,0 +1,8 @@
+accelerate>=1.0.0
+torchvision
+transformers>=4.47.0
+ftfy
+tensorboard
+Jinja2
+peft>=0.14.0
+sentencepiece
\ No newline at end of file
diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py
new file mode 100644
index 000000000000..4baa9f194feb
--- /dev/null
+++ b/examples/dreambooth/train_dreambooth_lora_sana.py
@@ -0,0 +1,1552 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import itertools
+import logging
+import math
+import os
+import random
+import shutil
+import warnings
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from peft import LoraConfig, set_peft_model_state_dict
+from peft.utils import get_peft_model_state_dict
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, Gemma2Model
+
+import diffusers
+from diffusers import (
+ AutoencoderDC,
+ FlowMatchEulerDiscreteScheduler,
+ SanaPipeline,
+ SanaTransformer2DModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import (
+ cast_training_params,
+ compute_density_for_timestep_sampling,
+ compute_loss_weighting_for_sd3,
+ free_memory,
+)
+from diffusers.utils import (
+ check_min_version,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.32.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ base_model: str = None,
+ instance_prompt=None,
+ validation_prompt=None,
+ repo_folder=None,
+):
+ widget_dict = []
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ widget_dict.append(
+ {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
+ )
+
+ model_description = f"""
+# Sana DreamBooth LoRA - {repo_id}
+
+
+
+## Model description
+
+These are {repo_id} DreamBooth LoRA weights for {base_model}.
+
+The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Sana diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sana.md).
+
+
+## Trigger words
+
+You should use `{instance_prompt}` to trigger the image generation.
+
+## Download model
+
+[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
+
+## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
+
+```py
+TODO
+```
+
+For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
+
+## License
+
+TODO
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="other",
+ base_model=base_model,
+ prompt=instance_prompt,
+ model_description=model_description,
+ widget=widget_dict,
+ )
+ tags = [
+ "text-to-image",
+ "diffusers-training",
+ "diffusers",
+ "lora",
+ "sana",
+ "sana-diffusers",
+ "template:sd-lora",
+ ]
+
+ model_card = populate_model_card(model_card, tags=tags)
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ is_final_validation=False,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ pipeline.text_encoder = pipeline.text_encoder.to(torch.bfloat16)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+
+ images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ help=("A folder containing the training data. "),
+ )
+
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+
+ parser.add_argument(
+ "--image_column",
+ type=str,
+ default="image",
+ help="The column of the dataset containing the target image. By "
+ "default, the standard Image Dataset maps out 'file_name' "
+ "to 'image'.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default=None,
+ help="The column of the dataset containing the instance prompt for each image",
+ )
+
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
+
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--max_sequence_length",
+ type=int,
+ default=300,
+ help="Maximum sequence length to use with with the Gemma model",
+ )
+ parser.add_argument(
+ "--complex_human_instruction",
+ type=str,
+ default=None,
+ help="Instructions for complex human attention: https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sana-dreambooth-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--weighting_scheme",
+ type=str,
+ default="none",
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
+ help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
+ )
+ parser.add_argument(
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--mode_scale",
+ type=float,
+ default=1.29,
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
+ )
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="AdamW",
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
+ )
+
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
+ )
+
+ parser.add_argument(
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--prodigy_beta3",
+ type=float,
+ default=None,
+ help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+ parser.add_argument(
+ "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
+ )
+
+ parser.add_argument(
+ "--lora_layers",
+ type=str,
+ default=None,
+ help=(
+ 'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only'
+ ),
+ )
+
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
+ )
+
+ parser.add_argument(
+ "--prodigy_use_bias_correction",
+ type=bool,
+ default=True,
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
+ )
+ parser.add_argument(
+ "--prodigy_safeguard_warmup",
+ type=bool,
+ default=True,
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
+ "Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--cache_latents",
+ action="store_true",
+ default=False,
+ help="Cache the VAE latents",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--upcast_before_saving",
+ action="store_true",
+ default=False,
+ help=(
+ "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
+ "Defaults to precision dtype used for training to save memory"
+ ),
+ )
+ parser.add_argument(
+ "--offload",
+ action="store_true",
+ help="Whether to offload the VAE and the text encoder to CPU when they are not used.",
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.instance_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
+
+ if args.dataset_name is not None and args.instance_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ class_prompt,
+ class_data_root=None,
+ class_num=None,
+ size=1024,
+ repeats=1,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+
+ self.instance_prompt = instance_prompt
+ self.custom_instance_prompts = None
+ self.class_prompt = class_prompt
+
+ # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
+ # we load the training data using load_dataset
+ if args.dataset_name is not None:
+ try:
+ from datasets import load_dataset
+ except ImportError:
+ raise ImportError(
+ "You are trying to load your data using the datasets library. If you wish to train using custom "
+ "captions please install the datasets library: `pip install datasets`. If you wish to load a "
+ "local folder containing images only, specify --instance_data_dir instead."
+ )
+ # Downloading and loading a dataset from the hub.
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ # Preprocessing the datasets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ instance_images = dataset["train"][image_column]
+
+ if args.caption_column is None:
+ logger.info(
+ "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
+ "contains captions/prompts for the images, make sure to specify the "
+ "column as --caption_column"
+ )
+ self.custom_instance_prompts = None
+ else:
+ if args.caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ custom_instance_prompts = dataset["train"][args.caption_column]
+ # create final list of captions according to --repeats
+ self.custom_instance_prompts = []
+ for caption in custom_instance_prompts:
+ self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
+ else:
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
+ self.custom_instance_prompts = None
+
+ self.instance_images = []
+ for img in instance_images:
+ self.instance_images.extend(itertools.repeat(img, repeats))
+
+ self.pixel_values = []
+ train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ train_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+ for image in self.instance_images:
+ image = exif_transpose(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = train_resize(image)
+ if args.random_flip and random.random() < 0.5:
+ # flip
+ image = train_flip(image)
+ if args.center_crop:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ image = train_transforms(image)
+ self.pixel_values.append(image)
+
+ self.num_instance_images = len(self.instance_images)
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = self.pixel_values[index % self.num_instance_images]
+ example["instance_images"] = instance_image
+
+ if self.custom_instance_prompts:
+ caption = self.custom_instance_prompts[index % self.num_instance_images]
+ if caption:
+ example["instance_prompt"] = caption
+ else:
+ example["instance_prompt"] = self.instance_prompt
+
+ else: # custom prompts were provided, but length does not match size of image dataset
+ example["instance_prompt"] = self.instance_prompt
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ class_image = exif_transpose(class_image)
+
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_prompt"] = self.class_prompt
+
+ return example
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ pixel_values = [example["instance_images"] for example in examples]
+ prompts = [example["instance_prompt"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ pixel_values += [example["class_images"] for example in examples]
+ prompts += [example["class_prompt"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ batch = {"pixel_values": pixel_values, "prompts": prompts}
+ return batch
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ pipeline = SanaPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch.float32,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.text_encoder = pipeline.text_encoder.to(torch.bfloat16)
+ pipeline.transformer = pipeline.transformer.to(torch.float16)
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ ).repo_id
+
+ # Load the tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ )
+
+ # Load scheduler and models
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="scheduler"
+ )
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
+ text_encoder = Gemma2Model.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ vae = AutoencoderDC.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ transformer = SanaTransformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
+ )
+
+ # We only train the additional adapter LoRA layers
+ transformer.requires_grad_(False)
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+
+ # Initialize a text encoding pipeline and keep it to CPU for now.
+ text_encoding_pipeline = SanaPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=None,
+ transformer=None,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ )
+
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ # VAE should always be kept in fp32 for SANA (?)
+ vae.to(dtype=torch.float32)
+ transformer.to(accelerator.device, dtype=weight_dtype)
+ # because Gemma2 is particularly suited for bfloat16.
+ text_encoder.to(dtype=torch.bfloat16)
+
+ if args.gradient_checkpointing:
+ transformer.enable_gradient_checkpointing()
+
+ if args.lora_layers is not None:
+ target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
+ else:
+ target_modules = ["to_k", "to_q", "to_v"]
+
+ # now we will add new LoRA weights the transformer layers
+ transformer_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=target_modules,
+ )
+ transformer.add_adapter(transformer_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ transformer_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ SanaPipeline.save_lora_weights(
+ output_dir,
+ transformer_lora_layers=transformer_lora_layers_to_save,
+ )
+
+ def load_model_hook(models, input_dir):
+ transformer_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict = SanaPipeline.lora_state_dict(input_dir)
+
+ transformer_state_dict = {
+ f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ }
+ transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
+ incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [transformer_]
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32 and torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [transformer]
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
+
+ # Optimization parameters
+ transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
+ params_to_optimize = [transformer_parameters_with_lr]
+
+ # Optimizer creation
+ if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
+ logger.warning(
+ f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
+ "Defaulting to adamW"
+ )
+ args.optimizer = "adamw"
+
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
+ logger.warning(
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
+ f"set to {args.optimizer.lower()}"
+ )
+
+ if args.optimizer.lower() == "adamw":
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.optimizer.lower() == "prodigy":
+ try:
+ import prodigyopt
+ except ImportError:
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
+
+ optimizer_class = prodigyopt.Prodigy
+
+ if args.learning_rate <= 0.1:
+ logger.warning(
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
+ )
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ beta3=args.prodigy_beta3,
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ decouple=args.prodigy_decouple,
+ use_bias_correction=args.prodigy_use_bias_correction,
+ safeguard_warmup=args.prodigy_safeguard_warmup,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_prompt=args.class_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_num=args.num_class_images,
+ size=args.resolution,
+ repeats=args.repeats,
+ center_crop=args.center_crop,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ def compute_text_embeddings(prompt, text_encoding_pipeline):
+ text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
+ with torch.no_grad():
+ prompt_embeds, prompt_attention_mask, _, _ = text_encoding_pipeline.encode_prompt(
+ prompt,
+ max_sequence_length=args.max_sequence_length,
+ complex_human_instruction=args.complex_human_instruction,
+ )
+ if args.offload:
+ text_encoding_pipeline = text_encoding_pipeline.to("cpu")
+ return prompt_embeds, prompt_attention_mask
+
+ # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
+ # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
+ # the redundant encoding.
+ if not train_dataset.custom_instance_prompts:
+ instance_prompt_hidden_states, instance_prompt_attention_mask = compute_text_embeddings(
+ args.instance_prompt, text_encoding_pipeline
+ )
+
+ # Handle class prompt for prior-preservation.
+ if args.with_prior_preservation:
+ class_prompt_hidden_states, class_prompt_attention_mask = compute_text_embeddings(
+ args.class_prompt, text_encoding_pipeline
+ )
+
+ # Clear the memory here
+ if not train_dataset.custom_instance_prompts:
+ del text_encoder, tokenizer
+ free_memory()
+
+ # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
+ # pack the statically computed variables appropriately here. This is so that we don't
+ # have to pass them to the dataloader.
+ if not train_dataset.custom_instance_prompts:
+ prompt_embeds = instance_prompt_hidden_states
+ prompt_attention_mask = instance_prompt_attention_mask
+ if args.with_prior_preservation:
+ prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
+ prompt_attention_mask = torch.cat([prompt_attention_mask, class_prompt_attention_mask], dim=0)
+
+ vae_config_scaling_factor = vae.config.scaling_factor
+ if args.cache_latents:
+ latents_cache = []
+ vae = vae.to("cuda")
+ for batch in tqdm(train_dataloader, desc="Caching latents"):
+ with torch.no_grad():
+ batch["pixel_values"] = batch["pixel_values"].to(
+ accelerator.device, non_blocking=True, dtype=vae.dtype
+ )
+ latents_cache.append(vae.encode(batch["pixel_values"]).latent)
+
+ if args.validation_prompt is None:
+ del vae
+ free_memory()
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ transformer, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_name = "dreambooth-sana-lora"
+ accelerator.init_trackers(tracker_name, config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ transformer.train()
+
+ for step, batch in enumerate(train_dataloader):
+ models_to_accumulate = [transformer]
+ with accelerator.accumulate(models_to_accumulate):
+ prompts = batch["prompts"]
+
+ # encode batch prompts when custom prompts are provided for each image -
+ if train_dataset.custom_instance_prompts:
+ prompt_embeds, prompt_attention_mask = compute_text_embeddings(prompts, text_encoding_pipeline)
+
+ # Convert images to latent space
+ if args.cache_latents:
+ model_input = latents_cache[step]
+ else:
+ vae = vae.to(accelerator.device)
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ model_input = vae.encode(pixel_values).latent
+ if args.offload:
+ vae = vae.to("cpu")
+ model_input = model_input * vae_config_scaling_factor
+ model_input = model_input.to(dtype=weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz = model_input.shape[0]
+
+ # Sample a random timestep for each image
+ # for weighting schemes where we sample timesteps non-uniformly
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=args.weighting_scheme,
+ batch_size=bsz,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ mode_scale=args.mode_scale,
+ )
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
+
+ # Add noise according to flow matching.
+ # zt = (1 - texp) * x + texp * z1
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
+
+ # Predict the noise residual
+ model_pred = transformer(
+ hidden_states=noisy_model_input,
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ timestep=timesteps,
+ return_dict=False,
+ )[0]
+
+ # these weighting schemes use a uniform timestep sampling
+ # and instead post-weight the loss
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
+
+ # flow matching loss
+ target = noise - model_input
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute prior loss
+ prior_loss = torch.mean(
+ (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
+ target_prior.shape[0], -1
+ ),
+ 1,
+ )
+ prior_loss = prior_loss.mean()
+
+ # Compute regular loss.
+ loss = torch.mean(
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
+ 1,
+ )
+ loss = loss.mean()
+
+ if args.with_prior_preservation:
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = transformer.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ # create pipeline
+ pipeline = SanaPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ transformer=accelerator.unwrap_model(transformer),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=torch.float32,
+ )
+ pipeline_args = {
+ "prompt": args.validation_prompt,
+ "complex_human_instruction": args.complex_human_instruction,
+ }
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ )
+ free_memory()
+
+ images = None
+ del pipeline
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ transformer = unwrap_model(transformer)
+ if args.upcast_before_saving:
+ transformer.to(torch.float32)
+ else:
+ transformer = transformer.to(weight_dtype)
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
+
+ SanaPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ transformer_lora_layers=transformer_lora_layers,
+ )
+
+ # Final inference
+ # Load previous pipeline
+ pipeline = SanaPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=torch.float32,
+ )
+ pipeline.transformer = pipeline.transformer.to(torch.float16)
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline_args = {
+ "prompt": args.validation_prompt,
+ "complex_human_instruction": args.complex_human_instruction,
+ }
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ is_final_validation=True,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ instance_prompt=args.instance_prompt,
+ validation_prompt=args.validation_prompt,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ images = None
+ del pipeline
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index e2351a0c53b8..91b297f8c007 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -277,6 +277,7 @@
"CogView3PlusPipeline",
"CycleDiffusionPipeline",
"FluxControlImg2ImgPipeline",
+ "FluxControlInpaintPipeline",
"FluxControlNetImg2ImgPipeline",
"FluxControlNetInpaintPipeline",
"FluxControlNetPipeline",
@@ -765,6 +766,7 @@
CogView3PlusPipeline,
CycleDiffusionPipeline,
FluxControlImg2ImgPipeline,
+ FluxControlInpaintPipeline,
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline,
diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py
index d59830e614e9..b59150376599 100644
--- a/src/diffusers/loaders/__init__.py
+++ b/src/diffusers/loaders/__init__.py
@@ -70,6 +70,7 @@ def text_encoder_attn_modules(text_encoder):
"FluxLoraLoaderMixin",
"CogVideoXLoraLoaderMixin",
"Mochi1LoraLoaderMixin",
+ "SanaLoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
@@ -92,6 +93,7 @@ def text_encoder_attn_modules(text_encoder):
LoraLoaderMixin,
LTXVideoLoraLoaderMixin,
Mochi1LoraLoaderMixin,
+ SanaLoraLoaderMixin,
SD3LoraLoaderMixin,
StableDiffusionLoraLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index 869a5cca24f5..b8c44e480093 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -3562,6 +3562,314 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
super().unfuse_lora(components=components)
+class SanaLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ return state_dict
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
+ def load_lora_weights(
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel
+ def load_lora_into_transformer(
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ transformer (`SanaTransformer2DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not transformer_lora_layers:
+ raise ValueError("You must pass `transformer_lora_layers`.")
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer", "text_encoder"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ unfuse_text_encoder (`bool`, defaults to `True`):
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
+ LoRA parameters then it won't have any effect.
+ """
+ super().unfuse_lora(components=components)
+
+
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
def __init__(self, *args, **kwargs):
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py
index 3dddb94f30c1..a791a250af08 100644
--- a/src/diffusers/loaders/peft.py
+++ b/src/diffusers/loaders/peft.py
@@ -54,6 +54,7 @@
"CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
"MochiTransformer3DModel": lambda model_cls, weights: weights,
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
+ "SanaTransformer2DModel": lambda model_cls, weights: weights,
}
diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py
index 4e288737fe88..ded466b35e9a 100644
--- a/src/diffusers/loaders/single_file_utils.py
+++ b/src/diffusers/loaders/single_file_utils.py
@@ -151,6 +151,8 @@
"animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
+ "flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
+ "flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
"ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"},
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
@@ -587,7 +589,13 @@ def infer_diffusers_model_type(checkpoint):
if any(
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
):
- model_type = "flux-dev"
+ if checkpoint["img_in.weight"].shape[1] == 384:
+ model_type = "flux-fill"
+
+ elif checkpoint["img_in.weight"].shape[1] == 128:
+ model_type = "flux-depth"
+ else:
+ model_type = "flux-dev"
else:
model_type = "flux-schnell"
diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py
index 6749c7f17254..4d1dae879f11 100644
--- a/src/diffusers/models/attention.py
+++ b/src/diffusers/models/attention.py
@@ -188,8 +188,13 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
self._chunk_dim = dim
def forward(
- self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor,
+ temb: torch.FloatTensor,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
):
+ joint_attention_kwargs = joint_attention_kwargs or {}
if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
hidden_states, emb=temb
@@ -206,7 +211,9 @@ def forward(
# Attention.
attn_output, context_attn_output = self.attn(
- hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ **joint_attention_kwargs,
)
# Process attention outputs for the `hidden_states`.
@@ -214,7 +221,7 @@ def forward(
hidden_states = hidden_states + attn_output
if self.use_dual_attention:
- attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
+ attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
hidden_states = hidden_states + attn_output2
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index 0f4b555a2d71..69b3ee8466f4 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -691,7 +691,7 @@ def _get_positional_embeddings(
output_type="pt",
)
pos_embedding = pos_embedding.flatten(0, 1)
- joint_pos_embedding = torch.zeros(
+ joint_pos_embedding = pos_embedding.new_zeros(
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
)
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
@@ -957,7 +957,57 @@ def get_3d_rotary_pos_embed_allegro(
return freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w
-def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
+def get_2d_rotary_pos_embed(
+ embed_dim, crops_coords, grid_size, use_real=True, device: Optional[torch.device] = None, output_type: str = "np"
+):
+ """
+ RoPE for image tokens with 2d structure.
+
+ Args:
+ embed_dim: (`int`):
+ The embedding dimension size
+ crops_coords (`Tuple[int]`)
+ The top-left and bottom-right coordinates of the crop.
+ grid_size (`Tuple[int]`):
+ The grid size of the positional embedding.
+ use_real (`bool`):
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
+ device: (`torch.device`, **optional**):
+ The device used to create tensors.
+
+ Returns:
+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
+ """
+ if output_type == "np":
+ deprecation_message = (
+ "`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
+ " `from_numpy` is no longer required."
+ " Pass `output_type='pt' to use the new version now."
+ )
+ deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
+ return _get_2d_rotary_pos_embed_np(
+ embed_dim=embed_dim,
+ crops_coords=crops_coords,
+ grid_size=grid_size,
+ use_real=use_real,
+ )
+ start, stop = crops_coords
+ # scale end by (steps−1)/steps matches np.linspace(..., endpoint=False)
+ grid_h = torch.linspace(
+ start[0], stop[0] * (grid_size[0] - 1) / grid_size[0], grid_size[0], device=device, dtype=torch.float32
+ )
+ grid_w = torch.linspace(
+ start[1], stop[1] * (grid_size[1] - 1) / grid_size[1], grid_size[1], device=device, dtype=torch.float32
+ )
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
+ grid = torch.stack(grid, dim=0) # [2, W, H]
+
+ grid = grid.reshape([2, 1, *grid.shape[1:]])
+ pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
+ return pos_embed
+
+
+def _get_2d_rotary_pos_embed_np(embed_dim, crops_coords, grid_size, use_real=True):
"""
RoPE for image tokens with 2d structure.
diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py
index dba67f45fce9..41224e42d2a5 100644
--- a/src/diffusers/models/transformers/sana_transformer.py
+++ b/src/diffusers/models/transformers/sana_transformer.py
@@ -18,7 +18,8 @@
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
-from ...utils import is_torch_version, logging
+from ...loaders import PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ..attention_processor import (
Attention,
AttentionProcessor,
@@ -180,7 +181,7 @@ def forward(
return hidden_states
-class SanaTransformer2DModel(ModelMixin, ConfigMixin):
+class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
r"""
A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models.
@@ -363,8 +364,24 @@ def forward(
timestep: torch.LongTensor,
encoder_attention_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
@@ -460,6 +477,11 @@ def custom_forward(*inputs):
hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4)
output = hidden_states.reshape(batch_size, -1, post_patch_height * p, post_patch_width * p)
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
if not return_dict:
return (output,)
+
return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py
index 79452bb85176..79c4069e9a37 100644
--- a/src/diffusers/models/transformers/transformer_sd3.py
+++ b/src/diffusers/models/transformers/transformer_sd3.py
@@ -411,11 +411,15 @@ def custom_forward(*inputs):
hidden_states,
encoder_hidden_states,
temb,
+ joint_attention_kwargs,
**ckpt_kwargs,
)
elif not is_skip:
encoder_hidden_states, hidden_states = block(
- hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
diff --git a/src/diffusers/models/unets/unet_2d.py b/src/diffusers/models/unets/unet_2d.py
index 5972505f2897..d05af686dede 100644
--- a/src/diffusers/models/unets/unet_2d.py
+++ b/src/diffusers/models/unets/unet_2d.py
@@ -97,6 +97,7 @@ def __init__(
out_channels: int = 3,
center_input_sample: bool = False,
time_embedding_type: str = "positional",
+ time_embedding_dim: Optional[int] = None,
freq_shift: int = 0,
flip_sin_to_cos: bool = True,
down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
@@ -122,7 +123,7 @@ def __init__(
super().__init__()
self.sample_size = sample_size
- time_embed_dim = block_out_channels[0] * 4
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
# Check inputs
if len(down_block_types) != len(up_block_types):
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index e7fd7ec78bed..ce291e5ceb45 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -128,6 +128,7 @@
]
_import_structure["flux"] = [
"FluxControlPipeline",
+ "FluxControlInpaintPipeline",
"FluxControlImg2ImgPipeline",
"FluxControlNetPipeline",
"FluxControlNetImg2ImgPipeline",
@@ -539,6 +540,7 @@
)
from .flux import (
FluxControlImg2ImgPipeline,
+ FluxControlInpaintPipeline,
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline,
diff --git a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py
index 45e17f3de1e2..c8464f8108ea 100644
--- a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py
+++ b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py
@@ -925,7 +925,11 @@ def __call__(
base_size = 512 // 8 // self.transformer.config.patch_size
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
image_rotary_emb = get_2d_rotary_pos_embed(
- self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width)
+ self.transformer.inner_dim // self.transformer.num_heads,
+ grid_crops_coords,
+ (grid_height, grid_width),
+ device=device,
+ output_type="pt",
)
style = torch.tensor([0], device=device)
diff --git a/src/diffusers/pipelines/flux/__init__.py b/src/diffusers/pipelines/flux/__init__.py
index 3570368a5ca1..72e1b578f2ca 100644
--- a/src/diffusers/pipelines/flux/__init__.py
+++ b/src/diffusers/pipelines/flux/__init__.py
@@ -26,6 +26,7 @@
_import_structure["pipeline_flux"] = ["FluxPipeline"]
_import_structure["pipeline_flux_control"] = ["FluxControlPipeline"]
_import_structure["pipeline_flux_control_img2img"] = ["FluxControlImg2ImgPipeline"]
+ _import_structure["pipeline_flux_control_inpaint"] = ["FluxControlInpaintPipeline"]
_import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
_import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"]
_import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"]
@@ -44,6 +45,7 @@
from .pipeline_flux import FluxPipeline
from .pipeline_flux_control import FluxControlPipeline
from .pipeline_flux_control_img2img import FluxControlImg2ImgPipeline
+ from .pipeline_flux_control_inpaint import FluxControlInpaintPipeline
from .pipeline_flux_controlnet import FluxControlNetPipeline
from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline
from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py
new file mode 100644
index 000000000000..a9ac1c72c6ed
--- /dev/null
+++ b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py
@@ -0,0 +1,1141 @@
+# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import (
+ CLIPTextModel,
+ CLIPTokenizer,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import (
+ FluxLoraLoaderMixin,
+ FromSingleFileMixin,
+ TextualInversionLoaderMixin,
+)
+from ...models.autoencoders import AutoencoderKL
+from ...models.transformers import FluxTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import FluxPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ import torch
+ from diffusers import FluxControlInpaintPipeline
+ from diffusers.models.transformers import FluxTransformer2DModel
+ from transformers import T5EncoderModel
+ from diffusers.utils import load_image, make_image_grid
+ from image_gen_aux import DepthPreprocessor # https://github.com/huggingface/image_gen_aux
+ from PIL import Image
+ import numpy as np
+
+ pipe = FluxControlInpaintPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-Depth-dev",
+ torch_dtype=torch.bfloat16,
+ )
+ # use following lines if you have GPU constraints
+ # ---------------------------------------------------------------
+ transformer = FluxTransformer2DModel.from_pretrained(
+ "sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="transformer", torch_dtype=torch.bfloat16
+ )
+ text_encoder_2 = T5EncoderModel.from_pretrained(
+ "sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="text_encoder_2", torch_dtype=torch.bfloat16
+ )
+ pipe.transformer = transformer
+ pipe.text_encoder_2 = text_encoder_2
+ pipe.enable_model_cpu_offload()
+ # ---------------------------------------------------------------
+ pipe.to("cuda")
+
+ prompt = "a blue robot singing opera with human-like expressions"
+ image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
+
+ head_mask = np.zeros_like(image)
+ head_mask[65:580, 300:642] = 255
+ mask_image = Image.fromarray(head_mask)
+
+ processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
+ control_image = processor(image)[0].convert("RGB")
+
+ output = pipe(
+ prompt=prompt,
+ image=image,
+ control_image=control_image,
+ mask_image=mask_image,
+ num_inference_steps=30,
+ strength=0.9,
+ guidance_scale=10.0,
+ generator=torch.Generator().manual_seed(42),
+ ).images[0]
+ make_image_grid([image, control_image, mask_image, output.resize(image.size)], rows=1, cols=4).save(
+ "output.png"
+ )
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.16,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class FluxControlInpaintPipeline(
+ DiffusionPipeline,
+ FluxLoraLoaderMixin,
+ FromSingleFileMixin,
+ TextualInversionLoaderMixin,
+):
+ r"""
+ The Flux pipeline for image inpainting using Flux-dev-Depth/Canny.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: T5TokenizerFast,
+ transformer: FluxTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ )
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2,
+ vae_latent_channels=self.vae.config.latent_channels,
+ do_normalize=False,
+ do_binarize=True,
+ do_convert_grayscale=True,
+ )
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ strength,
+ height,
+ width,
+ prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+
+ return latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def prepare_latents(
+ self,
+ image,
+ timestep,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ shape = (batch_size, num_channels_latents, height, width)
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+ return latents, noise, image_latents, latent_image_ids
+
+ # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ if isinstance(image, torch.Tensor):
+ pass
+ else:
+ image = self.image_processor.preprocess(image, height=height, width=width)
+
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ def prepare_mask_latents(
+ self,
+ image,
+ mask_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ image = self.image_processor.preprocess(image, height=height, width=width)
+ mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width)
+
+ masked_image = image * (1 - mask_image)
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask_image = torch.nn.functional.interpolate(mask_image, size=(height, width))
+ mask_image = mask_image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ if masked_image.shape[1] == num_channels_latents:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
+
+ masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask_image.shape[0] < batch_size:
+ if not batch_size % mask_image.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask_image.shape[0]} mask_image were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask_image = mask_image.repeat(batch_size // mask_image.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ masked_image_latents = self._pack_latents(
+ masked_image_latents,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+ mask_image = self._pack_latents(
+ mask_image.repeat(1, num_channels_latents, 1, 1),
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+ masked_image_latents = torch.cat((masked_image_latents, mask_image), dim=-1)
+
+ return mask_image, masked_image_latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: PipelineImageInput = None,
+ control_image: PipelineImageInput = None,
+ mask_image: PipelineImageInput = None,
+ masked_image_latents: PipelineImageInput = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 0.6,
+ num_inference_steps: int = 28,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 7.0,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
+ images must be passed as a list such that each element of the list can be correctly batched for input
+ to a single ControlNet.
+ mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
+ are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
+ single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
+ color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
+ H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
+ 1)`, or `(H, W)`.
+ mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
+ `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
+ latents tensor will ge generated by `mask_image`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ strength (`float`, *optional*, defaults to 1.0):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ strength,
+ height,
+ width,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+ device = self._execution_device
+
+ # 3. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Prepare text embeddings
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 3. Preprocess mask and image
+ num_channels_latents = self.vae.config.latent_channels
+ if masked_image_latents is not None:
+ # pre computed masked_image_latents and mask_image
+ masked_image_latents = masked_image_latents.to(latents.device)
+ mask = mask_image.to(latents.device)
+ else:
+ mask, masked_image_latents = self.prepare_mask_latents(
+ image,
+ mask_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
+ init_image = init_image.to(dtype=torch.float32)
+
+ # 4.Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.base_image_seq_len,
+ self.scheduler.config.max_image_seq_len,
+ self.scheduler.config.base_shift,
+ self.scheduler.config.max_shift,
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 8
+
+ control_image = self.prepare_image(
+ image=control_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.vae.dtype,
+ )
+
+ if control_image.ndim == 4:
+ control_image = self.vae.encode(control_image).latent_dist.sample(generator=generator)
+ control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ height_control_image, width_control_image = control_image.shape[2:]
+ control_image = self._pack_latents(
+ control_image,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height_control_image,
+ width_control_image,
+ )
+
+ latents, noise, image_latents, latent_image_ids = self.prepare_latents(
+ init_image,
+ latent_timestep,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height_8 = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width_8 = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = torch.cat([latents, control_image], dim=2)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ # for 64 channel transformer only.
+ init_mask = mask
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.scale_noise(
+ image_latents, torch.tensor([noise_timestep]), noise
+ )
+ else:
+ init_latents_proper = image_latents
+ init_latents_proper = self._pack_latents(
+ init_latents_proper, batch_size * num_images_per_prompt, num_channels_latents, height_8, width_8
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
index bda718cb197d..6f542cb59f46 100644
--- a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
+++ b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
@@ -798,7 +798,11 @@ def __call__(
base_size = 512 // 8 // self.transformer.config.patch_size
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
image_rotary_emb = get_2d_rotary_pos_embed(
- self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width)
+ self.transformer.inner_dim // self.transformer.num_heads,
+ grid_crops_coords,
+ (grid_height, grid_width),
+ device=device,
+ output_type="pt",
)
style = torch.tensor([0], device=device)
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py
index 543af08f2e3c..7180601dad41 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+# Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
index 6d2afc56ed39..fbb30e304d65 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+# Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py
index dfc0a9be278d..aac4e32e33f0 100644
--- a/src/diffusers/pipelines/mochi/pipeline_mochi.py
+++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+# Copyright 2024 Genmo and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -188,6 +188,7 @@ def __init__(
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
transformer: MochiTransformer3DModel,
+ force_zeros_for_empty_prompt: bool = False,
):
super().__init__()
@@ -205,10 +206,11 @@ def __init__(
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor)
self.tokenizer_max_length = (
- self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 256
)
self.default_height = 480
self.default_width = 848
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
def _get_t5_prompt_embeds(
self,
@@ -236,7 +238,11 @@ def _get_t5_prompt_embeds(
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.bool().to(device)
- if prompt == "" or prompt[-1] == "":
+
+ # The original Mochi implementation zeros out empty negative prompts
+ # but this can lead to overflow when placing the entire pipeline under the autocast context
+ # adding this here so that we can enable zeroing prompts if necessary
+ if self.config.force_zeros_for_empty_prompt and (prompt == "" or prompt[-1] == ""):
text_input_ids = torch.zeros_like(text_input_ids, device=device)
prompt_attention_mask = torch.zeros_like(prompt_attention_mask, dtype=torch.bool, device=device)
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py
index 408992378538..dea1f12696b2 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py
@@ -818,7 +818,11 @@ def __call__(
base_size = 512 // 8 // self.transformer.config.patch_size
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
image_rotary_emb = get_2d_rotary_pos_embed(
- self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width)
+ self.transformer.inner_dim // self.transformer.num_heads,
+ grid_crops_coords,
+ (grid_height, grid_width),
+ device=device,
+ output_type="pt",
)
style = torch.tensor([0], device=device)
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py
index c6e7554e6b69..cf4d41fee487 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py
@@ -170,7 +170,6 @@ def __init__(
pag_attn_processors=(PAGCFGSanaLinearAttnProcessor2_0(), PAGIdentitySanaLinearAttnProcessor2_0()),
)
- # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py
index 80736d498e0f..2df6586d0bc4 100644
--- a/src/diffusers/pipelines/sana/pipeline_sana.py
+++ b/src/diffusers/pipelines/sana/pipeline_sana.py
@@ -16,21 +16,25 @@
import inspect
import re
import urllib.parse as ul
-from typing import Callable, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PixArtImageProcessor
+from ...loaders import SanaLoraLoaderMixin
from ...models import AutoencoderDC, SanaTransformer2DModel
from ...schedulers import DPMSolverMultistepScheduler
from ...utils import (
BACKENDS_MAPPING,
+ USE_PEFT_BACKEND,
is_bs4_available,
is_ftfy_available,
logging,
replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
@@ -130,7 +134,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
-class SanaPipeline(DiffusionPipeline):
+class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
r"""
Pipeline for text-to-image generation using [Sana](https://huggingface.co/papers/2410.10629).
"""
@@ -177,6 +181,7 @@ def encode_prompt(
clean_caption: bool = False,
max_sequence_length: int = 300,
complex_human_instruction: Optional[List[str]] = None,
+ lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
@@ -210,6 +215,15 @@ def encode_prompt(
if device is None:
device = self._execution_device
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
@@ -305,6 +319,11 @@ def encode_prompt(
negative_prompt_embeds = None
negative_prompt_attention_mask = None
+ if self.text_encoder is not None:
+ if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
@@ -554,6 +573,10 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
def guidance_scale(self):
return self._guidance_scale
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@@ -590,6 +613,7 @@ def __call__(
return_dict: bool = True,
clean_caption: bool = True,
use_resolution_binning: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 300,
@@ -662,6 +686,10 @@ def __call__(
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
+ attention_kwargs:
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clean_caption (`bool`, *optional*, defaults to `True`):
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
be installed. If the dependencies are not installed, the embeddings will be created from the raw
@@ -722,6 +750,7 @@ def __call__(
)
self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
self._interrupt = False
# 2. Default height and width to transformer
@@ -733,6 +762,7 @@ def __call__(
batch_size = prompt_embeds.shape[0]
device = self._execution_device
+ lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
# 3. Encode input prompt
(
@@ -753,6 +783,7 @@ def __call__(
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
complex_human_instruction=complex_human_instruction,
+ lora_scale=lora_scale,
)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
@@ -801,6 +832,7 @@ def __call__(
encoder_attention_mask=prompt_attention_mask,
timestep=timestep,
return_dict=False,
+ attention_kwargs=self.attention_kwargs,
)[0]
noise_pred = noise_pred.float()
diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py
index 3350c3373ecf..6a653f183bba 100644
--- a/src/diffusers/schedulers/scheduling_deis_multistep.py
+++ b/src/diffusers/schedulers/scheduling_deis_multistep.py
@@ -289,6 +289,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
sigmas = 1.0 - alphas
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
timesteps = (sigmas * self.config.num_train_timesteps).copy()
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
index 19399a724a41..971817f7b777 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
@@ -291,14 +291,17 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
elif self.config.use_flow_sigmas:
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
sigmas = 1.0 - alphas
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
timesteps = (sigmas * self.config.num_train_timesteps).copy()
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_max = (
diff --git a/src/diffusers/schedulers/scheduling_sasolver.py b/src/diffusers/schedulers/scheduling_sasolver.py
index 41a471275fa2..d45c93880bc5 100644
--- a/src/diffusers/schedulers/scheduling_sasolver.py
+++ b/src/diffusers/schedulers/scheduling_sasolver.py
@@ -318,6 +318,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
sigmas = 1.0 - alphas
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
timesteps = (sigmas * self.config.num_train_timesteps).copy()
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py
index c6434c6f87c6..01500426305c 100644
--- a/src/diffusers/schedulers/scheduling_unipc_multistep.py
+++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py
@@ -381,6 +381,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
sigmas = 1.0 - alphas
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
timesteps = (sigmas * self.config.num_train_timesteps).copy()
+ if self.config.final_sigmas_type == "sigma_min":
+ sigma_last = sigmas[-1]
+ elif self.config.final_sigmas_type == "zero":
+ sigma_last = 0
+ else:
+ raise ValueError(
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
+ )
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.final_sigmas_type == "sigma_min":
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index e148c025d191..9b36be9e0604 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -392,6 +392,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class FluxControlInpaintPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class FluxControlNetImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py
index 15f8ebf4505c..aa7a1619a183 100644
--- a/tests/lora/test_lora_layers_cogvideox.py
+++ b/tests/lora/test_lora_layers_cogvideox.py
@@ -29,7 +29,6 @@
)
from diffusers.utils.testing_utils import (
floats_tensor,
- is_peft_available,
is_torch_version,
require_peft_backend,
skip_mps,
@@ -37,9 +36,6 @@
)
-if is_peft_available():
- pass
-
sys.path.append(".")
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
diff --git a/tests/lora/test_lora_layers_mochi.py b/tests/lora/test_lora_layers_mochi.py
index 0a07e3d096bb..4bfc5a824d43 100644
--- a/tests/lora/test_lora_layers_mochi.py
+++ b/tests/lora/test_lora_layers_mochi.py
@@ -23,7 +23,6 @@
from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel
from diffusers.utils.testing_utils import (
floats_tensor,
- is_peft_available,
is_torch_version,
require_peft_backend,
skip_mps,
@@ -31,9 +30,6 @@
)
-if is_peft_available():
- pass
-
sys.path.append(".")
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
diff --git a/tests/lora/test_lora_layers_sana.py b/tests/lora/test_lora_layers_sana.py
new file mode 100644
index 000000000000..499ca89262a0
--- /dev/null
+++ b/tests/lora/test_lora_layers_sana.py
@@ -0,0 +1,138 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+import unittest
+
+import torch
+from transformers import Gemma2ForCausalLM, GemmaTokenizer
+
+from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel
+from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
+
+
+sys.path.append(".")
+
+from utils import PeftLoraLoaderMixinTests # noqa: E402
+
+
+@require_peft_backend
+class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
+ pipeline_class = SanaPipeline
+ scheduler_cls = FlowMatchEulerDiscreteScheduler(shift=7.0)
+ scheduler_kwargs = {}
+ scheduler_classes = [FlowMatchEulerDiscreteScheduler]
+ transformer_kwargs = {
+ "patch_size": 1,
+ "in_channels": 4,
+ "out_channels": 4,
+ "num_layers": 1,
+ "num_attention_heads": 2,
+ "attention_head_dim": 4,
+ "num_cross_attention_heads": 2,
+ "cross_attention_head_dim": 4,
+ "cross_attention_dim": 8,
+ "caption_channels": 8,
+ "sample_size": 32,
+ }
+ transformer_cls = SanaTransformer2DModel
+ vae_kwargs = {
+ "in_channels": 3,
+ "latent_channels": 4,
+ "attention_head_dim": 2,
+ "encoder_block_types": (
+ "ResBlock",
+ "EfficientViTBlock",
+ ),
+ "decoder_block_types": (
+ "ResBlock",
+ "EfficientViTBlock",
+ ),
+ "encoder_block_out_channels": (8, 8),
+ "decoder_block_out_channels": (8, 8),
+ "encoder_qkv_multiscales": ((), (5,)),
+ "decoder_qkv_multiscales": ((), (5,)),
+ "encoder_layers_per_block": (1, 1),
+ "decoder_layers_per_block": [1, 1],
+ "downsample_block_type": "conv",
+ "upsample_block_type": "interpolate",
+ "decoder_norm_types": "rms_norm",
+ "decoder_act_fns": "silu",
+ "scaling_factor": 0.41407,
+ }
+ vae_cls = AutoencoderDC
+ tokenizer_cls, tokenizer_id = GemmaTokenizer, "hf-internal-testing/dummy-gemma"
+ text_encoder_cls, text_encoder_id = Gemma2ForCausalLM, "hf-internal-testing/dummy-gemma-for-diffusers"
+
+ @property
+ def output_shape(self):
+ return (1, 32, 32, 3)
+
+ def get_dummy_inputs(self, with_generator=True):
+ batch_size = 1
+ sequence_length = 16
+ num_channels = 4
+ sizes = (32, 32)
+
+ generator = torch.manual_seed(0)
+ noise = floats_tensor((batch_size, num_channels) + sizes)
+ input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
+
+ pipeline_inputs = {
+ "prompt": "",
+ "negative_prompt": "",
+ "num_inference_steps": 4,
+ "guidance_scale": 4.5,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": sequence_length,
+ "output_type": "np",
+ "complex_human_instruction": None,
+ }
+ if with_generator:
+ pipeline_inputs.update({"generator": generator})
+
+ return noise, input_ids, pipeline_inputs
+
+ @unittest.skip("Not supported in Sana.")
+ def test_modify_padding_mode(self):
+ pass
+
+ @unittest.skip("Not supported in Mochi.")
+ def test_simple_inference_with_text_denoiser_block_scale(self):
+ pass
+
+ @unittest.skip("Not supported in Mochi.")
+ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Mochi.")
+ def test_simple_inference_with_partial_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Mochi.")
+ def test_simple_inference_with_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Mochi.")
+ def test_simple_inference_with_text_lora_and_scale(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Mochi.")
+ def test_simple_inference_with_text_lora_fused(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Mochi.")
+ def test_simple_inference_with_text_lora_save_load(self):
+ pass
diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py
index b37a2a297e04..8c42f9c86ee9 100644
--- a/tests/lora/test_lora_layers_sd3.py
+++ b/tests/lora/test_lora_layers_sd3.py
@@ -29,7 +29,6 @@
from diffusers.utils import load_image
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.utils.testing_utils import (
- is_peft_available,
numpy_cosine_similarity_distance,
require_peft_backend,
require_torch_gpu,
@@ -37,9 +36,6 @@
)
-if is_peft_available():
- pass
-
sys.path.append(".")
from utils import PeftLoraLoaderMixinTests # noqa: E402
diff --git a/tests/lora/utils.py b/tests/lora/utils.py
index 990cf71f298e..ac7a944cd026 100644
--- a/tests/lora/utils.py
+++ b/tests/lora/utils.py
@@ -1545,7 +1545,12 @@ def test_lora_fuse_nan(self):
"adapter-1"
].weight += float("inf")
else:
- pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
+ named_modules = [name for name, _ in pipe.transformer.named_modules()]
+ has_attn1 = any("attn1" in name for name in named_modules)
+ if has_attn1:
+ pipe.transformer.transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
+ else:
+ pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
diff --git a/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
new file mode 100644
index 000000000000..c5ff02a525f2
--- /dev/null
+++ b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
@@ -0,0 +1,215 @@
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKL,
+ FlowMatchEulerDiscreteScheduler,
+ FluxControlInpaintPipeline,
+ FluxTransformer2DModel,
+)
+from diffusers.utils.testing_utils import (
+ torch_device,
+)
+
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+ check_qkv_fusion_matches_attn_procs_length,
+ check_qkv_fusion_processors_exist,
+)
+
+
+class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
+ pipeline_class = FluxControlInpaintPipeline
+ params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
+ batch_params = frozenset(["prompt"])
+
+ # there is no xformers processor for Flux
+ test_xformers_attention = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = FluxTransformer2DModel(
+ patch_size=1,
+ in_channels=8,
+ out_channels=4,
+ num_layers=1,
+ num_single_layers=1,
+ attention_head_dim=16,
+ num_attention_heads=2,
+ joint_attention_dim=32,
+ pooled_projection_dim=32,
+ axes_dims_rope=[4, 4, 8],
+ )
+ clip_text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+
+ torch.manual_seed(0)
+ text_encoder = CLIPTextModel(clip_text_encoder_config)
+
+ torch.manual_seed(0)
+ text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ block_out_channels=(4,),
+ layers_per_block=1,
+ latent_channels=1,
+ norm_num_groups=1,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ shift_factor=0.0609,
+ scaling_factor=1.5035,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "transformer": transformer,
+ "vae": vae,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ image = Image.new("RGB", (8, 8), 0)
+ control_image = Image.new("RGB", (8, 8), 0)
+ mask_image = Image.new("RGB", (8, 8), 255)
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "control_image": control_image,
+ "generator": generator,
+ "image": image,
+ "mask_image": mask_image,
+ "strength": 0.8,
+ "num_inference_steps": 2,
+ "guidance_scale": 30.0,
+ "height": 8,
+ "width": 8,
+ "max_sequence_length": 48,
+ "output_type": "np",
+ }
+ return inputs
+
+ # def test_flux_different_prompts(self):
+ # pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+
+ # inputs = self.get_dummy_inputs(torch_device)
+ # output_same_prompt = pipe(**inputs).images[0]
+
+ # inputs = self.get_dummy_inputs(torch_device)
+ # inputs["prompt_2"] = "a different prompt"
+ # output_different_prompts = pipe(**inputs).images[0]
+
+ # max_diff = np.abs(output_same_prompt - output_different_prompts).max()
+
+ # # Outputs should be different here
+ # # For some reasons, they don't show large differences
+ # assert max_diff > 1e-6
+
+ def test_flux_prompt_embeds(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ output_with_prompt = pipe(**inputs).images[0]
+
+ inputs = self.get_dummy_inputs(torch_device)
+ prompt = inputs.pop("prompt")
+
+ (prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt(
+ prompt,
+ prompt_2=None,
+ device=torch_device,
+ max_sequence_length=inputs["max_sequence_length"],
+ )
+ output_with_embeds = pipe(
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ **inputs,
+ ).images[0]
+
+ max_diff = np.abs(output_with_prompt - output_with_embeds).max()
+ assert max_diff < 1e-4
+
+ def test_fused_qkv_projections(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ original_image_slice = image[0, -3:, -3:, -1]
+
+ # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
+ # to the pipeline level.
+ pipe.transformer.fuse_qkv_projections()
+ assert check_qkv_fusion_processors_exist(
+ pipe.transformer
+ ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_matches_attn_procs_length(
+ pipe.transformer, pipe.transformer.original_attn_processors
+ ), "Something wrong with the attention processors concerning the fused QKV projections."
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_fused = image[0, -3:, -3:, -1]
+
+ pipe.transformer.unfuse_qkv_projections()
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_disabled = image[0, -3:, -3:, -1]
+
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
+
+ def test_flux_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ height_width_pairs = [(32, 32), (72, 57)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
+
+ inputs.update({"height": height, "width": width})
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
diff --git a/tests/pipelines/mochi/test_mochi.py b/tests/pipelines/mochi/test_mochi.py
index 2192c171aa22..bbcf6d210ce5 100644
--- a/tests/pipelines/mochi/test_mochi.py
+++ b/tests/pipelines/mochi/test_mochi.py
@@ -275,7 +275,7 @@ def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
- def test_cogvideox(self):
+ def test_mochi(self):
generator = torch.Generator("cpu").manual_seed(0)
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.float16)