Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modular backend - overrides #6692

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,7 @@ def step_callback(state: PipelineIntermediateState) -> None:
# ext: freeu, seamless, ip adapter, lora
ext_manager.patch_unet(unet, cached_weights),
):
sd_backend = StableDiffusionBackend(unet, scheduler)
sd_backend = StableDiffusionBackend()
denoise_ctx.unet = unet
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)

Expand Down
6 changes: 4 additions & 2 deletions invokeai/backend/stable_diffusion/denoise_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ class DenoiseContext:
timestep: Optional[torch.Tensor] = None

# Arguments which will be passed to UNet model.
# Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
# Available in `PRE_UNET_FORWARD`/`POST_UNET_FORWARD` callbacks
# and in `UNET_FORWARD` override, otherwise will be None.
unet_kwargs: Optional[UNetKwargs] = None

# SchedulerOutput class returned from step function(normally, generated by scheduler).
Expand All @@ -109,7 +110,8 @@ class DenoiseContext:
latent_model_input: Optional[torch.Tensor] = None

# [TMP] Defines on which conditionings current unet call will be runned.
# Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
# Available in `PRE_UNET_FORWARD`/`POST_UNET_FORWARD` callbacks
# and in `UNET_FORWARD` override, otherwise will be None.
conditioning_mode: Optional[ConditioningMode] = None

# [TMP] Noise predictions from negative conditioning.
Expand Down
59 changes: 27 additions & 32 deletions invokeai/backend/stable_diffusion/diffusion_backend.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,19 @@
from __future__ import annotations

import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers.schedulers.scheduling_utils import SchedulerOutput
from tqdm.auto import tqdm

from invokeai.app.services.config.config_default import get_config
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extension_override_type import ExtensionOverrideType
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager


class StableDiffusionBackend:
def __init__(
self,
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
):
self.unet = unet
self.scheduler = scheduler
def __init__(self):
config = get_config()
self._sequential_guidance = config.sequential_guidance

Expand All @@ -31,7 +25,7 @@ def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsMa

if ctx.inputs.noise is not None:
batch_size = ctx.latents.shape[0]
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
# latents = noise * ctx.scheduler.init_noise_sigma # it's like in t2l according to diffusers
ctx.latents = ctx.scheduler.add_noise(
ctx.latents, ctx.inputs.noise, ctx.inputs.init_timestep.expand(batch_size)
)
Expand All @@ -49,7 +43,7 @@ def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsMa
ext_manager.run_callback(ExtensionCallbackType.PRE_STEP, ctx)

# ext: tiles? [override: step]
ctx.step_output = self.step(ctx, ext_manager)
ctx.step_output = ext_manager.run_override(ExtensionOverrideType.STEP, self.step, ctx, ext_manager)

# ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models)
# ext: preview[post_step, priority=low]
Expand Down Expand Up @@ -77,7 +71,9 @@ def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> Scheduler
ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)

# ext: override combine_noise_preds
ctx.noise_pred = self.combine_noise_preds(ctx)
ctx.noise_pred = ext_manager.run_override(
ExtensionOverrideType.COMBINE_NOISE_PREDS, self.combine_noise_preds, ctx, ext_manager
)

# ext: cfg_rescale [modify_noise_prediction]
# TODO: rename
Expand All @@ -94,17 +90,6 @@ def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> Scheduler

return step_output

@staticmethod
def combine_noise_preds(ctx: DenoiseContext) -> torch.Tensor:
guidance_scale = ctx.inputs.conditioning_data.guidance_scale
if isinstance(guidance_scale, list):
guidance_scale = guidance_scale[ctx.step_index]

# Note: Although this `torch.lerp(...)` line is logically equivalent to the current CFG line, it seems to result
# in slightly different outputs. It is suspected that this is caused by small precision differences.
# return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)

def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode):
sample = ctx.latent_model_input
if conditioning_mode == ConditioningMode.Both:
Expand All @@ -122,21 +107,31 @@ def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditio
ctx.conditioning_mode = conditioning_mode
ctx.inputs.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode)

# ext: controlnet/ip/t2i [pre_unet]
ext_manager.run_callback(ExtensionCallbackType.PRE_UNET, ctx)
# ext: controlnet/ip/t2i [pre_unet_forward]
ext_manager.run_callback(ExtensionCallbackType.PRE_UNET_FORWARD, ctx)

# ext: inpaint [pre_unet, priority=low]
# or
# ext: inpaint [override: unet_forward]
noise_pred = self._unet_forward(**vars(ctx.unet_kwargs))
# ext: inpaint model/ic-light [override: unet_forward]
noise_pred = ext_manager.run_override(ExtensionOverrideType.UNET_FORWARD, self.unet_forward, ctx, ext_manager)

ext_manager.run_callback(ExtensionCallbackType.POST_UNET, ctx)
ext_manager.run_callback(ExtensionCallbackType.POST_UNET_FORWARD, ctx)

# clean up locals
ctx.unet_kwargs = None
ctx.conditioning_mode = None

return noise_pred

def _unet_forward(self, **kwargs) -> torch.Tensor:
return self.unet(**kwargs).sample
# pass extensions manager as arg to allow override access it
def combine_noise_preds(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> torch.Tensor:
guidance_scale = ctx.inputs.conditioning_data.guidance_scale
if isinstance(guidance_scale, list):
guidance_scale = guidance_scale[ctx.step_index]

# Note: Although this `torch.lerp(...)` line is logically equivalent to the current CFG line, it seems to result
# in slightly different outputs. It is suspected that this is caused by small precision differences.
# return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)

# pass extensions manager as arg to allow override access it
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would we want this? Seems like it just opens the door for a bunch of messiness.

Copy link
Contributor Author

@StAlKeR7779 StAlKeR7779 Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How else tiled denoise will be able to call original step function or callbacks?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this will be easier to discuss in the context of the tiled denoise PR? It seems to me that if we can avoid passing the ext_manager down to callbacks/overrides then that would keep things quite a bit simpler.

def unet_forward(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> torch.Tensor:
return ctx.unet(**vars(ctx.unet_kwargs)).sample
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ class ExtensionCallbackType(Enum):
POST_DENOISE_LOOP = "post_denoise_loop"
PRE_STEP = "pre_step"
POST_STEP = "post_step"
PRE_UNET = "pre_unet"
POST_UNET = "post_unet"
PRE_UNET_FORWARD = "pre_unet_forward"
POST_UNET_FORWARD = "post_unet_forward"
POST_COMBINE_NOISE_PREDS = "post_combine_noise_preds"
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from enum import Enum


class ExtensionOverrideType(Enum):
STEP = "step"
UNET_FORWARD = "unet_forward"
COMBINE_NOISE_PREDS = "combine_noise_preds"
43 changes: 38 additions & 5 deletions invokeai/backend/stable_diffusion/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional

import torch
from diffusers import UNet2DConditionModel

if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extension_override_type import ExtensionOverrideType


@dataclass
Expand All @@ -35,22 +36,54 @@ def _decorator(function):
return _decorator


@dataclass
class OverrideMetadata:
override_type: ExtensionOverrideType


@dataclass
class OverrideFunctionWithMetadata:
metadata: OverrideMetadata
function: Callable[..., Any]


def override(override_type: ExtensionOverrideType):
def _decorator(function):
function._ext_metadata = OverrideMetadata(
override_type=override_type,
)
return function

return _decorator


class ExtensionBase:
def __init__(self):
self._callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
self._overrides: Dict[ExtensionOverrideType, OverrideFunctionWithMetadata] = {}
Comment on lines 62 to +63
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docs explaining the difference between _overrides and _callbacks. Include guidance for developers for how they should decide between using a callback vs. an override. In some cases, both could work, so highlight the things that should be considered to decide between them.


# Register all of the callback methods for this instance.
for func_name in dir(self):
func = getattr(self, func_name)
metadata = getattr(func, "_ext_metadata", None)
if metadata is not None and isinstance(metadata, CallbackMetadata):
if metadata.callback_type not in self._callbacks:
self._callbacks[metadata.callback_type] = []
self._callbacks[metadata.callback_type].append(CallbackFunctionWithMetadata(metadata, func))
if metadata is not None:
if isinstance(metadata, CallbackMetadata):
if metadata.callback_type not in self._callbacks:
self._callbacks[metadata.callback_type] = []
self._callbacks[metadata.callback_type].append(CallbackFunctionWithMetadata(metadata, func))
elif isinstance(metadata, OverrideMetadata):
if metadata.override_type in self._overrides:
raise RuntimeError(
f"Override {metadata.override_type} defined multiple times in {type(self).__qualname__}"
)
self._overrides[metadata.override_type] = OverrideFunctionWithMetadata(metadata, func)

def get_callbacks(self):
return self._callbacks

def get_overrides(self):
return self._overrides

@contextmanager
def patch_extension(self, ctx: DenoiseContext):
yield None
Expand Down
4 changes: 2 additions & 2 deletions invokeai/backend/stable_diffusion/extensions/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def resize_image(self, ctx: DenoiseContext):
resize_mode=self._resize_mode,
)

@callback(ExtensionCallbackType.PRE_UNET)
def pre_unet_step(self, ctx: DenoiseContext):
@callback(ExtensionCallbackType.PRE_UNET_FORWARD)
def pre_unet_forward(self, ctx: DenoiseContext):
# skip if model not active in current step
total_steps = len(ctx.inputs.timesteps)
first_step = math.floor(self._begin_step_percent * total_steps)
Expand Down
12 changes: 7 additions & 5 deletions invokeai/backend/stable_diffusion/extensions/inpaint_model.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Callable, Optional

import torch
from diffusers import UNet2DConditionModel

from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
from invokeai.backend.stable_diffusion.extension_override_type import ExtensionOverrideType
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback, override

if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager


class InpaintModelExt(ExtensionBase):
Expand Down Expand Up @@ -68,16 +70,16 @@ def init_tensors(self, ctx: DenoiseContext):
self._masked_latents = torch.zeros_like(ctx.latents[:1])
self._masked_latents = self._masked_latents.to(device=ctx.latents.device, dtype=ctx.latents.dtype)

# Do last so that other extensions works with normal latents
@callback(ExtensionCallbackType.PRE_UNET, order=1000)
def append_inpaint_layers(self, ctx: DenoiseContext):
@override(ExtensionOverrideType.UNET_FORWARD)
def append_inpaint_layers(self, orig_func: Callable[..., Any], ctx: DenoiseContext, ext_manager: ExtensionsManager):
batch_size = ctx.unet_kwargs.sample.shape[0]
b_mask = torch.cat([self._mask] * batch_size)
b_masked_latents = torch.cat([self._masked_latents] * batch_size)
ctx.unet_kwargs.sample = torch.cat(
[ctx.unet_kwargs.sample, b_mask, b_masked_latents],
dim=1,
)
return orig_func(ctx, ext_manager)

# Restore unmasked part as inpaint model can change unmasked part slightly
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
Expand Down
4 changes: 2 additions & 2 deletions invokeai/backend/stable_diffusion/extensions/t2i_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def _run_model(

return model(t2i_image)

@callback(ExtensionCallbackType.PRE_UNET)
def pre_unet_step(self, ctx: DenoiseContext):
@callback(ExtensionCallbackType.PRE_UNET_FORWARD)
def pre_unet_forward(self, ctx: DenoiseContext):
# skip if model not active in current step
total_steps = len(ctx.inputs.timesteps)
first_step = math.floor(self._begin_step_percent * total_steps)
Expand Down
27 changes: 25 additions & 2 deletions invokeai/backend/stable_diffusion/extensions_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from contextlib import ExitStack, contextmanager
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional

import torch
from diffusers import UNet2DConditionModel
Expand All @@ -11,7 +11,12 @@
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.base import CallbackFunctionWithMetadata, ExtensionBase
from invokeai.backend.stable_diffusion.extension_override_type import ExtensionOverrideType
from invokeai.backend.stable_diffusion.extensions.base import (
CallbackFunctionWithMetadata,
ExtensionBase,
OverrideFunctionWithMetadata,
)


class ExtensionsManager:
Expand All @@ -21,11 +26,19 @@ def __init__(self, is_canceled: Optional[Callable[[], bool]] = None):
# A list of extensions in the order that they were added to the ExtensionsManager.
self._extensions: List[ExtensionBase] = []
self._ordered_callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
self._overrides: Dict[ExtensionOverrideType, OverrideFunctionWithMetadata] = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add unit tests for the new ExtensionBase/ExtensionsManager functionality given it's core role. It should be straightforward - you can use the existing tests for reference. I think we'd roughly want tests for each of the following:

  • When an override is registered, it get's called
  • Calling an override type with no override registered behaves as expected
  • When duplicate overrides are registered, a meaningful error is raised


def add_extension(self, extension: ExtensionBase):
self._extensions.append(extension)
self._regenerate_ordered_callbacks()

for override_type, override in extension.get_overrides().items():
if override_type in self._overrides:
raise RuntimeError(
f"Override {override_type} already defined by {self._overrides[override_type].function.__qualname__}"
)
self._overrides[override_type] = override

def _regenerate_ordered_callbacks(self):
"""Regenerates self._ordered_callbacks. Intended to be called each time a new extension is added."""
self._ordered_callbacks = {}
Expand All @@ -51,6 +64,16 @@ def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext
for cb in callbacks:
cb.function(ctx)

def run_override(self, override_type: ExtensionOverrideType, orig_function: Callable[..., Any], *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have typed function signatures for each override type given that the signatures are known and there aren't very many of them (instead of passing *args and **kwargs).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for passing orig_function? If the orig_function needs to be called, it feels like those use cases could be solved with callbacks.

Copy link
Contributor Author

@StAlKeR7779 StAlKeR7779 Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least to allow extension manager to run original non-overriden implementation.
And also simply because it more flexible, you don't need to implement underlying logic if you only patch it slightly.
Also tiled decode will use orig function of step on each tile.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Including it in the function signature implies that the function / override should handle it.

For the case that you're describing, I imagined that it would just look like this:

from ... import unet_forward

class AnExtension(ExtensionBase):
    @override(ExtensionOverrideType.UNET_FORWARD)
    def custom_unet_forward(self, ...):
        # Do some stuff...
        unet_forward(...)
        

What do you think?

if self._is_canceled and self._is_canceled():
raise CanceledException

override = self._overrides.get(override_type, None)
if override is not None:
return override.function(orig_function, *args, **kwargs)
else:
return orig_function(*args, **kwargs)

@contextmanager
def patch_extensions(self, ctx: DenoiseContext):
if self._is_canceled and self._is_canceled():
Expand Down