-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add FLUX image-to-image and inpainting (#6798)
## Summary This PR adds support for Image-to-Image and inpainting workflows with the FLUX model. Full changelog: - Split out `FLUX VAE Encode` and `FLUX VAE Decode` nodes - Renamed `FLUX Text-to-Image` node to `FLUX Denoise` (since it now supports image-to-image too). This is a workflow-breaking change. - Added support for FLUX image-to-image via the `Latents` param on the FLUX denoising node. - Added support for FLUX masked inpainting via the `Denoise Mask` param on the FLUX denoising node. - Added "Denoise Start" and "Denoise End" params to the "FLUX Denoise" node. - Updated the "FLUX Text to Image" default workflow. - Added a "FLUX Image to Image" default workflow. ### Example FLUX inpainting workflow <img width="1282" alt="image" src="https://github.com/user-attachments/assets/86fc1170-e620-4412-8fd8-e119f875fc2e"> Input image ![image](https://github.com/user-attachments/assets/9c381b86-9f87-4257-bd2e-da22c56ca26c) Mask ![image](https://github.com/user-attachments/assets/8f774c5c-2a25-45fe-9d4b-b233e3d58d2c) Output image ![image](https://github.com/user-attachments/assets/8576a630-24ce-4a00-8052-e86bab59c855) ### Callouts for reviewers: - I renamed FLUXTextToImageInvocation -> FLUXDenoisingInvocation. This is, of course, a breaking change. It feels like the right move and now is the right time to do it. Any objection? - I added new `FLUX VAE Encode` and `FLUX VAE Decode` nodes. Alternatively, I could have tried to match these names to the corresponding SD nodes (e.g. `FLUX Image to Latents`, `FLUX Latents to Image`). Personally, I prefer the current names, but want to hear other opinions. ### Usage notes: - With the default dev timestep scheduler, the image structure is largely determined in the first ~3 steps. A consequence of this is that the denoise_start parameter provides limited 'granularity' of control. This will likely be improved in the future as we add more scheduler options. In the meantime, you will likely want to use small values for `denoise_start` (e.g. 0.03) to start denoising on step ~1-4 out of ~30. - Currently, there is no 'noise' parameter on the `FLUX Denoise` node, so the `denoise_end` parameter has limited utility. This will be added in the future. ## QA Instructions Test the following workflows: - [x] Vanilla FLUX text-to-image behaviour is unchanged - [x] Image-to-image with FLUX dev, no mask - [x] Image-to-image with FLUX dev, with mask - [x] Image-to-image with FLUX schnell, no mask (smoke test, not expected to work well) ## Merge Plan No special instructions. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_
- Loading branch information
Showing
16 changed files
with
1,412 additions
and
489 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,274 @@ | ||
from typing import Callable, Optional | ||
|
||
import torch | ||
import torchvision.transforms as tv_transforms | ||
from torchvision.transforms.functional import resize as tv_resize | ||
|
||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation | ||
from invokeai.app.invocations.fields import ( | ||
DenoiseMaskField, | ||
FieldDescriptions, | ||
FluxConditioningField, | ||
Input, | ||
InputField, | ||
LatentsField, | ||
WithBoard, | ||
WithMetadata, | ||
) | ||
from invokeai.app.invocations.model import TransformerField | ||
from invokeai.app.invocations.primitives import LatentsOutput | ||
from invokeai.app.services.session_processor.session_processor_common import CanceledException | ||
from invokeai.app.services.shared.invocation_context import InvocationContext | ||
from invokeai.backend.flux.denoise import denoise | ||
from invokeai.backend.flux.inpaint_extension import InpaintExtension | ||
from invokeai.backend.flux.model import Flux | ||
from invokeai.backend.flux.sampling_utils import ( | ||
clip_timestep_schedule, | ||
generate_img_ids, | ||
get_noise, | ||
get_schedule, | ||
pack, | ||
unpack, | ||
) | ||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo | ||
from invokeai.backend.util.devices import TorchDevice | ||
|
||
|
||
@invocation( | ||
"flux_denoise", | ||
title="FLUX Denoise", | ||
tags=["image", "flux"], | ||
category="image", | ||
version="1.0.0", | ||
classification=Classification.Prototype, | ||
) | ||
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): | ||
"""Run denoising process with a FLUX transformer model.""" | ||
|
||
# If latents is provided, this means we are doing image-to-image. | ||
latents: Optional[LatentsField] = InputField( | ||
default=None, | ||
description=FieldDescriptions.latents, | ||
input=Input.Connection, | ||
) | ||
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified. | ||
denoise_mask: Optional[DenoiseMaskField] = InputField( | ||
default=None, | ||
description=FieldDescriptions.denoise_mask, | ||
input=Input.Connection, | ||
) | ||
denoising_start: float = InputField( | ||
default=0.0, | ||
ge=0, | ||
le=1, | ||
description=FieldDescriptions.denoising_start, | ||
) | ||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end) | ||
transformer: TransformerField = InputField( | ||
description=FieldDescriptions.flux_model, | ||
input=Input.Connection, | ||
title="Transformer", | ||
) | ||
positive_text_conditioning: FluxConditioningField = InputField( | ||
description=FieldDescriptions.positive_cond, input=Input.Connection | ||
) | ||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.") | ||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.") | ||
num_steps: int = InputField( | ||
default=4, description="Number of diffusion steps. Recommended values are schnell: 4, dev: 50." | ||
) | ||
guidance: float = InputField( | ||
default=4.0, | ||
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.", | ||
) | ||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.") | ||
|
||
@torch.no_grad() | ||
def invoke(self, context: InvocationContext) -> LatentsOutput: | ||
latents = self._run_diffusion(context) | ||
latents = latents.detach().to("cpu") | ||
|
||
name = context.tensors.save(tensor=latents) | ||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None) | ||
|
||
def _run_diffusion( | ||
self, | ||
context: InvocationContext, | ||
): | ||
inference_dtype = torch.bfloat16 | ||
|
||
# Load the conditioning data. | ||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name) | ||
assert len(cond_data.conditionings) == 1 | ||
flux_conditioning = cond_data.conditionings[0] | ||
assert isinstance(flux_conditioning, FLUXConditioningInfo) | ||
flux_conditioning = flux_conditioning.to(dtype=inference_dtype) | ||
t5_embeddings = flux_conditioning.t5_embeds | ||
clip_embeddings = flux_conditioning.clip_embeds | ||
|
||
# Load the input latents, if provided. | ||
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None | ||
if init_latents is not None: | ||
init_latents = init_latents.to(device=TorchDevice.choose_torch_device(), dtype=inference_dtype) | ||
|
||
# Prepare input noise. | ||
noise = get_noise( | ||
num_samples=1, | ||
height=self.height, | ||
width=self.width, | ||
device=TorchDevice.choose_torch_device(), | ||
dtype=inference_dtype, | ||
seed=self.seed, | ||
) | ||
|
||
transformer_info = context.models.load(self.transformer.transformer) | ||
is_schnell = "schnell" in transformer_info.config.config_path | ||
|
||
# Calculate the timestep schedule. | ||
image_seq_len = noise.shape[-1] * noise.shape[-2] // 4 | ||
timesteps = get_schedule( | ||
num_steps=self.num_steps, | ||
image_seq_len=image_seq_len, | ||
shift=not is_schnell, | ||
) | ||
|
||
# Clip the timesteps schedule based on denoising_start and denoising_end. | ||
timesteps = clip_timestep_schedule(timesteps, self.denoising_start, self.denoising_end) | ||
|
||
# Prepare input latent image. | ||
if init_latents is not None: | ||
# If init_latents is provided, we are doing image-to-image. | ||
|
||
if is_schnell: | ||
context.logger.warning( | ||
"Running image-to-image with a FLUX schnell model. This is not recommended. The results are likely " | ||
"to be poor. Consider using a FLUX dev model instead." | ||
) | ||
|
||
# Noise the orig_latents by the appropriate amount for the first timestep. | ||
t_0 = timesteps[0] | ||
x = t_0 * noise + (1.0 - t_0) * init_latents | ||
else: | ||
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise). | ||
if self.denoising_start > 1e-5: | ||
raise ValueError("denoising_start should be 0 when initial latents are not provided.") | ||
|
||
x = noise | ||
|
||
# If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any | ||
# denoising steps. | ||
if len(timesteps) <= 1: | ||
return x | ||
|
||
inpaint_mask = self._prep_inpaint_mask(context, x) | ||
|
||
b, _c, h, w = x.shape | ||
img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype) | ||
|
||
bs, t5_seq_len, _ = t5_embeddings.shape | ||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()) | ||
|
||
# Pack all latent tensors. | ||
init_latents = pack(init_latents) if init_latents is not None else None | ||
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None | ||
noise = pack(noise) | ||
x = pack(x) | ||
|
||
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly. | ||
assert image_seq_len == x.shape[1] | ||
|
||
# Prepare inpaint extension. | ||
inpaint_extension: InpaintExtension | None = None | ||
if inpaint_mask is not None: | ||
assert init_latents is not None | ||
inpaint_extension = InpaintExtension( | ||
init_latents=init_latents, | ||
inpaint_mask=inpaint_mask, | ||
noise=noise, | ||
) | ||
|
||
with transformer_info as transformer: | ||
assert isinstance(transformer, Flux) | ||
|
||
x = denoise( | ||
model=transformer, | ||
img=x, | ||
img_ids=img_ids, | ||
txt=t5_embeddings, | ||
txt_ids=txt_ids, | ||
vec=clip_embeddings, | ||
timesteps=timesteps, | ||
step_callback=self._build_step_callback(context), | ||
guidance=self.guidance, | ||
inpaint_extension=inpaint_extension, | ||
) | ||
|
||
x = unpack(x.float(), self.height, self.width) | ||
return x | ||
|
||
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None: | ||
"""Prepare the inpaint mask. | ||
- Loads the mask | ||
- Resizes if necessary | ||
- Casts to same device/dtype as latents | ||
- Expands mask to the same shape as latents so that they line up after 'packing' | ||
Args: | ||
context (InvocationContext): The invocation context, for loading the inpaint mask. | ||
latents (torch.Tensor): A latent image tensor. In 'unpacked' format. Used to determine the target shape, | ||
device, and dtype for the inpaint mask. | ||
Returns: | ||
torch.Tensor | None: Inpaint mask. | ||
""" | ||
if self.denoise_mask is None: | ||
return None | ||
|
||
mask = context.tensors.load(self.denoise_mask.mask_name) | ||
|
||
_, _, latent_height, latent_width = latents.shape | ||
mask = tv_resize( | ||
img=mask, | ||
size=[latent_height, latent_width], | ||
interpolation=tv_transforms.InterpolationMode.BILINEAR, | ||
antialias=False, | ||
) | ||
|
||
mask = mask.to(device=latents.device, dtype=latents.dtype) | ||
|
||
# Expand the inpaint mask to the same shape as `latents` so that when we 'pack' `mask` it lines up with | ||
# `latents`. | ||
return mask.expand_as(latents) | ||
|
||
def _build_step_callback(self, context: InvocationContext) -> Callable[[], None]: | ||
def step_callback() -> None: | ||
if context.util.is_canceled(): | ||
raise CanceledException | ||
|
||
# TODO: Make this look like the image before re-enabling | ||
# latent_image = unpack(img.float(), self.height, self.width) | ||
# latent_image = latent_image.squeeze() # Remove unnecessary dimensions | ||
# flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128] | ||
|
||
# # Create a new tensor of the required shape [255, 255, 3] | ||
# latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format | ||
|
||
# # Convert to a NumPy array and then to a PIL Image | ||
# image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8)) | ||
|
||
# (width, height) = image.size | ||
# width *= 8 | ||
# height *= 8 | ||
|
||
# dataURL = image_to_dataURL(image, image_format="JPEG") | ||
|
||
# # TODO: move this whole function to invocation context to properly reference these variables | ||
# context._services.events.emit_invocation_denoise_progress( | ||
# context._data.queue_item, | ||
# context._data.invocation, | ||
# state, | ||
# ProgressImage(dataURL=dataURL, width=width, height=height), | ||
# ) | ||
|
||
return step_callback |
Oops, something went wrong.