Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Base of modular backend #6606

Merged
merged 27 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9cc852c
Base code from draft PR
StAlKeR7779 Jul 12, 2024
0bc6037
A bit rework conditioning convert to unet kwargs
StAlKeR7779 Jul 12, 2024
87e96e1
Rename modifiers to callbacks, convert order to int, a bit unify inje…
StAlKeR7779 Jul 12, 2024
bd8ae5d
Simplify guidance modes
StAlKeR7779 Jul 12, 2024
3a9dda9
Renames
StAlKeR7779 Jul 12, 2024
7e00526
Remove overrides logic for now
StAlKeR7779 Jul 12, 2024
e961dd1
Remove remains of priority logic
StAlKeR7779 Jul 12, 2024
499e4d4
Add preview extension to check logic
StAlKeR7779 Jul 12, 2024
d623bd4
Fix condtionings logic
StAlKeR7779 Jul 15, 2024
fd8d1c1
Remove 'del' operator overload
StAlKeR7779 Jul 15, 2024
9f088d1
Multiple small fixes
StAlKeR7779 Jul 15, 2024
608cbe3
Separate inputs in denoise context
StAlKeR7779 Jul 16, 2024
cec345c
Change attention processor apply logic
StAlKeR7779 Jul 16, 2024
b7c6c63
Added some comments
StAlKeR7779 Jul 16, 2024
cd1bc15
Rename sequential as private variable
StAlKeR7779 Jul 17, 2024
ae6d4fb
Move out _concat_conditionings_for_batch submethods
StAlKeR7779 Jul 17, 2024
03e22c2
Convert conditioning_mode to enum
StAlKeR7779 Jul 17, 2024
137202b
Remove patch_unet logic for now
StAlKeR7779 Jul 17, 2024
79e35bd
Minor fixes
StAlKeR7779 Jul 17, 2024
2c2ec8f
Comments, a bit refactor
StAlKeR7779 Jul 17, 2024
3f79467
Ruff format
StAlKeR7779 Jul 17, 2024
2ef3b49
Add run cancelling logic to extension manager
StAlKeR7779 Jul 17, 2024
710dc6b
Merge branch 'main' into stalker7779/backend_base
StAlKeR7779 Jul 17, 2024
0c56d4a
Ryan's suggested changes to extension manager/extensions
StAlKeR7779 Jul 18, 2024
83a86ab
Add unit tests for ExtensionsManager and ExtensionBase.
RyanJDick Jul 19, 2024
39e10d8
Add invocation cancellation logic to patchers
StAlKeR7779 Jul 19, 2024
78d2b1b
Merge branch 'main' into stalker-backend_base
RyanJDick Jul 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 111 additions & 7 deletions invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import inspect
import os
from contextlib import ExitStack
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

Expand Down Expand Up @@ -39,6 +40,7 @@
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
StableDiffusionGeneratorPipeline,
Expand All @@ -53,6 +55,10 @@
TextConditioningData,
TextConditioningRegions,
)
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
from invokeai.backend.stable_diffusion.extensions import PreviewExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.util.devices import TorchDevice
Expand Down Expand Up @@ -314,9 +320,10 @@
context: InvocationContext,
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
unet: UNet2DConditionModel,
latent_height: int,
latent_width: int,
device: torch.device,
dtype: torch.dtype,
cfg_scale: float | list[float],
steps: int,
cfg_rescale_multiplier: float,
Expand All @@ -330,25 +337,25 @@
uncond_list = [uncond_list]

cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
cond_list, context, unet.device, unet.dtype
cond_list, context, device, dtype
)
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
uncond_list, context, unet.device, unet.dtype
uncond_list, context, device, dtype
)

cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=cond_text_embeddings,
masks=cond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
dtype=dtype,
)
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
dtype=dtype,
)

if isinstance(cfg_scale, list):
Expand Down Expand Up @@ -707,9 +714,105 @@

return seed, noise, latents

def invoke(self, context: InvocationContext) -> LatentsOutput:
if os.environ.get("USE_MODULAR_DENOISE", False):
return self._new_invoke(context)
else:
return self._old_invoke(context)

@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def invoke(self, context: InvocationContext) -> LatentsOutput:
def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
with ExitStack() as exit_stack:

Check failure on line 726 in invokeai/app/invocations/denoise_latents.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (F841)

invokeai/app/invocations/denoise_latents.py:726:29: F841 Local variable `exit_stack` is assigned to but never used

Check failure on line 726 in invokeai/app/invocations/denoise_latents.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (F841)

invokeai/app/invocations/denoise_latents.py:726:29: F841 Local variable `exit_stack` is assigned to but never used
ext_manager = ExtensionsManager()

device = TorchDevice.choose_torch_device()
dtype = TorchDevice.choose_torch_dtype()

seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
latents = latents.to(device=device, dtype=dtype)
if noise is not None:
noise = noise.to(device=device, dtype=dtype)

_, _, latent_height, latent_width = latents.shape

conditioning_data = self.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
cfg_scale=self.cfg_scale,
steps=self.steps,
latent_height=latent_height,
latent_width=latent_width,
device=device,
dtype=dtype,
# TODO: old backend, remove
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)

scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)

timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
scheduler,
seed=seed,
device=device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
)

denoise_ctx = DenoiseContext(
latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise,
seed=seed,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
unet=None,
scheduler=scheduler,
)

### preview
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)

ext_manager.add_extension(PreviewExt(step_callback))

# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)
StAlKeR7779 marked this conversation as resolved.
Show resolved Hide resolved

# ext: t2i/ip adapter
ext_manager.callbacks.setup(denoise_ctx, ext_manager)

unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
unet_info.model_on_device() as (model_state_dict, unet),
# ext: controlnet
ext_manager.patch_attention_processor(unet, CustomAttnProcessor2_0),
# ext: freeu, seamless, ip adapter, lora
ext_manager.patch_unet(model_state_dict, unet),
):
sd_backend = StableDiffusionBackend(unet, scheduler)
denoise_ctx.unet = unet
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu") # TODO: detach?
StAlKeR7779 marked this conversation as resolved.
Show resolved Hide resolved
TorchDevice.empty_cache()

name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)

@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)

mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
Expand Down Expand Up @@ -788,7 +891,8 @@
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
device=unet.device,
dtype=unet.dtype,
latent_height=latent_height,
latent_width=latent_width,
cfg_scale=self.cfg_scale,
Expand Down
60 changes: 60 additions & 0 deletions invokeai/backend/stable_diffusion/denoise_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union

import torch
from diffusers import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput

if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData


@dataclass
class UNetKwargs:
StAlKeR7779 marked this conversation as resolved.
Show resolved Hide resolved
sample: torch.Tensor
timestep: Union[torch.Tensor, float, int]
encoder_hidden_states: torch.Tensor

class_labels: Optional[torch.Tensor] = None
timestep_cond: Optional[torch.Tensor] = None
attention_mask: Optional[torch.Tensor] = None
cross_attention_kwargs: Optional[Dict[str, Any]] = None
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None
mid_block_additional_residual: Optional[torch.Tensor] = None
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None
encoder_attention_mask: Optional[torch.Tensor] = None
# return_dict: bool = True


@dataclass
class DenoiseContext:
StAlKeR7779 marked this conversation as resolved.
Show resolved Hide resolved
latents: torch.Tensor
scheduler_step_kwargs: dict[str, Any]
conditioning_data: TextConditioningData
noise: Optional[torch.Tensor]
seed: int
timesteps: torch.Tensor
init_timestep: torch.Tensor
StAlKeR7779 marked this conversation as resolved.
Show resolved Hide resolved

scheduler: SchedulerMixin
unet: Optional[UNet2DConditionModel] = None

orig_latents: Optional[torch.Tensor] = None
step_index: Optional[int] = None
timestep: Optional[torch.Tensor] = None
unet_kwargs: Optional[UNetKwargs] = None
step_output: Optional[SchedulerOutput] = None

latent_model_input: Optional[torch.Tensor] = None
conditioning_mode: Optional[str] = None
negative_noise_pred: Optional[torch.Tensor] = None
positive_noise_pred: Optional[torch.Tensor] = None
noise_pred: Optional[torch.Tensor] = None

extra: dict = field(default_factory=dict)

def __delattr__(self, name: str):
setattr(self, name, None)
StAlKeR7779 marked this conversation as resolved.
Show resolved Hide resolved
18 changes: 9 additions & 9 deletions invokeai/backend/stable_diffusion/diffusers_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,19 @@
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
from invokeai.backend.stable_diffusion.extensions import PipelineIntermediateState
from invokeai.backend.util.attention import auto_detect_slice_size
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel


@dataclass
class PipelineIntermediateState:
step: int
order: int
total_steps: int
timestep: int
latents: torch.Tensor
predicted_original: Optional[torch.Tensor] = None
# @dataclass
# class PipelineIntermediateState:
# step: int
# order: int
# total_steps: int
# timestep: int
# latents: torch.Tensor
# predicted_original: Optional[torch.Tensor] = None
RyanJDick marked this conversation as resolved.
Show resolved Hide resolved


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData


@dataclass
Expand Down Expand Up @@ -103,7 +104,7 @@ def __init__(
uncond_regions: Optional[TextConditioningRegions],
cond_regions: Optional[TextConditioningRegions],
guidance_scale: Union[float, List[float]],
guidance_rescale_multiplier: float = 0,
guidance_rescale_multiplier: float = 0, # TODO: old backend, remove
):
self.uncond_text = uncond_text
self.cond_text = cond_text
Expand All @@ -114,10 +115,96 @@ def __init__(
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
self.guidance_scale = guidance_scale
# TODO: old backend, remove
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
self.guidance_rescale_multiplier = guidance_rescale_multiplier

def is_sdxl(self):
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
return isinstance(self.cond_text, SDXLConditioningInfo)

def to_unet_kwargs(self, unet_kwargs, conditioning_mode):
StAlKeR7779 marked this conversation as resolved.
Show resolved Hide resolved
_, _, h, w = unet_kwargs.sample.shape
device = unet_kwargs.sample.device
dtype = unet_kwargs.sample.dtype

# TODO: combine regions with conditionings
if conditioning_mode == "both":
conditionings = [self.uncond_text.embeds, self.cond_text.embeds]
c_regions = [self.uncond_regions, self.cond_regions]
elif conditioning_mode == "positive":
conditionings = [self.cond_text.embeds]
c_regions = [self.cond_regions]
else:
conditionings = [self.uncond_text.embeds]
c_regions = [self.uncond_regions]

encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(conditionings)

unet_kwargs.encoder_hidden_states = encoder_hidden_states
unet_kwargs.encoder_attention_mask = encoder_attention_mask

if self.is_sdxl():
added_cond_kwargs = dict( # noqa: C408
text_embeds=torch.cat([c.pooled_embeds for c in conditionings]),
time_ids=torch.cat([c.add_time_ids for c in conditionings]),
)

unet_kwargs.added_cond_kwargs = added_cond_kwargs

if any(r is not None for r in c_regions):
tmp_regions = []
for c, r in zip(conditionings, c_regions, strict=True):
if r is None:
r = TextConditioningRegions(
masks=torch.ones((1, 1, h, w), dtype=dtype),
ranges=[Range(start=0, end=c.embeds.shape[1])],
)
tmp_regions.append(r)

if unet_kwargs.cross_attention_kwargs is None:
unet_kwargs.cross_attention_kwargs = {}

unet_kwargs.cross_attention_kwargs.update(
regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype),
)

def _concat_conditionings_for_batch(self, conditionings):
StAlKeR7779 marked this conversation as resolved.
Show resolved Hide resolved
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int):
return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim)

def _pad_conditioning(cond, target_len, encoder_attention_mask):
conditioning_attention_mask = torch.ones(
(cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype
)

if cond.shape[1] < max_len:
conditioning_attention_mask = _pad_zeros(
conditioning_attention_mask,
pad_shape=(cond.shape[0], max_len - cond.shape[1]),
dim=1,
)

cond = _pad_zeros(
cond,
pad_shape=(cond.shape[0], max_len - cond.shape[1], cond.shape[2]),
dim=1,
)

if encoder_attention_mask is None:
encoder_attention_mask = conditioning_attention_mask
else:
encoder_attention_mask = torch.cat([encoder_attention_mask, conditioning_attention_mask])

return cond, encoder_attention_mask

encoder_attention_mask = None
max_len = max([c.shape[1] for c in conditionings])
if any(c.shape[1] != max_len for c in conditionings):
for i in range(len(conditionings)):
conditionings[i], encoder_attention_mask = _pad_conditioning(
conditionings[i], max_len, encoder_attention_mask
)

return torch.cat(conditionings), encoder_attention_mask
Loading
Loading