-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modular backend - overrides #6692
base: main
Are you sure you want to change the base?
Modular backend - overrides #6692
Conversation
# return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale) | ||
return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred) | ||
|
||
# pass extensions manager as arg to allow override access it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would we want this? Seems like it just opens the door for a bunch of messiness.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How else tiled denoise will be able to call original step function or callbacks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this will be easier to discuss in the context of the tiled denoise PR? It seems to me that if we can avoid passing the ext_manager down to callbacks/overrides then that would keep things quite a bit simpler.
@@ -51,6 +64,16 @@ def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext | |||
for cb in callbacks: | |||
cb.function(ctx) | |||
|
|||
def run_override(self, override_type: ExtensionOverrideType, orig_function: Callable[..., Any], *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to have typed function signatures for each override type given that the signatures are known and there aren't very many of them (instead of passing *args and **kwargs).
@@ -51,6 +64,16 @@ def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext | |||
for cb in callbacks: | |||
cb.function(ctx) | |||
|
|||
def run_override(self, override_type: ExtensionOverrideType, orig_function: Callable[..., Any], *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reason for passing orig_function? If the orig_function needs to be called, it feels like those use cases could be solved with callbacks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At least to allow extension manager to run original non-overriden implementation.
And also simply because it more flexible, you don't need to implement underlying logic if you only patch it slightly.
Also tiled decode will use orig function of step on each tile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Including it in the function signature implies that the function / override should handle it.
For the case that you're describing, I imagined that it would just look like this:
from ... import unet_forward
class AnExtension(ExtensionBase):
@override(ExtensionOverrideType.UNET_FORWARD)
def custom_unet_forward(self, ...):
# Do some stuff...
unet_forward(...)
What do you think?
self._callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {} | ||
self._overrides: Dict[ExtensionOverrideType, OverrideFunctionWithMetadata] = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add docs explaining the difference between _overrides and _callbacks. Include guidance for developers for how they should decide between using a callback vs. an override. In some cases, both could work, so highlight the things that should be considered to decide between them.
@@ -21,11 +26,19 @@ def __init__(self, is_canceled: Optional[Callable[[], bool]] = None): | |||
# A list of extensions in the order that they were added to the ExtensionsManager. | |||
self._extensions: List[ExtensionBase] = [] | |||
self._ordered_callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {} | |||
self._overrides: Dict[ExtensionOverrideType, OverrideFunctionWithMetadata] = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add unit tests for the new ExtensionBase/ExtensionsManager functionality given it's core role. It should be straightforward - you can use the existing tests for reference. I think we'd roughly want tests for each of the following:
- When an override is registered, it get's called
- Calling an override type with no override registered behaves as expected
- When duplicate overrides are registered, a meaningful error is raised
Summary
Initial implementation of overrides in modular backend, should be used in inpaint and tiled extensions(also in preview extension after preview event rewrite).
Created PR now just to have ability to discuss.
To be precise - need to decide how better to implement arguments in overrides.
Related Issues / Discussions
#6606
https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d
QA Instructions
Run with set
USE_MODULAR_DENOISE
environment.Merge Plan
Discuss, then merge.
Checklist