Skip to content

Commit

Permalink
Flux Fill, Canny, Depth, Redux (#9985)
Browse files Browse the repository at this point in the history
* update

---------

Co-authored-by: yiyixuxu <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
3 people authored Nov 23, 2024
1 parent b5fd6f1 commit 7ac6e28
Show file tree
Hide file tree
Showing 18 changed files with 4,189 additions and 8 deletions.
153 changes: 149 additions & 4 deletions docs/source/en/api/pipelines/flux.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,20 @@ Flux can be quite expensive to run on consumer hardware devices. However, you ca

</Tip>

Flux comes in two variants:
Flux comes in the following variants:

* Timestep-distilled (`black-forest-labs/FLUX.1-schnell`)
* Guidance-distilled (`black-forest-labs/FLUX.1-dev`)
| model type | model id |
|:----------:|:--------:|
| Timestep-distilled | [`black-forest-labs/FLUX.1-schnell`](https://huggingface.co/black-forest-labs/FLUX.1-schnell) |
| Guidance-distilled | [`black-forest-labs/FLUX.1-dev`](https://huggingface.co/black-forest-labs/FLUX.1-dev) |
| Fill Inpainting/Outpainting (Guidance-distilled) | [`black-forest-labs/FLUX.1-Fill-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev) |
| Canny Control (Guidance-distilled) | [`black-forest-labs/FLUX.1-Canny-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev) |
| Depth Control (Guidance-distilled) | [`black-forest-labs/FLUX.1-Depth-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev) |
| Canny Control (LoRA) | [`black-forest-labs/FLUX.1-Canny-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora) |
| Depth Control (LoRA) | [`black-forest-labs/FLUX.1-Depth-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora) |
| Redux (Adapter) | [`black-forest-labs/FLUX.1-Redux-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) |

Both checkpoints have slightly difference usage which we detail below.
All checkpoints have different usage which we detail below.

### Timestep-distilled

Expand Down Expand Up @@ -77,7 +85,132 @@ out = pipe(
out.save("image.png")
```

### Fill Inpainting/Outpainting

* Flux Fill pipeline does not require `strength` as an input like regular inpainting pipelines.
* It supports both inpainting and outpainting.

```python
import torch
from diffusers import FluxFillPipeline
from diffusers.utils import load_image

image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/cup.png")
mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/cup_mask.png")

repo_id = "black-forest-labs/FLUX.1-Fill-dev"
pipe = FluxFillPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16).to("cuda")

image = pipe(
prompt="a white paper cup",
image=image,
mask_image=mask,
height=1632,
width=1232,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save(f"output.png")
```

### Canny Control

**Note:** `black-forest-labs/Flux.1-Canny-dev` is _not_ a [`ControlNetModel`] model. ControlNet models are a separate component from the UNet/Transformer whose residuals are added to the actual underlying model. Canny Control is an alternate architecture that achieves effectively the same results as a ControlNet model would, by using channel-wise concatenation with input control condition and ensuring the transformer learns structure control by following the condition as closely as possible.

```python
# !pip install -U controlnet-aux
import torch
from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline
from diffusers.utils import load_image

pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-Canny-dev", torch_dtype=torch.bfloat16).to("cuda")

prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")

processor = CannyDetector()
control_image = processor(control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024)

image = pipe(
prompt=prompt,
control_image=control_image,
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=30.0,
).images[0]
image.save("output.png")
```

### Depth Control

**Note:** `black-forest-labs/Flux.1-Depth-dev` is _not_ a ControlNet model. [`ControlNetModel`] models are a separate component from the UNet/Transformer whose residuals are added to the actual underlying model. Depth Control is an alternate architecture that achieves effectively the same results as a ControlNet model would, by using channel-wise concatenation with input control condition and ensuring the transformer learns structure control by following the condition as closely as possible.

```python
# !pip install git+https://github.com/asomoza/image_gen_aux.git
import torch
from diffusers import FluxControlPipeline, FluxTransformer2DModel
from diffusers.utils import load_image
from image_gen_aux import DepthPreprocessor

pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-Depth-dev", torch_dtype=torch.bfloat16).to("cuda")

prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")

processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
control_image = processor(control_image)[0].convert("RGB")

image = pipe(
prompt=prompt,
control_image=control_image,
height=1024,
width=1024,
num_inference_steps=30,
guidance_scale=10.0,
generator=torch.Generator().manual_seed(42),
).images[0]
image.save("output.png")
```

### Redux

* Flux Redux pipeline is an adapter for FLUX.1 base models. It can be used with both flux-dev and flux-schnell, for image-to-image generation.
* You can first use the `FluxPriorReduxPipeline` to get the `prompt_embeds` and `pooled_prompt_embeds`, and then feed them into the `FluxPipeline` for image-to-image generation.
* When use `FluxPriorReduxPipeline` with a base pipeline, you can set `text_encoder=None` and `text_encoder_2=None` in the base pipeline, in order to save VRAM.

```python
import torch
from diffusers import FluxPriorReduxPipeline, FluxPipeline
from diffusers.utils import load_image
device = "cuda"
dtype = torch.bfloat16


repo_redux = "black-forest-labs/FLUX.1-Redux-dev"
repo_base = "black-forest-labs/FLUX.1-dev"
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(repo_redux, torch_dtype=dtype).to(device)
pipe = FluxPipeline.from_pretrained(
repo_base,
text_encoder=None,
text_encoder_2=None,
torch_dtype=torch.bfloat16
).to(device)

image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png")
pipe_prior_output = pipe_prior_redux(image)
images = pipe(
guidance_scale=2.5,
num_inference_steps=50,
generator=torch.Generator("cpu").manual_seed(0),
**pipe_prior_output,
).images
images[0].save("flux-redux.png")
```

## Running FP16 inference

Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.

FP16 inference code:
Expand Down Expand Up @@ -188,3 +321,15 @@ image.save("flux-fp8-dev.png")
[[autodoc]] FluxControlNetImg2ImgPipeline
- all
- __call__

## FluxControlPipeline

[[autodoc]] FluxControlPipeline
- all
- __call__

## FluxControlImg2ImgPipeline

[[autodoc]] FluxControlImg2ImgPipeline
- all
- __call__
7 changes: 6 additions & 1 deletion scripts/convert_flux_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
parser.add_argument("--filename", default="flux.safetensors", type=str)
parser.add_argument("--checkpoint_path", default=None, type=str)
parser.add_argument("--in_channels", type=int, default=64)
parser.add_argument("--out_channels", type=int, default=None)
parser.add_argument("--vae", action="store_true")
parser.add_argument("--transformer", action="store_true")
parser.add_argument("--output_path", type=str)
Expand Down Expand Up @@ -279,10 +281,13 @@ def main(args):
num_single_layers = 38
inner_dim = 3072
mlp_ratio = 4.0

converted_transformer_state_dict = convert_flux_transformer_checkpoint_to_diffusers(
original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio
)
transformer = FluxTransformer2DModel(guidance_embeds=has_guidance)
transformer = FluxTransformer2DModel(
in_channels=args.in_channels, out_channels=args.out_channels, guidance_embeds=has_guidance
)
transformer.load_state_dict(converted_transformer_state_dict, strict=True)

print(
Expand Down
10 changes: 10 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,12 +269,16 @@
"CogVideoXVideoToVideoPipeline",
"CogView3PlusPipeline",
"CycleDiffusionPipeline",
"FluxControlImg2ImgPipeline",
"FluxControlNetImg2ImgPipeline",
"FluxControlNetInpaintPipeline",
"FluxControlNetPipeline",
"FluxControlPipeline",
"FluxFillPipeline",
"FluxImg2ImgPipeline",
"FluxInpaintPipeline",
"FluxPipeline",
"FluxPriorReduxPipeline",
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
"HunyuanDiTPipeline",
Expand Down Expand Up @@ -321,6 +325,7 @@
"PixArtAlphaPipeline",
"PixArtSigmaPAGPipeline",
"PixArtSigmaPipeline",
"ReduxImageEncoder",
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
Expand Down Expand Up @@ -734,12 +739,16 @@
CogVideoXVideoToVideoPipeline,
CogView3PlusPipeline,
CycleDiffusionPipeline,
FluxControlImg2ImgPipeline,
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline,
FluxControlPipeline,
FluxFillPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
FluxPipeline,
FluxPriorReduxPipeline,
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,
HunyuanDiTPipeline,
Expand Down Expand Up @@ -786,6 +795,7 @@
PixArtAlphaPipeline,
PixArtSigmaPAGPipeline,
PixArtSigmaPipeline,
ReduxImageEncoder,
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
Expand Down
7 changes: 5 additions & 2 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
out_channels: Optional[int] = None,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
Expand All @@ -248,7 +249,7 @@ def __init__(
axes_dims_rope: Tuple[int] = (16, 56, 56),
):
super().__init__()
self.out_channels = in_channels
self.out_channels = out_channels or in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim

self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
Expand All @@ -261,7 +262,7 @@ def __init__(
)

self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim)

self.transformer_blocks = nn.ModuleList(
[
Expand Down Expand Up @@ -449,13 +450,15 @@ def forward(
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)

hidden_states = self.x_embedder(hidden_states)

timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None

temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
Expand Down
10 changes: 10 additions & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,17 @@
"AnimateDiffVideoToVideoControlNetPipeline",
]
_import_structure["flux"] = [
"FluxControlPipeline",
"FluxControlImg2ImgPipeline",
"FluxControlNetPipeline",
"FluxControlNetImg2ImgPipeline",
"FluxControlNetInpaintPipeline",
"FluxImg2ImgPipeline",
"FluxInpaintPipeline",
"FluxPipeline",
"FluxFillPipeline",
"FluxPriorReduxPipeline",
"ReduxImageEncoder",
]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
Expand Down Expand Up @@ -521,12 +526,17 @@
VQDiffusionPipeline,
)
from .flux import (
FluxControlImg2ImgPipeline,
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline,
FluxControlPipeline,
FluxFillPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
FluxPipeline,
FluxPriorReduxPipeline,
ReduxImageEncoder,
)
from .hunyuandit import HunyuanDiTPipeline
from .i2vgen_xl import I2VGenXLPipeline
Expand Down
12 changes: 11 additions & 1 deletion src/diffusers/pipelines/flux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

_dummy_objects = {}
_additional_imports = {}
_import_structure = {"pipeline_output": ["FluxPipelineOutput"]}
_import_structure = {"pipeline_output": ["FluxPipelineOutput", "FluxPriorReduxPipelineOutput"]}

try:
if not (is_transformers_available() and is_torch_available()):
Expand All @@ -22,25 +22,35 @@

_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["modeling_flux"] = ["ReduxImageEncoder"]
_import_structure["pipeline_flux"] = ["FluxPipeline"]
_import_structure["pipeline_flux_control"] = ["FluxControlPipeline"]
_import_structure["pipeline_flux_control_img2img"] = ["FluxControlImg2ImgPipeline"]
_import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
_import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"]
_import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"]
_import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"]
_import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
_import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
_import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .modeling_flux import ReduxImageEncoder
from .pipeline_flux import FluxPipeline
from .pipeline_flux_control import FluxControlPipeline
from .pipeline_flux_control_img2img import FluxControlImg2ImgPipeline
from .pipeline_flux_controlnet import FluxControlNetPipeline
from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline
from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
from .pipeline_flux_fill import FluxFillPipeline
from .pipeline_flux_img2img import FluxImg2ImgPipeline
from .pipeline_flux_inpaint import FluxInpaintPipeline
from .pipeline_flux_prior_redux import FluxPriorReduxPipeline
else:
import sys

Expand Down
Loading

0 comments on commit 7ac6e28

Please sign in to comment.