Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modular backend - inpaint #6643

Merged
merged 13 commits into from
Jul 29, 2024
1 change: 1 addition & 0 deletions invokeai/app/invocations/create_gradient_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def invoke(self, context: InvocationContext) -> GradientMaskOutput:

# redistribute blur so that the original edges are 0 and blur outwards to 1
blur_tensor = (blur_tensor - 0.5) * 2
blur_tensor[blur_tensor < 0] = 0.0
StAlKeR7779 marked this conversation as resolved.
Show resolved Hide resolved

threshold = 1 - self.minimum_denoise

Expand Down
42 changes: 27 additions & 15 deletions invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
Expand All @@ -58,6 +58,8 @@
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
Expand Down Expand Up @@ -672,7 +674,7 @@ def prep_inpaint_mask(
else:
masked_latents = torch.where(mask < 0.5, 0.0, latents)

return 1 - mask, masked_latents, self.denoise_mask.gradient
return mask, masked_latents, self.denoise_mask.gradient

@staticmethod
def prepare_noise_and_latents(
Expand Down Expand Up @@ -730,10 +732,6 @@ def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
dtype = TorchDevice.choose_torch_dtype()

seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
latents = latents.to(device=device, dtype=dtype)
if noise is not None:
noise = noise.to(device=device, dtype=dtype)

_, _, latent_height, latent_width = latents.shape

conditioning_data = self.get_conditioning_data(
Expand Down Expand Up @@ -766,6 +764,27 @@ def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
denoising_end=self.denoising_end,
)

# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)

### preview
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)

ext_manager.add_extension(PreviewExt(step_callback))

### inpaint
mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents)
if unet_config.variant == ModelVariantType.Inpaint:
ext_manager.add_extension(InpaintModelExt(mask, masked_latents, is_gradient_mask))
StAlKeR7779 marked this conversation as resolved.
Show resolved Hide resolved
elif mask is not None:
ext_manager.add_extension(InpaintExt(mask, is_gradient_mask))

# Initialize context for modular denoise
latents = latents.to(device=device, dtype=dtype)
if noise is not None:
noise = noise.to(device=device, dtype=dtype)

denoise_ctx = DenoiseContext(
inputs=DenoiseInputs(
orig_latents=latents,
Expand All @@ -781,15 +800,6 @@ def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
scheduler=scheduler,
)

# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)

### preview
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)

ext_manager.add_extension(PreviewExt(step_callback))

# ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)

Expand Down Expand Up @@ -820,6 +830,8 @@ def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)

mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
if mask is not None:
mask = 1 - mask
StAlKeR7779 marked this conversation as resolved.
Show resolved Hide resolved

# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
# below. Investigate whether this is appropriate.
Expand Down
116 changes: 116 additions & 0 deletions invokeai/backend/stable_diffusion/extensions/inpaint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional

import einops
import torch
from diffusers import UNet2DConditionModel

from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback

if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext


class InpaintExt(ExtensionBase):
StAlKeR7779 marked this conversation as resolved.
Show resolved Hide resolved
"""An extension for inpainting with non-inpainting models. See `InpaintModelExt` for inpainting with inpainting
models.
"""
def __init__(
self,
mask: torch.Tensor,
is_gradient_mask: bool,
):
StAlKeR7779 marked this conversation as resolved.
Show resolved Hide resolved
"""Initialize InpaintExt.
Args:
mask (torch.Tensor): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be
inpainted.
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
1.
"""
super().__init__()
self._mask = mask
self._is_gradient_mask = is_gradient_mask

# Noise, which used to noisify unmasked part of image
# if noise provided to context, then it will be used
# if no noise provided, then noise will be generated based on seed
self._noise: Optional[torch.Tensor] = None
StAlKeR7779 marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def _is_normal_model(unet: UNet2DConditionModel):
""" Checks if the provided UNet belongs to a regular model.
The `in_channels` of a UNet vary depending on model type:
- normal - 4
- depth - 5
- inpaint - 9
"""
return unet.conv_in.in_channels == 4

def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
batch_size = latents.size(0)
mask = einops.repeat(self._mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
if t.dim() == 0:
# some schedulers expect t to be one-dimensional.
# TODO: file diffusers bug about inconsistency?
t = einops.repeat(t, "-> batch", batch=batch_size)
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
# get very confused about what is happening from step to step when we do that.
mask_latents = ctx.scheduler.add_noise(ctx.inputs.orig_latents, self._noise, t)
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
if self._is_gradient_mask:
threshold = (t.item()) / ctx.scheduler.config.num_train_timesteps
mask_bool = mask < 1 - threshold
masked_input = torch.where(mask_bool, latents, mask_latents)
else:
masked_input = torch.lerp(latents, mask_latents.to(dtype=latents.dtype), mask.to(dtype=latents.dtype))
return masked_input

@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
def init_tensors(self, ctx: DenoiseContext):
if not self._is_normal_model(ctx.unet):
raise ValueError("InpaintExt should be used only on normal models!")

self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)

self._noise = ctx.inputs.noise
# 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).
# We still need noise for inpainting, so we generate it from the seed here.
if self._noise is None:
self._noise = torch.randn(
ctx.latents.shape,
dtype=torch.float32,
device="cpu",
generator=torch.Generator(device="cpu").manual_seed(ctx.seed),
).to(device=ctx.latents.device, dtype=ctx.latents.dtype)
StAlKeR7779 marked this conversation as resolved.
Show resolved Hide resolved

# TODO: order value
@callback(ExtensionCallbackType.PRE_STEP, order=-100)
StAlKeR7779 marked this conversation as resolved.
Show resolved Hide resolved
def apply_mask_to_initial_latents(self, ctx: DenoiseContext):
ctx.latents = self._apply_mask(ctx, ctx.latents, ctx.timestep)

# TODO: order value
# TODO: redo this with preview events rewrite
@callback(ExtensionCallbackType.POST_STEP, order=-100)
def apply_mask_to_step_output(self, ctx: DenoiseContext):
timestep = ctx.scheduler.timesteps[-1]
if hasattr(ctx.step_output, "denoised"):
ctx.step_output.denoised = self._apply_mask(ctx, ctx.step_output.denoised, timestep)
elif hasattr(ctx.step_output, "pred_original_sample"):
ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.pred_original_sample, timestep)
else:
ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.prev_sample, timestep)

# TODO: should here be used order?
# restore unmasked part after the last step is completed
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
def restore_unmasked(self, ctx: DenoiseContext):
if self._is_gradient_mask:
ctx.latents = torch.where(self._mask < 1, ctx.latents, ctx.inputs.orig_latents)
else:
ctx.latents = torch.lerp(ctx.latents, ctx.inputs.orig_latents, self._mask)
89 changes: 89 additions & 0 deletions invokeai/backend/stable_diffusion/extensions/inpaint_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional

import torch
from diffusers import UNet2DConditionModel

from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback

if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext


class InpaintModelExt(ExtensionBase):
StAlKeR7779 marked this conversation as resolved.
Show resolved Hide resolved
"""An extension for inpainting with inpainting models. See `InpaintExt` for inpainting with non-inpainting
models.
"""
def __init__(
self,
mask: Optional[torch.Tensor],
masked_latents: Optional[torch.Tensor],
is_gradient_mask: bool,
):
StAlKeR7779 marked this conversation as resolved.
Show resolved Hide resolved
"""Initialize InpaintModelExt.
Args:
mask (Optional[torch.Tensor]): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be
inpainted.
masked_latents (Optional[torch.Tensor]): Latents of initial image, with masked out by black color inpainted area.
If mask provided, then too should be provided. Shape: (1, 1, latent_height, latent_width)
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
1.
"""
super().__init__()
if mask is not None and masked_latents is None:
raise ValueError("Source image required for inpaint mask when inpaint model used!")

# Inverse mask, because inpaint models treat mask as: 0 - remain same, 1 - inpaint
self._mask = None
if mask is not None:
self._mask = 1 - mask
self._masked_latents = masked_latents
self._is_gradient_mask = is_gradient_mask

@staticmethod
def _is_inpaint_model(unet: UNet2DConditionModel):
""" Checks if the provided UNet belongs to a regular model.
The `in_channels` of a UNet vary depending on model type:
- normal - 4
- depth - 5
- inpaint - 9
"""
return unet.conv_in.in_channels == 9

@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
def init_tensors(self, ctx: DenoiseContext):
if not self._is_inpaint_model(ctx.unet):
raise ValueError("InpaintModelExt should be used only on inpaint models!")

if self._mask is None:
self._mask = torch.ones_like(ctx.latents[:1, :1])
self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)

if self._masked_latents is None:
self._masked_latents = torch.zeros_like(ctx.latents[:1])
self._masked_latents = self._masked_latents.to(device=ctx.latents.device, dtype=ctx.latents.dtype)

# TODO: any ideas about order value?
# do last so that other extensions works with normal latents
@callback(ExtensionCallbackType.PRE_UNET, order=1000)
def append_inpaint_layers(self, ctx: DenoiseContext):
batch_size = ctx.unet_kwargs.sample.shape[0]
b_mask = torch.cat([self._mask] * batch_size)
b_masked_latents = torch.cat([self._masked_latents] * batch_size)
ctx.unet_kwargs.sample = torch.cat(
[ctx.unet_kwargs.sample, b_mask, b_masked_latents],
dim=1,
)

# TODO: should here be used order?
# restore unmasked part as inpaint model can change unmasked part slightly
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
def restore_unmasked(self, ctx: DenoiseContext):
if self._is_gradient_mask:
ctx.latents = torch.where(self._mask > 0, ctx.latents, ctx.inputs.orig_latents)
StAlKeR7779 marked this conversation as resolved.
Show resolved Hide resolved
else:
ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self._mask)