-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modular backend - overrides #6692
base: main
Are you sure you want to change the base?
Changes from all commits
694e802
920ea95
4d893c9
b6c0b4f
6159435
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from enum import Enum | ||
|
||
|
||
class ExtensionOverrideType(Enum): | ||
STEP = "step" | ||
UNET_FORWARD = "unet_forward" | ||
COMBINE_NOISE_PREDS = "combine_noise_preds" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,14 +2,15 @@ | |
|
||
from contextlib import contextmanager | ||
from dataclasses import dataclass | ||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional | ||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional | ||
|
||
import torch | ||
from diffusers import UNet2DConditionModel | ||
|
||
if TYPE_CHECKING: | ||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext | ||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType | ||
from invokeai.backend.stable_diffusion.extension_override_type import ExtensionOverrideType | ||
|
||
|
||
@dataclass | ||
|
@@ -35,22 +36,54 @@ def _decorator(function): | |
return _decorator | ||
|
||
|
||
@dataclass | ||
class OverrideMetadata: | ||
override_type: ExtensionOverrideType | ||
|
||
|
||
@dataclass | ||
class OverrideFunctionWithMetadata: | ||
metadata: OverrideMetadata | ||
function: Callable[..., Any] | ||
|
||
|
||
def override(override_type: ExtensionOverrideType): | ||
def _decorator(function): | ||
function._ext_metadata = OverrideMetadata( | ||
override_type=override_type, | ||
) | ||
return function | ||
|
||
return _decorator | ||
|
||
|
||
class ExtensionBase: | ||
def __init__(self): | ||
self._callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {} | ||
self._overrides: Dict[ExtensionOverrideType, OverrideFunctionWithMetadata] = {} | ||
Comment on lines
62
to
+63
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add docs explaining the difference between _overrides and _callbacks. Include guidance for developers for how they should decide between using a callback vs. an override. In some cases, both could work, so highlight the things that should be considered to decide between them. |
||
|
||
# Register all of the callback methods for this instance. | ||
for func_name in dir(self): | ||
func = getattr(self, func_name) | ||
metadata = getattr(func, "_ext_metadata", None) | ||
if metadata is not None and isinstance(metadata, CallbackMetadata): | ||
if metadata.callback_type not in self._callbacks: | ||
self._callbacks[metadata.callback_type] = [] | ||
self._callbacks[metadata.callback_type].append(CallbackFunctionWithMetadata(metadata, func)) | ||
if metadata is not None: | ||
if isinstance(metadata, CallbackMetadata): | ||
if metadata.callback_type not in self._callbacks: | ||
self._callbacks[metadata.callback_type] = [] | ||
self._callbacks[metadata.callback_type].append(CallbackFunctionWithMetadata(metadata, func)) | ||
elif isinstance(metadata, OverrideMetadata): | ||
if metadata.override_type in self._overrides: | ||
raise RuntimeError( | ||
f"Override {metadata.override_type} defined multiple times in {type(self).__qualname__}" | ||
) | ||
self._overrides[metadata.override_type] = OverrideFunctionWithMetadata(metadata, func) | ||
|
||
def get_callbacks(self): | ||
return self._callbacks | ||
|
||
def get_overrides(self): | ||
return self._overrides | ||
|
||
@contextmanager | ||
def patch_extension(self, ctx: DenoiseContext): | ||
yield None | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
from __future__ import annotations | ||
|
||
from contextlib import ExitStack, contextmanager | ||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional | ||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional | ||
|
||
import torch | ||
from diffusers import UNet2DConditionModel | ||
|
@@ -11,7 +11,12 @@ | |
if TYPE_CHECKING: | ||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext | ||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType | ||
from invokeai.backend.stable_diffusion.extensions.base import CallbackFunctionWithMetadata, ExtensionBase | ||
from invokeai.backend.stable_diffusion.extension_override_type import ExtensionOverrideType | ||
from invokeai.backend.stable_diffusion.extensions.base import ( | ||
CallbackFunctionWithMetadata, | ||
ExtensionBase, | ||
OverrideFunctionWithMetadata, | ||
) | ||
|
||
|
||
class ExtensionsManager: | ||
|
@@ -21,11 +26,19 @@ def __init__(self, is_canceled: Optional[Callable[[], bool]] = None): | |
# A list of extensions in the order that they were added to the ExtensionsManager. | ||
self._extensions: List[ExtensionBase] = [] | ||
self._ordered_callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {} | ||
self._overrides: Dict[ExtensionOverrideType, OverrideFunctionWithMetadata] = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add unit tests for the new ExtensionBase/ExtensionsManager functionality given it's core role. It should be straightforward - you can use the existing tests for reference. I think we'd roughly want tests for each of the following:
|
||
|
||
def add_extension(self, extension: ExtensionBase): | ||
self._extensions.append(extension) | ||
self._regenerate_ordered_callbacks() | ||
|
||
for override_type, override in extension.get_overrides().items(): | ||
if override_type in self._overrides: | ||
raise RuntimeError( | ||
f"Override {override_type} already defined by {self._overrides[override_type].function.__qualname__}" | ||
) | ||
self._overrides[override_type] = override | ||
|
||
def _regenerate_ordered_callbacks(self): | ||
"""Regenerates self._ordered_callbacks. Intended to be called each time a new extension is added.""" | ||
self._ordered_callbacks = {} | ||
|
@@ -51,6 +64,16 @@ def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext | |
for cb in callbacks: | ||
cb.function(ctx) | ||
|
||
def run_override(self, override_type: ExtensionOverrideType, orig_function: Callable[..., Any], *args, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice to have typed function signatures for each override type given that the signatures are known and there aren't very many of them (instead of passing *args and **kwargs). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the reason for passing orig_function? If the orig_function needs to be called, it feels like those use cases could be solved with callbacks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At least to allow extension manager to run original non-overriden implementation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Including it in the function signature implies that the function / override should handle it. For the case that you're describing, I imagined that it would just look like this: from ... import unet_forward
class AnExtension(ExtensionBase):
@override(ExtensionOverrideType.UNET_FORWARD)
def custom_unet_forward(self, ...):
# Do some stuff...
unet_forward(...)
What do you think? |
||
if self._is_canceled and self._is_canceled(): | ||
raise CanceledException | ||
|
||
override = self._overrides.get(override_type, None) | ||
if override is not None: | ||
return override.function(orig_function, *args, **kwargs) | ||
else: | ||
return orig_function(*args, **kwargs) | ||
|
||
@contextmanager | ||
def patch_extensions(self, ctx: DenoiseContext): | ||
if self._is_canceled and self._is_canceled(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would we want this? Seems like it just opens the door for a bunch of messiness.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How else tiled denoise will be able to call original step function or callbacks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this will be easier to discuss in the context of the tiled denoise PR? It seems to me that if we can avoid passing the ext_manager down to callbacks/overrides then that would keep things quite a bit simpler.