From 9e7b470189e3f8f9cee7baace4ccb37be9f8d61e Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sun, 21 Jul 2024 20:45:55 +0300 Subject: [PATCH 01/12] Handle inpaint models --- invokeai/app/invocations/denoise_latents.py | 7 ++ .../extensions/inpaint_model.py | 66 +++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 invokeai/backend/stable_diffusion/extensions/inpaint_model.py diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index ccacc3303cf..1f28252408a 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -58,6 +58,7 @@ from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0 from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP @@ -790,6 +791,12 @@ def step_callback(state: PipelineIntermediateState) -> None: ext_manager.add_extension(PreviewExt(step_callback)) + ### inpaint + # TODO: add inpainting on normal model + mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents) + if unet_config.variant == "inpaint": # ModelVariantType.Inpaint: + ext_manager.add_extension(InpaintModelExt(mask, masked_latents, is_gradient_mask)) + # ext: t2i/ip adapter ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx) diff --git a/invokeai/backend/stable_diffusion/extensions/inpaint_model.py b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py new file mode 100644 index 00000000000..190e0fa9316 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +import torch +from diffusers import UNet2DConditionModel + +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback + +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + + +class InpaintModelExt(ExtensionBase): + def __init__( + self, + mask: Optional[torch.Tensor], + masked_latents: Optional[torch.Tensor], + is_gradient_mask: bool, + ): + super().__init__() + self.mask = mask + self.masked_latents = masked_latents + self.is_gradient_mask = is_gradient_mask + + @staticmethod + def _is_inpaint_model(unet: UNet2DConditionModel): + return unet.conv_in.in_channels == 9 + + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) + def init_tensors(self, ctx: DenoiseContext): + if not self._is_inpaint_model(ctx.unet): + raise Exception("InpaintModelExt should be used only on inpaint model!") + + if self.mask is None: + self.mask = torch.ones_like(ctx.latents[:1, :1]) + self.mask = self.mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype) + + if self.masked_latents is None: + self.masked_latents = torch.zeros_like(ctx.latents[:1]) + self.masked_latents = self.masked_latents.to(device=ctx.latents.device, dtype=ctx.latents.dtype) + + # TODO: any ideas about order value? + # do last so that other extensions works with normal latents + @callback(ExtensionCallbackType.PRE_UNET, order=1000) + def append_inpaint_layers(self, ctx: DenoiseContext): + batch_size = ctx.unet_kwargs.sample.shape[0] + b_mask = torch.cat([self.mask] * batch_size) + b_masked_latents = torch.cat([self.masked_latents] * batch_size) + ctx.unet_kwargs.sample = torch.cat( + [ctx.unet_kwargs.sample, b_mask, b_masked_latents], + dim=1, + ) + + # TODO: should here be used order? + # restore unmasked part as inpaint model can change unmasked part slightly + @callback(ExtensionCallbackType.POST_DENOISE_LOOP) + def restore_unmasked(self, ctx: DenoiseContext): + if self.mask is None: + return + + if self.is_gradient_mask: + ctx.latents = torch.where(self.mask > 0, ctx.latents, ctx.inputs.orig_latents) + else: + ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self.mask) From 58f3072b9154f34df1cb6c4a26de3d59882ecd5c Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sun, 21 Jul 2024 22:17:29 +0300 Subject: [PATCH 02/12] Handle inpainting on normal models --- invokeai/app/invocations/denoise_latents.py | 8 +- .../stable_diffusion/extensions/inpaint.py | 91 +++++++++++++++++++ .../extensions/inpaint_model.py | 2 +- 3 files changed, 97 insertions(+), 4 deletions(-) create mode 100644 invokeai/backend/stable_diffusion/extensions/inpaint.py diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 1f28252408a..3a9e0291af7 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -37,7 +37,7 @@ from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.lora import LoRAModelRaw -from invokeai.backend.model_manager import BaseModelType +from invokeai.backend.model_manager import BaseModelType, ModelVariantType from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs @@ -58,6 +58,7 @@ from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0 from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager @@ -792,10 +793,11 @@ def step_callback(state: PipelineIntermediateState) -> None: ext_manager.add_extension(PreviewExt(step_callback)) ### inpaint - # TODO: add inpainting on normal model mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents) - if unet_config.variant == "inpaint": # ModelVariantType.Inpaint: + if unet_config.variant == ModelVariantType.Inpaint: ext_manager.add_extension(InpaintModelExt(mask, masked_latents, is_gradient_mask)) + elif mask is not None: + ext_manager.add_extension(InpaintExt(mask, is_gradient_mask)) # ext: t2i/ip adapter ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx) diff --git a/invokeai/backend/stable_diffusion/extensions/inpaint.py b/invokeai/backend/stable_diffusion/extensions/inpaint.py new file mode 100644 index 00000000000..5ef81f2c03c --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/inpaint.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import einops +import torch +from diffusers import UNet2DConditionModel + +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback + +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + + +class InpaintExt(ExtensionBase): + def __init__( + self, + mask: torch.Tensor, + is_gradient_mask: bool, + ): + super().__init__() + self.mask = mask + self.is_gradient_mask = is_gradient_mask + + @staticmethod + def _is_normal_model(unet: UNet2DConditionModel): + return unet.conv_in.in_channels == 4 + + def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + batch_size = latents.size(0) + mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size) + if t.dim() == 0: + # some schedulers expect t to be one-dimensional. + # TODO: file diffusers bug about inconsistency? + t = einops.repeat(t, "-> batch", batch=batch_size) + # Noise shouldn't be re-randomized between steps here. The multistep schedulers + # get very confused about what is happening from step to step when we do that. + mask_latents = ctx.scheduler.add_noise(ctx.inputs.orig_latents, self.noise, t) + # TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already? + # mask_latents = self.scheduler.scale_model_input(mask_latents, t) + mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size) + if self.is_gradient_mask: + threshhold = (t.item()) / ctx.scheduler.config.num_train_timesteps + mask_bool = mask > threshhold # I don't know when mask got inverted, but it did + masked_input = torch.where(mask_bool, latents, mask_latents) + else: + masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)) + return masked_input + + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) + def init_tensors(self, ctx: DenoiseContext): + if not self._is_normal_model(ctx.unet): + raise Exception("InpaintExt should be used only on normal models!") + + self.mask = self.mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype) + + self.noise = ctx.inputs.noise + if self.noise is None: + self.noise = torch.randn( + ctx.latents.shape, + dtype=torch.float32, + device="cpu", + generator=torch.Generator(device="cpu").manual_seed(ctx.seed), + ).to(device=ctx.latents.device, dtype=ctx.latents.dtype) + + # TODO: order value + @callback(ExtensionCallbackType.PRE_STEP, order=-100) + def apply_mask_to_initial_latents(self, ctx: DenoiseContext): + ctx.latents = self._apply_mask(ctx, ctx.latents, ctx.timestep) + + # TODO: order value + # TODO: redo this with preview events rewrite + @callback(ExtensionCallbackType.POST_STEP, order=-100) + def apply_mask_to_step_output(self, ctx: DenoiseContext): + timestep = ctx.scheduler.timesteps[-1] + if hasattr(ctx.step_output, "denoised"): + ctx.step_output.denoised = self._apply_mask(ctx, ctx.step_output.denoised, timestep) + elif hasattr(ctx.step_output, "pred_original_sample"): + ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.pred_original_sample, timestep) + else: + ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.prev_sample, timestep) + + # TODO: should here be used order? + # restore unmasked part after the last step is completed + @callback(ExtensionCallbackType.POST_DENOISE_LOOP) + def restore_unmasked(self, ctx: DenoiseContext): + if self.is_gradient_mask: + ctx.latents = torch.where(self.mask > 0, ctx.latents, ctx.inputs.orig_latents) + else: + ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self.mask) diff --git a/invokeai/backend/stable_diffusion/extensions/inpaint_model.py b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py index 190e0fa9316..b1cf8fa476e 100644 --- a/invokeai/backend/stable_diffusion/extensions/inpaint_model.py +++ b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py @@ -31,7 +31,7 @@ def _is_inpaint_model(unet: UNet2DConditionModel): @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) def init_tensors(self, ctx: DenoiseContext): if not self._is_inpaint_model(ctx.unet): - raise Exception("InpaintModelExt should be used only on inpaint model!") + raise Exception("InpaintModelExt should be used only on inpaint models!") if self.mask is None: self.mask = torch.ones_like(ctx.latents[:1, :1]) From 5003e5d763671d99893b63f8f7bb8d5b0b1f9aa9 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Mon, 22 Jul 2024 23:47:39 +0300 Subject: [PATCH 03/12] Same changes as in other PRs, add check for running inpainting on inpaint model without source image Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com> --- invokeai/app/invocations/denoise_latents.py | 2 +- .../stable_diffusion/extensions/inpaint.py | 27 ++++++++------- .../extensions/inpaint_model.py | 34 +++++++++---------- 3 files changed, 32 insertions(+), 31 deletions(-) diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 3a9e0291af7..eb1ee44bdaa 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -718,7 +718,7 @@ def prepare_noise_and_latents( return seed, noise, latents def invoke(self, context: InvocationContext) -> LatentsOutput: - if os.environ.get("USE_MODULAR_DENOISE", False): + if os.environ.get("USE_MODULAR_DENOISE", True): return self._new_invoke(context) else: return self._old_invoke(context) diff --git a/invokeai/backend/stable_diffusion/extensions/inpaint.py b/invokeai/backend/stable_diffusion/extensions/inpaint.py index 5ef81f2c03c..27ea0a4ed6a 100644 --- a/invokeai/backend/stable_diffusion/extensions/inpaint.py +++ b/invokeai/backend/stable_diffusion/extensions/inpaint.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import einops import torch @@ -20,8 +20,9 @@ def __init__( is_gradient_mask: bool, ): super().__init__() - self.mask = mask - self.is_gradient_mask = is_gradient_mask + self._mask = mask + self._is_gradient_mask = is_gradient_mask + self._noise: Optional[torch.Tensor] = None @staticmethod def _is_normal_model(unet: UNet2DConditionModel): @@ -29,18 +30,18 @@ def _is_normal_model(unet: UNet2DConditionModel): def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor: batch_size = latents.size(0) - mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size) + mask = einops.repeat(self._mask, "b c h w -> (repeat b) c h w", repeat=batch_size) if t.dim() == 0: # some schedulers expect t to be one-dimensional. # TODO: file diffusers bug about inconsistency? t = einops.repeat(t, "-> batch", batch=batch_size) # Noise shouldn't be re-randomized between steps here. The multistep schedulers # get very confused about what is happening from step to step when we do that. - mask_latents = ctx.scheduler.add_noise(ctx.inputs.orig_latents, self.noise, t) + mask_latents = ctx.scheduler.add_noise(ctx.inputs.orig_latents, self._noise, t) # TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already? # mask_latents = self.scheduler.scale_model_input(mask_latents, t) mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size) - if self.is_gradient_mask: + if self._is_gradient_mask: threshhold = (t.item()) / ctx.scheduler.config.num_train_timesteps mask_bool = mask > threshhold # I don't know when mask got inverted, but it did masked_input = torch.where(mask_bool, latents, mask_latents) @@ -53,11 +54,11 @@ def init_tensors(self, ctx: DenoiseContext): if not self._is_normal_model(ctx.unet): raise Exception("InpaintExt should be used only on normal models!") - self.mask = self.mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype) + self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype) - self.noise = ctx.inputs.noise - if self.noise is None: - self.noise = torch.randn( + self._noise = ctx.inputs.noise + if self._noise is None: + self._noise = torch.randn( ctx.latents.shape, dtype=torch.float32, device="cpu", @@ -85,7 +86,7 @@ def apply_mask_to_step_output(self, ctx: DenoiseContext): # restore unmasked part after the last step is completed @callback(ExtensionCallbackType.POST_DENOISE_LOOP) def restore_unmasked(self, ctx: DenoiseContext): - if self.is_gradient_mask: - ctx.latents = torch.where(self.mask > 0, ctx.latents, ctx.inputs.orig_latents) + if self._is_gradient_mask: + ctx.latents = torch.where(self._mask > 0, ctx.latents, ctx.inputs.orig_latents) else: - ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self.mask) + ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self._mask) diff --git a/invokeai/backend/stable_diffusion/extensions/inpaint_model.py b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py index b1cf8fa476e..9be259408f1 100644 --- a/invokeai/backend/stable_diffusion/extensions/inpaint_model.py +++ b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py @@ -20,9 +20,12 @@ def __init__( is_gradient_mask: bool, ): super().__init__() - self.mask = mask - self.masked_latents = masked_latents - self.is_gradient_mask = is_gradient_mask + if mask is not None and masked_latents is None: + raise ValueError("Source image required for inpaint mask when inpaint model used!") + + self._mask = mask + self._masked_latents = masked_latents + self._is_gradient_mask = is_gradient_mask @staticmethod def _is_inpaint_model(unet: UNet2DConditionModel): @@ -33,21 +36,21 @@ def init_tensors(self, ctx: DenoiseContext): if not self._is_inpaint_model(ctx.unet): raise Exception("InpaintModelExt should be used only on inpaint models!") - if self.mask is None: - self.mask = torch.ones_like(ctx.latents[:1, :1]) - self.mask = self.mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype) + if self._mask is None: + self._mask = torch.ones_like(ctx.latents[:1, :1]) + self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype) - if self.masked_latents is None: - self.masked_latents = torch.zeros_like(ctx.latents[:1]) - self.masked_latents = self.masked_latents.to(device=ctx.latents.device, dtype=ctx.latents.dtype) + if self._masked_latents is None: + self._masked_latents = torch.zeros_like(ctx.latents[:1]) + self._masked_latents = self._masked_latents.to(device=ctx.latents.device, dtype=ctx.latents.dtype) # TODO: any ideas about order value? # do last so that other extensions works with normal latents @callback(ExtensionCallbackType.PRE_UNET, order=1000) def append_inpaint_layers(self, ctx: DenoiseContext): batch_size = ctx.unet_kwargs.sample.shape[0] - b_mask = torch.cat([self.mask] * batch_size) - b_masked_latents = torch.cat([self.masked_latents] * batch_size) + b_mask = torch.cat([self._mask] * batch_size) + b_masked_latents = torch.cat([self._masked_latents] * batch_size) ctx.unet_kwargs.sample = torch.cat( [ctx.unet_kwargs.sample, b_mask, b_masked_latents], dim=1, @@ -57,10 +60,7 @@ def append_inpaint_layers(self, ctx: DenoiseContext): # restore unmasked part as inpaint model can change unmasked part slightly @callback(ExtensionCallbackType.POST_DENOISE_LOOP) def restore_unmasked(self, ctx: DenoiseContext): - if self.mask is None: - return - - if self.is_gradient_mask: - ctx.latents = torch.where(self.mask > 0, ctx.latents, ctx.inputs.orig_latents) + if self._is_gradient_mask: + ctx.latents = torch.where(self._mask > 0, ctx.latents, ctx.inputs.orig_latents) else: - ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self.mask) + ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self._mask) From 87eb0183807b05fe4843aa6718e9a7e1e18ee320 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Mon, 22 Jul 2024 23:49:20 +0300 Subject: [PATCH 04/12] Revert debug change --- invokeai/app/invocations/denoise_latents.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index eb1ee44bdaa..3a9e0291af7 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -718,7 +718,7 @@ def prepare_noise_and_latents( return seed, noise, latents def invoke(self, context: InvocationContext) -> LatentsOutput: - if os.environ.get("USE_MODULAR_DENOISE", True): + if os.environ.get("USE_MODULAR_DENOISE", False): return self._new_invoke(context) else: return self._old_invoke(context) From 9d1fcba415d29c7f3d29c55a8f9ba1c5f9274193 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 23 Jul 2024 23:29:28 +0300 Subject: [PATCH 05/12] Fix create gradient mask node output --- invokeai/app/invocations/create_gradient_mask.py | 1 + 1 file changed, 1 insertion(+) diff --git a/invokeai/app/invocations/create_gradient_mask.py b/invokeai/app/invocations/create_gradient_mask.py index 089313463bf..3b0afec1979 100644 --- a/invokeai/app/invocations/create_gradient_mask.py +++ b/invokeai/app/invocations/create_gradient_mask.py @@ -93,6 +93,7 @@ def invoke(self, context: InvocationContext) -> GradientMaskOutput: # redistribute blur so that the original edges are 0 and blur outwards to 1 blur_tensor = (blur_tensor - 0.5) * 2 + blur_tensor[blur_tensor < 0] = 0.0 threshold = 1 - self.minimum_denoise From c323a760a5bed1c5ecf701458b126927e66cf7b5 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 23 Jul 2024 23:34:28 +0300 Subject: [PATCH 06/12] Suggested changes Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com> --- invokeai/app/invocations/denoise_latents.py | 39 ++++++++++--------- .../stable_diffusion/extensions/inpaint.py | 30 ++++++++++++-- .../extensions/inpaint_model.py | 22 ++++++++++- 3 files changed, 68 insertions(+), 23 deletions(-) diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 3a9e0291af7..0d9293be021 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -732,10 +732,6 @@ def _new_invoke(self, context: InvocationContext) -> LatentsOutput: dtype = TorchDevice.choose_torch_dtype() seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents) - latents = latents.to(device=device, dtype=dtype) - if noise is not None: - noise = noise.to(device=device, dtype=dtype) - _, _, latent_height, latent_width = latents.shape conditioning_data = self.get_conditioning_data( @@ -768,21 +764,6 @@ def _new_invoke(self, context: InvocationContext) -> LatentsOutput: denoising_end=self.denoising_end, ) - denoise_ctx = DenoiseContext( - inputs=DenoiseInputs( - orig_latents=latents, - timesteps=timesteps, - init_timestep=init_timestep, - noise=noise, - seed=seed, - scheduler_step_kwargs=scheduler_step_kwargs, - conditioning_data=conditioning_data, - attention_processor_cls=CustomAttnProcessor2_0, - ), - unet=None, - scheduler=scheduler, - ) - # get the unet's config so that we can pass the base to sd_step_callback() unet_config = context.models.get_config(self.unet.unet.key) @@ -799,6 +780,26 @@ def step_callback(state: PipelineIntermediateState) -> None: elif mask is not None: ext_manager.add_extension(InpaintExt(mask, is_gradient_mask)) + # Initialize context for modular denoise + latents = latents.to(device=device, dtype=dtype) + if noise is not None: + noise = noise.to(device=device, dtype=dtype) + + denoise_ctx = DenoiseContext( + inputs=DenoiseInputs( + orig_latents=latents, + timesteps=timesteps, + init_timestep=init_timestep, + noise=noise, + seed=seed, + scheduler_step_kwargs=scheduler_step_kwargs, + conditioning_data=conditioning_data, + attention_processor_cls=CustomAttnProcessor2_0, + ), + unet=None, + scheduler=scheduler, + ) + # ext: t2i/ip adapter ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx) diff --git a/invokeai/backend/stable_diffusion/extensions/inpaint.py b/invokeai/backend/stable_diffusion/extensions/inpaint.py index 27ea0a4ed6a..fa58958b473 100644 --- a/invokeai/backend/stable_diffusion/extensions/inpaint.py +++ b/invokeai/backend/stable_diffusion/extensions/inpaint.py @@ -14,18 +14,40 @@ class InpaintExt(ExtensionBase): + """An extension for inpainting with non-inpainting models. See `InpaintModelExt` for inpainting with inpainting + models. + """ def __init__( self, mask: torch.Tensor, is_gradient_mask: bool, ): + """Initialize InpaintExt. + Args: + mask (torch.Tensor): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are + expected to be in the range [0, 1]. A value of 0 means that the corresponding 'pixel' should not be + inpainted. + is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range + from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or + 1. + """ super().__init__() self._mask = mask self._is_gradient_mask = is_gradient_mask + + # Noise, which used to noisify unmasked part of image + # if noise provided to context, then it will be used + # if no noise provided, then noise will be generated based on seed self._noise: Optional[torch.Tensor] = None @staticmethod def _is_normal_model(unet: UNet2DConditionModel): + """ Checks if the provided UNet belongs to a regular model. + The `in_channels` of a UNet vary depending on model type: + - normal - 4 + - depth - 5 + - inpaint - 9 + """ return unet.conv_in.in_channels == 4 def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor: @@ -42,8 +64,8 @@ def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tenso # mask_latents = self.scheduler.scale_model_input(mask_latents, t) mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size) if self._is_gradient_mask: - threshhold = (t.item()) / ctx.scheduler.config.num_train_timesteps - mask_bool = mask > threshhold # I don't know when mask got inverted, but it did + threshold = (t.item()) / ctx.scheduler.config.num_train_timesteps + mask_bool = mask > threshold masked_input = torch.where(mask_bool, latents, mask_latents) else: masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)) @@ -52,11 +74,13 @@ def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tenso @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) def init_tensors(self, ctx: DenoiseContext): if not self._is_normal_model(ctx.unet): - raise Exception("InpaintExt should be used only on normal models!") + raise ValueError("InpaintExt should be used only on normal models!") self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype) self._noise = ctx.inputs.noise + # 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner). + # We still need noise for inpainting, so we generate it from the seed here. if self._noise is None: self._noise = torch.randn( ctx.latents.shape, diff --git a/invokeai/backend/stable_diffusion/extensions/inpaint_model.py b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py index 9be259408f1..b5a08a85a85 100644 --- a/invokeai/backend/stable_diffusion/extensions/inpaint_model.py +++ b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py @@ -13,12 +13,26 @@ class InpaintModelExt(ExtensionBase): + """An extension for inpainting with inpainting models. See `InpaintExt` for inpainting with non-inpainting + models. + """ def __init__( self, mask: Optional[torch.Tensor], masked_latents: Optional[torch.Tensor], is_gradient_mask: bool, ): + """Initialize InpaintModelExt. + Args: + mask (Optional[torch.Tensor]): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are + expected to be in the range [0, 1]. A value of 0 means that the corresponding 'pixel' should not be + inpainted. + masked_latents (Optional[torch.Tensor]): Latents of initial image, with masked out by black color inpainted area. + If mask provided, then too should be provided. Shape: (1, 1, latent_height, latent_width) + is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range + from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or + 1. + """ super().__init__() if mask is not None and masked_latents is None: raise ValueError("Source image required for inpaint mask when inpaint model used!") @@ -29,12 +43,18 @@ def __init__( @staticmethod def _is_inpaint_model(unet: UNet2DConditionModel): + """ Checks if the provided UNet belongs to a regular model. + The `in_channels` of a UNet vary depending on model type: + - normal - 4 + - depth - 5 + - inpaint - 9 + """ return unet.conv_in.in_channels == 9 @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) def init_tensors(self, ctx: DenoiseContext): if not self._is_inpaint_model(ctx.unet): - raise Exception("InpaintModelExt should be used only on inpaint models!") + raise ValueError("InpaintModelExt should be used only on inpaint models!") if self._mask is None: self._mask = torch.ones_like(ctx.latents[:1, :1]) From 19c00241c6d5bc58ea0058aa90ffb9a22b4f30c7 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 24 Jul 2024 00:59:13 +0300 Subject: [PATCH 07/12] Use non-inverted mask generally(except inpaint model handling) --- invokeai/app/invocations/denoise_latents.py | 4 +++- .../backend/stable_diffusion/extensions/inpaint.py | 10 +++++----- .../stable_diffusion/extensions/inpaint_model.py | 7 +++++-- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 0d9293be021..b7a296a9b4c 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -674,7 +674,7 @@ def prep_inpaint_mask( else: masked_latents = torch.where(mask < 0.5, 0.0, latents) - return 1 - mask, masked_latents, self.denoise_mask.gradient + return mask, masked_latents, self.denoise_mask.gradient @staticmethod def prepare_noise_and_latents( @@ -830,6 +830,8 @@ def _old_invoke(self, context: InvocationContext) -> LatentsOutput: seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents) mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents) + if mask is not None: + mask = 1 - mask # TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets, # below. Investigate whether this is appropriate. diff --git a/invokeai/backend/stable_diffusion/extensions/inpaint.py b/invokeai/backend/stable_diffusion/extensions/inpaint.py index fa58958b473..6bf155b44fe 100644 --- a/invokeai/backend/stable_diffusion/extensions/inpaint.py +++ b/invokeai/backend/stable_diffusion/extensions/inpaint.py @@ -25,7 +25,7 @@ def __init__( """Initialize InpaintExt. Args: mask (torch.Tensor): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are - expected to be in the range [0, 1]. A value of 0 means that the corresponding 'pixel' should not be + expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be inpainted. is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or @@ -65,10 +65,10 @@ def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tenso mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size) if self._is_gradient_mask: threshold = (t.item()) / ctx.scheduler.config.num_train_timesteps - mask_bool = mask > threshold + mask_bool = mask < 1 - threshold masked_input = torch.where(mask_bool, latents, mask_latents) else: - masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)) + masked_input = torch.lerp(latents, mask_latents.to(dtype=latents.dtype), mask.to(dtype=latents.dtype)) return masked_input @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) @@ -111,6 +111,6 @@ def apply_mask_to_step_output(self, ctx: DenoiseContext): @callback(ExtensionCallbackType.POST_DENOISE_LOOP) def restore_unmasked(self, ctx: DenoiseContext): if self._is_gradient_mask: - ctx.latents = torch.where(self._mask > 0, ctx.latents, ctx.inputs.orig_latents) + ctx.latents = torch.where(self._mask < 1, ctx.latents, ctx.inputs.orig_latents) else: - ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self._mask) + ctx.latents = torch.lerp(ctx.latents, ctx.inputs.orig_latents, self._mask) diff --git a/invokeai/backend/stable_diffusion/extensions/inpaint_model.py b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py index b5a08a85a85..e1cadb0a2e2 100644 --- a/invokeai/backend/stable_diffusion/extensions/inpaint_model.py +++ b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py @@ -25,7 +25,7 @@ def __init__( """Initialize InpaintModelExt. Args: mask (Optional[torch.Tensor]): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are - expected to be in the range [0, 1]. A value of 0 means that the corresponding 'pixel' should not be + expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be inpainted. masked_latents (Optional[torch.Tensor]): Latents of initial image, with masked out by black color inpainted area. If mask provided, then too should be provided. Shape: (1, 1, latent_height, latent_width) @@ -37,7 +37,10 @@ def __init__( if mask is not None and masked_latents is None: raise ValueError("Source image required for inpaint mask when inpaint model used!") - self._mask = mask + # Inverse mask, because inpaint models treat mask as: 0 - remain same, 1 - inpaint + self._mask = None + if mask is not None: + self._mask = 1 - mask self._masked_latents = masked_latents self._is_gradient_mask = is_gradient_mask From 416d29fb839ecdae932e8f9a0907fd04fb5449ca Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 24 Jul 2024 01:17:28 +0300 Subject: [PATCH 08/12] Ruff format --- invokeai/backend/stable_diffusion/extensions/inpaint.py | 3 ++- invokeai/backend/stable_diffusion/extensions/inpaint_model.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/stable_diffusion/extensions/inpaint.py b/invokeai/backend/stable_diffusion/extensions/inpaint.py index 6bf155b44fe..7bdd9238dfd 100644 --- a/invokeai/backend/stable_diffusion/extensions/inpaint.py +++ b/invokeai/backend/stable_diffusion/extensions/inpaint.py @@ -17,6 +17,7 @@ class InpaintExt(ExtensionBase): """An extension for inpainting with non-inpainting models. See `InpaintModelExt` for inpainting with inpainting models. """ + def __init__( self, mask: torch.Tensor, @@ -42,7 +43,7 @@ def __init__( @staticmethod def _is_normal_model(unet: UNet2DConditionModel): - """ Checks if the provided UNet belongs to a regular model. + """Checks if the provided UNet belongs to a regular model. The `in_channels` of a UNet vary depending on model type: - normal - 4 - depth - 5 diff --git a/invokeai/backend/stable_diffusion/extensions/inpaint_model.py b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py index e1cadb0a2e2..4a89f8223f4 100644 --- a/invokeai/backend/stable_diffusion/extensions/inpaint_model.py +++ b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py @@ -16,6 +16,7 @@ class InpaintModelExt(ExtensionBase): """An extension for inpainting with inpainting models. See `InpaintExt` for inpainting with non-inpainting models. """ + def __init__( self, mask: Optional[torch.Tensor], @@ -46,7 +47,7 @@ def __init__( @staticmethod def _is_inpaint_model(unet: UNet2DConditionModel): - """ Checks if the provided UNet belongs to a regular model. + """Checks if the provided UNet belongs to a regular model. The `in_channels` of a UNet vary depending on model type: - normal - 4 - depth - 5 From bd8890be113b8cb4dae8c8bf5cf4222b057365da Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Fri, 26 Jul 2024 19:24:46 +0300 Subject: [PATCH 09/12] Revert "Fix create gradient mask node output" This reverts commit 9d1fcba415d29c7f3d29c55a8f9ba1c5f9274193. --- invokeai/app/invocations/create_gradient_mask.py | 1 - 1 file changed, 1 deletion(-) diff --git a/invokeai/app/invocations/create_gradient_mask.py b/invokeai/app/invocations/create_gradient_mask.py index 3b0afec1979..089313463bf 100644 --- a/invokeai/app/invocations/create_gradient_mask.py +++ b/invokeai/app/invocations/create_gradient_mask.py @@ -93,7 +93,6 @@ def invoke(self, context: InvocationContext) -> GradientMaskOutput: # redistribute blur so that the original edges are 0 and blur outwards to 1 blur_tensor = (blur_tensor - 0.5) * 2 - blur_tensor[blur_tensor < 0] = 0.0 threshold = 1 - self.minimum_denoise From 5810cee6c96292bef76cc7be522995eecf013a28 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Fri, 26 Jul 2024 19:47:28 +0300 Subject: [PATCH 10/12] Suggested changes Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com> --- invokeai/app/invocations/denoise_latents.py | 6 ++++++ invokeai/backend/stable_diffusion/extensions/inpaint.py | 6 +++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index b7a296a9b4c..c502234e5ee 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -775,6 +775,10 @@ def step_callback(state: PipelineIntermediateState) -> None: ### inpaint mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents) + # NOTE: We used to identify inpainting models by inpecting the shape of the loaded UNet model weights. Now we + # use the ModelVariantType config. During testing, there was a report of a user with models that had an + # incorrect ModelVariantType value. Re-installing the model fixed the issue. If this issue turns out to be + # prevalent, we will have to revisit how we initialize the inpainting extensions. if unet_config.variant == ModelVariantType.Inpaint: ext_manager.add_extension(InpaintModelExt(mask, masked_latents, is_gradient_mask)) elif mask is not None: @@ -830,6 +834,8 @@ def _old_invoke(self, context: InvocationContext) -> LatentsOutput: seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents) mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents) + # At this point, the mask ranges from 0 (leave unchanged) to 1 (inpaint). + # We invert the mask here for compatibility with the old backend implementation. if mask is not None: mask = 1 - mask diff --git a/invokeai/backend/stable_diffusion/extensions/inpaint.py b/invokeai/backend/stable_diffusion/extensions/inpaint.py index 7bdd9238dfd..437e06df76c 100644 --- a/invokeai/backend/stable_diffusion/extensions/inpaint.py +++ b/invokeai/backend/stable_diffusion/extensions/inpaint.py @@ -75,7 +75,11 @@ def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tenso @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) def init_tensors(self, ctx: DenoiseContext): if not self._is_normal_model(ctx.unet): - raise ValueError("InpaintExt should be used only on normal models!") + raise ValueError( + "InpaintExt should be used only on normal (non-inpainting) models. This could be caused by an " + "inpainting model that was incorrectly marked as a non-inpainting model. In some cases, this can be " + "fixed by removing and re-adding the model (so that it gets re-probed)." + ) self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype) From ed0174fbc6fec3df7f770494626731840dce58c7 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sat, 27 Jul 2024 13:18:28 +0300 Subject: [PATCH 11/12] Suggested changes Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com> --- invokeai/backend/stable_diffusion/extensions/inpaint.py | 7 +++---- .../backend/stable_diffusion/extensions/inpaint_model.py | 6 ++---- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/invokeai/backend/stable_diffusion/extensions/inpaint.py b/invokeai/backend/stable_diffusion/extensions/inpaint.py index 437e06df76c..00793591558 100644 --- a/invokeai/backend/stable_diffusion/extensions/inpaint.py +++ b/invokeai/backend/stable_diffusion/extensions/inpaint.py @@ -94,13 +94,13 @@ def init_tensors(self, ctx: DenoiseContext): generator=torch.Generator(device="cpu").manual_seed(ctx.seed), ).to(device=ctx.latents.device, dtype=ctx.latents.dtype) - # TODO: order value + # Use negative order to make extensions with default order work with patched latents @callback(ExtensionCallbackType.PRE_STEP, order=-100) def apply_mask_to_initial_latents(self, ctx: DenoiseContext): ctx.latents = self._apply_mask(ctx, ctx.latents, ctx.timestep) - # TODO: order value # TODO: redo this with preview events rewrite + # Use negative order to make extensions with default order work with patched latents @callback(ExtensionCallbackType.POST_STEP, order=-100) def apply_mask_to_step_output(self, ctx: DenoiseContext): timestep = ctx.scheduler.timesteps[-1] @@ -111,8 +111,7 @@ def apply_mask_to_step_output(self, ctx: DenoiseContext): else: ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.prev_sample, timestep) - # TODO: should here be used order? - # restore unmasked part after the last step is completed + # Restore unmasked part after the last step is completed @callback(ExtensionCallbackType.POST_DENOISE_LOOP) def restore_unmasked(self, ctx: DenoiseContext): if self._is_gradient_mask: diff --git a/invokeai/backend/stable_diffusion/extensions/inpaint_model.py b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py index 4a89f8223f4..98ee66c458c 100644 --- a/invokeai/backend/stable_diffusion/extensions/inpaint_model.py +++ b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py @@ -68,8 +68,7 @@ def init_tensors(self, ctx: DenoiseContext): self._masked_latents = torch.zeros_like(ctx.latents[:1]) self._masked_latents = self._masked_latents.to(device=ctx.latents.device, dtype=ctx.latents.dtype) - # TODO: any ideas about order value? - # do last so that other extensions works with normal latents + # Use negative order to make extensions with default order work with patched latents @callback(ExtensionCallbackType.PRE_UNET, order=1000) def append_inpaint_layers(self, ctx: DenoiseContext): batch_size = ctx.unet_kwargs.sample.shape[0] @@ -80,8 +79,7 @@ def append_inpaint_layers(self, ctx: DenoiseContext): dim=1, ) - # TODO: should here be used order? - # restore unmasked part as inpaint model can change unmasked part slightly + # Restore unmasked part as inpaint model can change unmasked part slightly @callback(ExtensionCallbackType.POST_DENOISE_LOOP) def restore_unmasked(self, ctx: DenoiseContext): if self._is_gradient_mask: From 84d028898cf4cc59f7b15a43fa8746469534cb0e Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sat, 27 Jul 2024 13:20:58 +0300 Subject: [PATCH 12/12] Revert wrong comment copy --- invokeai/backend/stable_diffusion/extensions/inpaint_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/backend/stable_diffusion/extensions/inpaint_model.py b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py index 98ee66c458c..6ee8ef6311c 100644 --- a/invokeai/backend/stable_diffusion/extensions/inpaint_model.py +++ b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py @@ -68,7 +68,7 @@ def init_tensors(self, ctx: DenoiseContext): self._masked_latents = torch.zeros_like(ctx.latents[:1]) self._masked_latents = self._masked_latents.to(device=ctx.latents.device, dtype=ctx.latents.dtype) - # Use negative order to make extensions with default order work with patched latents + # Do last so that other extensions works with normal latents @callback(ExtensionCallbackType.PRE_UNET, order=1000) def append_inpaint_layers(self, ctx: DenoiseContext): batch_size = ctx.unet_kwargs.sample.shape[0]