diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 560bc9003c6..96fbb2ca24b 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -901,7 +901,7 @@ def step_callback(state: PipelineIntermediateState) -> None: # ext: freeu, seamless, ip adapter, lora ext_manager.patch_unet(unet, cached_weights), ): - sd_backend = StableDiffusionBackend(unet, scheduler) + sd_backend = StableDiffusionBackend() denoise_ctx.unet = unet result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager) diff --git a/invokeai/backend/stable_diffusion/denoise_context.py b/invokeai/backend/stable_diffusion/denoise_context.py index 9060d549776..7642c45c3ad 100644 --- a/invokeai/backend/stable_diffusion/denoise_context.py +++ b/invokeai/backend/stable_diffusion/denoise_context.py @@ -96,7 +96,8 @@ class DenoiseContext: timestep: Optional[torch.Tensor] = None # Arguments which will be passed to UNet model. - # Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None. + # Available in `PRE_UNET_FORWARD`/`POST_UNET_FORWARD` callbacks + # and in `UNET_FORWARD` override, otherwise will be None. unet_kwargs: Optional[UNetKwargs] = None # SchedulerOutput class returned from step function(normally, generated by scheduler). @@ -109,7 +110,8 @@ class DenoiseContext: latent_model_input: Optional[torch.Tensor] = None # [TMP] Defines on which conditionings current unet call will be runned. - # Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None. + # Available in `PRE_UNET_FORWARD`/`POST_UNET_FORWARD` callbacks + # and in `UNET_FORWARD` override, otherwise will be None. conditioning_mode: Optional[ConditioningMode] = None # [TMP] Noise predictions from negative conditioning. diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py index 4191db734f9..9df7d18d229 100644 --- a/invokeai/backend/stable_diffusion/diffusion_backend.py +++ b/invokeai/backend/stable_diffusion/diffusion_backend.py @@ -1,25 +1,19 @@ from __future__ import annotations import torch -from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel -from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput +from diffusers.schedulers.scheduling_utils import SchedulerOutput from tqdm.auto import tqdm from invokeai.app.services.config.config_default import get_config from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extension_override_type import ExtensionOverrideType from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager class StableDiffusionBackend: - def __init__( - self, - unet: UNet2DConditionModel, - scheduler: SchedulerMixin, - ): - self.unet = unet - self.scheduler = scheduler + def __init__(self): config = get_config() self._sequential_guidance = config.sequential_guidance @@ -31,7 +25,7 @@ def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsMa if ctx.inputs.noise is not None: batch_size = ctx.latents.shape[0] - # latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers + # latents = noise * ctx.scheduler.init_noise_sigma # it's like in t2l according to diffusers ctx.latents = ctx.scheduler.add_noise( ctx.latents, ctx.inputs.noise, ctx.inputs.init_timestep.expand(batch_size) ) @@ -49,7 +43,7 @@ def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsMa ext_manager.run_callback(ExtensionCallbackType.PRE_STEP, ctx) # ext: tiles? [override: step] - ctx.step_output = self.step(ctx, ext_manager) + ctx.step_output = ext_manager.run_override(ExtensionOverrideType.STEP, self.step, ctx, ext_manager) # ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models) # ext: preview[post_step, priority=low] @@ -77,7 +71,9 @@ def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> Scheduler ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2) # ext: override combine_noise_preds - ctx.noise_pred = self.combine_noise_preds(ctx) + ctx.noise_pred = ext_manager.run_override( + ExtensionOverrideType.COMBINE_NOISE_PREDS, self.combine_noise_preds, ctx, ext_manager + ) # ext: cfg_rescale [modify_noise_prediction] # TODO: rename @@ -94,17 +90,6 @@ def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> Scheduler return step_output - @staticmethod - def combine_noise_preds(ctx: DenoiseContext) -> torch.Tensor: - guidance_scale = ctx.inputs.conditioning_data.guidance_scale - if isinstance(guidance_scale, list): - guidance_scale = guidance_scale[ctx.step_index] - - # Note: Although this `torch.lerp(...)` line is logically equivalent to the current CFG line, it seems to result - # in slightly different outputs. It is suspected that this is caused by small precision differences. - # return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale) - return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred) - def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode): sample = ctx.latent_model_input if conditioning_mode == ConditioningMode.Both: @@ -122,15 +107,13 @@ def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditio ctx.conditioning_mode = conditioning_mode ctx.inputs.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode) - # ext: controlnet/ip/t2i [pre_unet] - ext_manager.run_callback(ExtensionCallbackType.PRE_UNET, ctx) + # ext: controlnet/ip/t2i [pre_unet_forward] + ext_manager.run_callback(ExtensionCallbackType.PRE_UNET_FORWARD, ctx) - # ext: inpaint [pre_unet, priority=low] - # or - # ext: inpaint [override: unet_forward] - noise_pred = self._unet_forward(**vars(ctx.unet_kwargs)) + # ext: inpaint model/ic-light [override: unet_forward] + noise_pred = ext_manager.run_override(ExtensionOverrideType.UNET_FORWARD, self.unet_forward, ctx, ext_manager) - ext_manager.run_callback(ExtensionCallbackType.POST_UNET, ctx) + ext_manager.run_callback(ExtensionCallbackType.POST_UNET_FORWARD, ctx) # clean up locals ctx.unet_kwargs = None @@ -138,5 +121,17 @@ def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditio return noise_pred - def _unet_forward(self, **kwargs) -> torch.Tensor: - return self.unet(**kwargs).sample + # pass extensions manager as arg to allow override access it + def combine_noise_preds(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> torch.Tensor: + guidance_scale = ctx.inputs.conditioning_data.guidance_scale + if isinstance(guidance_scale, list): + guidance_scale = guidance_scale[ctx.step_index] + + # Note: Although this `torch.lerp(...)` line is logically equivalent to the current CFG line, it seems to result + # in slightly different outputs. It is suspected that this is caused by small precision differences. + # return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale) + return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred) + + # pass extensions manager as arg to allow override access it + def unet_forward(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> torch.Tensor: + return ctx.unet(**vars(ctx.unet_kwargs)).sample diff --git a/invokeai/backend/stable_diffusion/extension_callback_type.py b/invokeai/backend/stable_diffusion/extension_callback_type.py index e4c365007ba..8dfb1441568 100644 --- a/invokeai/backend/stable_diffusion/extension_callback_type.py +++ b/invokeai/backend/stable_diffusion/extension_callback_type.py @@ -7,6 +7,6 @@ class ExtensionCallbackType(Enum): POST_DENOISE_LOOP = "post_denoise_loop" PRE_STEP = "pre_step" POST_STEP = "post_step" - PRE_UNET = "pre_unet" - POST_UNET = "post_unet" + PRE_UNET_FORWARD = "pre_unet_forward" + POST_UNET_FORWARD = "post_unet_forward" POST_COMBINE_NOISE_PREDS = "post_combine_noise_preds" diff --git a/invokeai/backend/stable_diffusion/extension_override_type.py b/invokeai/backend/stable_diffusion/extension_override_type.py new file mode 100644 index 00000000000..9256a736fd4 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extension_override_type.py @@ -0,0 +1,7 @@ +from enum import Enum + + +class ExtensionOverrideType(Enum): + STEP = "step" + UNET_FORWARD = "unet_forward" + COMBINE_NOISE_PREDS = "combine_noise_preds" diff --git a/invokeai/backend/stable_diffusion/extensions/base.py b/invokeai/backend/stable_diffusion/extensions/base.py index 820d5d32a37..2667e7344fd 100644 --- a/invokeai/backend/stable_diffusion/extensions/base.py +++ b/invokeai/backend/stable_diffusion/extensions/base.py @@ -2,7 +2,7 @@ from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional import torch from diffusers import UNet2DConditionModel @@ -10,6 +10,7 @@ if TYPE_CHECKING: from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType + from invokeai.backend.stable_diffusion.extension_override_type import ExtensionOverrideType @dataclass @@ -35,22 +36,54 @@ def _decorator(function): return _decorator +@dataclass +class OverrideMetadata: + override_type: ExtensionOverrideType + + +@dataclass +class OverrideFunctionWithMetadata: + metadata: OverrideMetadata + function: Callable[..., Any] + + +def override(override_type: ExtensionOverrideType): + def _decorator(function): + function._ext_metadata = OverrideMetadata( + override_type=override_type, + ) + return function + + return _decorator + + class ExtensionBase: def __init__(self): self._callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {} + self._overrides: Dict[ExtensionOverrideType, OverrideFunctionWithMetadata] = {} # Register all of the callback methods for this instance. for func_name in dir(self): func = getattr(self, func_name) metadata = getattr(func, "_ext_metadata", None) - if metadata is not None and isinstance(metadata, CallbackMetadata): - if metadata.callback_type not in self._callbacks: - self._callbacks[metadata.callback_type] = [] - self._callbacks[metadata.callback_type].append(CallbackFunctionWithMetadata(metadata, func)) + if metadata is not None: + if isinstance(metadata, CallbackMetadata): + if metadata.callback_type not in self._callbacks: + self._callbacks[metadata.callback_type] = [] + self._callbacks[metadata.callback_type].append(CallbackFunctionWithMetadata(metadata, func)) + elif isinstance(metadata, OverrideMetadata): + if metadata.override_type in self._overrides: + raise RuntimeError( + f"Override {metadata.override_type} defined multiple times in {type(self).__qualname__}" + ) + self._overrides[metadata.override_type] = OverrideFunctionWithMetadata(metadata, func) def get_callbacks(self): return self._callbacks + def get_overrides(self): + return self._overrides + @contextmanager def patch_extension(self, ctx: DenoiseContext): yield None diff --git a/invokeai/backend/stable_diffusion/extensions/controlnet.py b/invokeai/backend/stable_diffusion/extensions/controlnet.py index a48a681af3f..4b8b748a1ef 100644 --- a/invokeai/backend/stable_diffusion/extensions/controlnet.py +++ b/invokeai/backend/stable_diffusion/extensions/controlnet.py @@ -68,8 +68,8 @@ def resize_image(self, ctx: DenoiseContext): resize_mode=self._resize_mode, ) - @callback(ExtensionCallbackType.PRE_UNET) - def pre_unet_step(self, ctx: DenoiseContext): + @callback(ExtensionCallbackType.PRE_UNET_FORWARD) + def pre_unet_forward(self, ctx: DenoiseContext): # skip if model not active in current step total_steps = len(ctx.inputs.timesteps) first_step = math.floor(self._begin_step_percent * total_steps) diff --git a/invokeai/backend/stable_diffusion/extensions/inpaint_model.py b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py index 6ee8ef6311c..cfe44f8125f 100644 --- a/invokeai/backend/stable_diffusion/extensions/inpaint_model.py +++ b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py @@ -1,15 +1,17 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional import torch from diffusers import UNet2DConditionModel from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType -from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback +from invokeai.backend.stable_diffusion.extension_override_type import ExtensionOverrideType +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback, override if TYPE_CHECKING: from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager class InpaintModelExt(ExtensionBase): @@ -68,9 +70,8 @@ def init_tensors(self, ctx: DenoiseContext): self._masked_latents = torch.zeros_like(ctx.latents[:1]) self._masked_latents = self._masked_latents.to(device=ctx.latents.device, dtype=ctx.latents.dtype) - # Do last so that other extensions works with normal latents - @callback(ExtensionCallbackType.PRE_UNET, order=1000) - def append_inpaint_layers(self, ctx: DenoiseContext): + @override(ExtensionOverrideType.UNET_FORWARD) + def append_inpaint_layers(self, orig_func: Callable[..., Any], ctx: DenoiseContext, ext_manager: ExtensionsManager): batch_size = ctx.unet_kwargs.sample.shape[0] b_mask = torch.cat([self._mask] * batch_size) b_masked_latents = torch.cat([self._masked_latents] * batch_size) @@ -78,6 +79,7 @@ def append_inpaint_layers(self, ctx: DenoiseContext): [ctx.unet_kwargs.sample, b_mask, b_masked_latents], dim=1, ) + return orig_func(ctx, ext_manager) # Restore unmasked part as inpaint model can change unmasked part slightly @callback(ExtensionCallbackType.POST_DENOISE_LOOP) diff --git a/invokeai/backend/stable_diffusion/extensions/t2i_adapter.py b/invokeai/backend/stable_diffusion/extensions/t2i_adapter.py index 5c290ea4e79..c7d1fc40646 100644 --- a/invokeai/backend/stable_diffusion/extensions/t2i_adapter.py +++ b/invokeai/backend/stable_diffusion/extensions/t2i_adapter.py @@ -96,8 +96,8 @@ def _run_model( return model(t2i_image) - @callback(ExtensionCallbackType.PRE_UNET) - def pre_unet_step(self, ctx: DenoiseContext): + @callback(ExtensionCallbackType.PRE_UNET_FORWARD) + def pre_unet_forward(self, ctx: DenoiseContext): # skip if model not active in current step total_steps = len(ctx.inputs.timesteps) first_step = math.floor(self._begin_step_percent * total_steps) diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index c8d585406a8..b9389e83bea 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import ExitStack, contextmanager -from typing import TYPE_CHECKING, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional import torch from diffusers import UNet2DConditionModel @@ -11,7 +11,12 @@ if TYPE_CHECKING: from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType - from invokeai.backend.stable_diffusion.extensions.base import CallbackFunctionWithMetadata, ExtensionBase + from invokeai.backend.stable_diffusion.extension_override_type import ExtensionOverrideType + from invokeai.backend.stable_diffusion.extensions.base import ( + CallbackFunctionWithMetadata, + ExtensionBase, + OverrideFunctionWithMetadata, + ) class ExtensionsManager: @@ -21,11 +26,19 @@ def __init__(self, is_canceled: Optional[Callable[[], bool]] = None): # A list of extensions in the order that they were added to the ExtensionsManager. self._extensions: List[ExtensionBase] = [] self._ordered_callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {} + self._overrides: Dict[ExtensionOverrideType, OverrideFunctionWithMetadata] = {} def add_extension(self, extension: ExtensionBase): self._extensions.append(extension) self._regenerate_ordered_callbacks() + for override_type, override in extension.get_overrides().items(): + if override_type in self._overrides: + raise RuntimeError( + f"Override {override_type} already defined by {self._overrides[override_type].function.__qualname__}" + ) + self._overrides[override_type] = override + def _regenerate_ordered_callbacks(self): """Regenerates self._ordered_callbacks. Intended to be called each time a new extension is added.""" self._ordered_callbacks = {} @@ -51,6 +64,16 @@ def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext for cb in callbacks: cb.function(ctx) + def run_override(self, override_type: ExtensionOverrideType, orig_function: Callable[..., Any], *args, **kwargs): + if self._is_canceled and self._is_canceled(): + raise CanceledException + + override = self._overrides.get(override_type, None) + if override is not None: + return override.function(orig_function, *args, **kwargs) + else: + return orig_function(*args, **kwargs) + @contextmanager def patch_extensions(self, ctx: DenoiseContext): if self._is_canceled and self._is_canceled():