From 39e10d894c44b73caea8c0e36776ff7b1f34b620 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Fri, 19 Jul 2024 23:17:01 +0300 Subject: [PATCH] Add invocation cancellation logic to patchers --- invokeai/backend/stable_diffusion/extensions_manager.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index f42a065e829..1cae2e42190 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -44,8 +44,6 @@ def _regenerate_ordered_callbacks(self): self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order) def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext): - # TODO: add to patchers too? - # and if so, should it be only in beginning of function or in for loop if self._is_canceled and self._is_canceled(): raise CanceledException @@ -55,6 +53,9 @@ def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext @contextmanager def patch_extensions(self, context: DenoiseContext): + if self._is_canceled and self._is_canceled(): + raise CanceledException + with ExitStack() as exit_stack: for ext in self._extensions: exit_stack.enter_context(ext.patch_extension(context)) @@ -63,5 +64,8 @@ def patch_extensions(self, context: DenoiseContext): @contextmanager def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): + if self._is_canceled and self._is_canceled(): + raise CanceledException + # TODO: create logic in PR with extension which uses it yield None