Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modular backend - Seamless #6651

Merged
merged 4 commits into from
Jul 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion import PipelineIntermediateState
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
Expand All @@ -62,6 +62,7 @@
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
Expand Down Expand Up @@ -833,6 +834,10 @@ def step_callback(state: PipelineIntermediateState) -> None:
if self.unet.freeu_config:
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))

### seamless
if self.unet.seamless_axes:
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))

# context for loading additional models
with ExitStack() as exit_stack:
# later should be smth like:
Expand Down Expand Up @@ -915,7 +920,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
ExitStack() as exit_stack,
unet_info.model_on_device() as (model_state_dict, unet),
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
set_seamless(unet, self.unet.seamless_axes), # FIXME
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(
unet,
Expand Down
4 changes: 2 additions & 2 deletions invokeai/app/invocations/latents_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion import set_seamless
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice

Expand Down Expand Up @@ -59,7 +59,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:

vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(vae.device)
if self.fp32:
Expand Down
2 changes: 0 additions & 2 deletions invokeai/backend/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
StableDiffusionGeneratorPipeline,
)
from invokeai.backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent # noqa: F401
from invokeai.backend.stable_diffusion.seamless import set_seamless # noqa: F401

__all__ = [
"PipelineIntermediateState",
"StableDiffusionGeneratorPipeline",
"InvokeAIDiffuserComponent",
"set_seamless",
]
71 changes: 71 additions & 0 deletions invokeai/backend/stable_diffusion/extensions/seamless.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from __future__ import annotations

from contextlib import contextmanager
from typing import Callable, Dict, List, Optional, Tuple

import torch
import torch.nn as nn
from diffusers import UNet2DConditionModel
from diffusers.models.lora import LoRACompatibleConv

from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase


class SeamlessExt(ExtensionBase):
def __init__(
self,
seamless_axes: List[str],
):
super().__init__()
self._seamless_axes = seamless_axes

@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
with self.static_patch_model(
model=unet,
seamless_axes=self._seamless_axes,
):
yield

@staticmethod
@contextmanager
def static_patch_model(
model: torch.nn.Module,
seamless_axes: List[str],
):
if not seamless_axes:
yield
return

x_mode = "circular" if "x" in seamless_axes else "constant"
y_mode = "circular" if "y" in seamless_axes else "constant"

# override conv_forward
# https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019
def _conv_forward_asymmetric(
self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
):
self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
return torch.nn.functional.conv2d(
working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups
)

original_layers: List[Tuple[nn.Conv2d, Callable]] = []
try:
for layer in model.modules():
if not isinstance(layer, torch.nn.Conv2d):
continue

if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
layer.lora_layer = lambda *x: 0
original_layers.append((layer, layer._conv_forward))
layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d)

yield

finally:
for layer, orig_conv_forward in original_layers:
layer._conv_forward = orig_conv_forward
51 changes: 0 additions & 51 deletions invokeai/backend/stable_diffusion/seamless.py

This file was deleted.

Loading