diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 2787074265c..6d8cde8bfa0 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -39,7 +39,7 @@ from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager import BaseModelType from invokeai.backend.model_patcher import ModelPatcher -from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless +from invokeai.backend.stable_diffusion import PipelineIntermediateState from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs from invokeai.backend.stable_diffusion.diffusers_pipeline import ( ControlNetData, @@ -62,6 +62,7 @@ from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt +from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES @@ -833,6 +834,10 @@ def step_callback(state: PipelineIntermediateState) -> None: if self.unet.freeu_config: ext_manager.add_extension(FreeUExt(self.unet.freeu_config)) + ### seamless + if self.unet.seamless_axes: + ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes)) + # context for loading additional models with ExitStack() as exit_stack: # later should be smth like: @@ -915,7 +920,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: ExitStack() as exit_stack, unet_info.model_on_device() as (model_state_dict, unet), ModelPatcher.apply_freeu(unet, self.unet.freeu_config), - set_seamless(unet, self.unet.seamless_axes), # FIXME + SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME # Apply the LoRA after unet has been moved to its target device for faster patching. ModelPatcher.apply_lora_unet( unet, diff --git a/invokeai/app/invocations/latents_to_image.py b/invokeai/app/invocations/latents_to_image.py index cc8a9c44a3f..35b8483f2cc 100644 --- a/invokeai/app/invocations/latents_to_image.py +++ b/invokeai/app/invocations/latents_to_image.py @@ -24,7 +24,7 @@ from invokeai.app.invocations.model import VAEField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.stable_diffusion import set_seamless +from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params from invokeai.backend.util.devices import TorchDevice @@ -59,7 +59,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: vae_info = context.models.load(self.vae.vae) assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny)) - with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: + with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae: assert isinstance(vae, (AutoencoderKL, AutoencoderTiny)) latents = latents.to(vae.device) if self.fp32: diff --git a/invokeai/backend/stable_diffusion/__init__.py b/invokeai/backend/stable_diffusion/__init__.py index 440cb4410ba..6a6f2ebc49c 100644 --- a/invokeai/backend/stable_diffusion/__init__.py +++ b/invokeai/backend/stable_diffusion/__init__.py @@ -7,11 +7,9 @@ StableDiffusionGeneratorPipeline, ) from invokeai.backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent # noqa: F401 -from invokeai.backend.stable_diffusion.seamless import set_seamless # noqa: F401 __all__ = [ "PipelineIntermediateState", "StableDiffusionGeneratorPipeline", "InvokeAIDiffuserComponent", - "set_seamless", ] diff --git a/invokeai/backend/stable_diffusion/extensions/seamless.py b/invokeai/backend/stable_diffusion/extensions/seamless.py new file mode 100644 index 00000000000..a96ea6e4d2e --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/seamless.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import Callable, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from diffusers import UNet2DConditionModel +from diffusers.models.lora import LoRACompatibleConv + +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase + + +class SeamlessExt(ExtensionBase): + def __init__( + self, + seamless_axes: List[str], + ): + super().__init__() + self._seamless_axes = seamless_axes + + @contextmanager + def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): + with self.static_patch_model( + model=unet, + seamless_axes=self._seamless_axes, + ): + yield + + @staticmethod + @contextmanager + def static_patch_model( + model: torch.nn.Module, + seamless_axes: List[str], + ): + if not seamless_axes: + yield + return + + x_mode = "circular" if "x" in seamless_axes else "constant" + y_mode = "circular" if "y" in seamless_axes else "constant" + + # override conv_forward + # https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019 + def _conv_forward_asymmetric( + self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None + ): + self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0) + self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3]) + working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode) + working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode) + return torch.nn.functional.conv2d( + working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups + ) + + original_layers: List[Tuple[nn.Conv2d, Callable]] = [] + try: + for layer in model.modules(): + if not isinstance(layer, torch.nn.Conv2d): + continue + + if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None: + layer.lora_layer = lambda *x: 0 + original_layers.append((layer, layer._conv_forward)) + layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d) + + yield + + finally: + for layer, orig_conv_forward in original_layers: + layer._conv_forward = orig_conv_forward diff --git a/invokeai/backend/stable_diffusion/seamless.py b/invokeai/backend/stable_diffusion/seamless.py deleted file mode 100644 index 23ed978c6d0..00000000000 --- a/invokeai/backend/stable_diffusion/seamless.py +++ /dev/null @@ -1,51 +0,0 @@ -from contextlib import contextmanager -from typing import Callable, List, Optional, Tuple, Union - -import torch -import torch.nn as nn -from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL -from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny -from diffusers.models.lora import LoRACompatibleConv -from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel - - -@contextmanager -def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL, AutoencoderTiny], seamless_axes: List[str]): - if not seamless_axes: - yield - return - - # override conv_forward - # https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019 - def _conv_forward_asymmetric(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): - self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0) - self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3]) - working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode) - working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode) - return torch.nn.functional.conv2d( - working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups - ) - - original_layers: List[Tuple[nn.Conv2d, Callable]] = [] - - try: - x_mode = "circular" if "x" in seamless_axes else "constant" - y_mode = "circular" if "y" in seamless_axes else "constant" - - conv_layers: List[torch.nn.Conv2d] = [] - - for module in model.modules(): - if isinstance(module, torch.nn.Conv2d): - conv_layers.append(module) - - for layer in conv_layers: - if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None: - layer.lora_layer = lambda *x: 0 - original_layers.append((layer, layer._conv_forward)) - layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d) - - yield - - finally: - for layer, orig_conv_forward in original_layers: - layer._conv_forward = orig_conv_forward