From 46d83a3026f1811d6ac60deb58efb38712699475 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 13 Feb 2024 14:24:46 -0500 Subject: [PATCH 01/21] Add a MaskField primitive, and add a mask to the ConditioningField primitive type. --- invokeai/app/invocations/conditioning.py | 41 ++++++++++++++++++++++++ invokeai/app/invocations/fields.py | 12 ++++++- 2 files changed, 52 insertions(+), 1 deletion(-) create mode 100644 invokeai/app/invocations/conditioning.py diff --git a/invokeai/app/invocations/conditioning.py b/invokeai/app/invocations/conditioning.py new file mode 100644 index 00000000000..9579d80009f --- /dev/null +++ b/invokeai/app/invocations/conditioning.py @@ -0,0 +1,41 @@ +import numpy as np +import torch +from PIL.Image import Image + +from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, InvocationContext, invocation +from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField + + +@invocation( + "add_conditioning_mask", + title="Add Conditioning Mask", + tags=["conditioning"], + category="conditioning", + version="1.0.0", +) +class AddConditioningMaskInvocation(BaseInvocation): + """Add a mask to an existing conditioning tensor.""" + + conditioning: ConditioningField = InputField(description="The conditioning tensor to add a mask to.") + image: ImageField = InputField( + description="A mask image to add to the conditioning tensor. Only the first channel of the image is used. " + "Pixels <128 are excluded from the mask, pixels >=128 are included in the mask." + ) + + @staticmethod + def convert_image_to_mask(image: Image) -> torch.Tensor: + """Convert a PIL image to a uint8 mask tensor.""" + np_image = np.array(image) + torch_image = torch.from_numpy(np_image[0, :, :]) + mask = torch_image >= 128 + return mask.to(dtype=torch.uint8) + + def invoke(self, context: InvocationContext) -> ConditioningOutput: + image = context.services.images.get_pil_image(self.image.image_name) + mask = self.convert_image_to_mask(image) + + mask_name = f"{context.graph_execution_state_id}__{self.id}_conditioning_mask" + context.services.latents.save(mask_name, mask) + + self.conditioning.mask_name = mask_name + return ConditioningOutput(conditioning=self.conditioning) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index d90c71a32de..56b9e12a6cd 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -203,6 +203,12 @@ class DenoiseMaskField(BaseModel): gradient: bool = Field(default=False, description="Used for gradient inpainting") +class MaskField(BaseModel): + """A mask primitive field.""" + + mask_name: str = Field(description="The name of a spatial mask. dtype: bool, shape: (1, h, w).") + + class LatentsField(BaseModel): """A latents tensor primitive field""" @@ -226,7 +232,11 @@ class ConditioningField(BaseModel): """A conditioning tensor primitive value""" conditioning_name: str = Field(description="The name of conditioning tensor") - # endregion + mask: Optional[MaskField] = Field( + default=None, + description="The bool mask associated with this conditioning tensor. Excluded regions should be set to False, " + "included regions should be set to True.", + ) class MetadataField(RootModel[dict[str, Any]]): From 1da8423f04936694ba1a554f82772facec9be458 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 8 Mar 2024 10:30:55 -0500 Subject: [PATCH 02/21] Add RectangleMaskInvocation. --- invokeai/app/invocations/conditioning.py | 41 ------------------------ invokeai/app/invocations/mask.py | 40 +++++++++++++++++++++++ invokeai/app/invocations/primitives.py | 11 +++++++ 3 files changed, 51 insertions(+), 41 deletions(-) delete mode 100644 invokeai/app/invocations/conditioning.py create mode 100644 invokeai/app/invocations/mask.py diff --git a/invokeai/app/invocations/conditioning.py b/invokeai/app/invocations/conditioning.py deleted file mode 100644 index 9579d80009f..00000000000 --- a/invokeai/app/invocations/conditioning.py +++ /dev/null @@ -1,41 +0,0 @@ -import numpy as np -import torch -from PIL.Image import Image - -from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, InvocationContext, invocation -from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField - - -@invocation( - "add_conditioning_mask", - title="Add Conditioning Mask", - tags=["conditioning"], - category="conditioning", - version="1.0.0", -) -class AddConditioningMaskInvocation(BaseInvocation): - """Add a mask to an existing conditioning tensor.""" - - conditioning: ConditioningField = InputField(description="The conditioning tensor to add a mask to.") - image: ImageField = InputField( - description="A mask image to add to the conditioning tensor. Only the first channel of the image is used. " - "Pixels <128 are excluded from the mask, pixels >=128 are included in the mask." - ) - - @staticmethod - def convert_image_to_mask(image: Image) -> torch.Tensor: - """Convert a PIL image to a uint8 mask tensor.""" - np_image = np.array(image) - torch_image = torch.from_numpy(np_image[0, :, :]) - mask = torch_image >= 128 - return mask.to(dtype=torch.uint8) - - def invoke(self, context: InvocationContext) -> ConditioningOutput: - image = context.services.images.get_pil_image(self.image.image_name) - mask = self.convert_image_to_mask(image) - - mask_name = f"{context.graph_execution_state_id}__{self.id}_conditioning_mask" - context.services.latents.save(mask_name, mask) - - self.conditioning.mask_name = mask_name - return ConditioningOutput(conditioning=self.conditioning) diff --git a/invokeai/app/invocations/mask.py b/invokeai/app/invocations/mask.py new file mode 100644 index 00000000000..e892a766c1c --- /dev/null +++ b/invokeai/app/invocations/mask.py @@ -0,0 +1,40 @@ +import torch + +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, + InvocationContext, + invocation, +) +from invokeai.app.invocations.fields import InputField, MaskField, WithMetadata +from invokeai.app.invocations.primitives import MaskOutput + + +@invocation( + "rectangle_mask", + title="Create Rectangle Mask", + tags=["conditioning"], + category="conditioning", + version="1.0.0", +) +class RectangleMaskInvocation(BaseInvocation, WithMetadata): + """Create a rectangular mask.""" + + height: int = InputField(description="The height of the entire mask.") + width: int = InputField(description="The width of the entire mask.") + y_top: int = InputField(description="The top y-coordinate of the rectangular masked region (inclusive).") + x_left: int = InputField(description="The left x-coordinate of the rectangular masked region (inclusive).") + rectangle_height: int = InputField(description="The height of the rectangular masked region.") + rectangle_width: int = InputField(description="The width of the rectangular masked region.") + + def invoke(self, context: InvocationContext) -> MaskOutput: + mask = torch.zeros((1, self.height, self.width), dtype=torch.bool) + mask[ + :, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width + ] = True + + mask_name = context.tensors.save(mask) + return MaskOutput( + mask=MaskField(mask_name=mask_name), + width=self.width, + height=self.height, + ) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index 6a8e4e4531d..25930f7d004 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -14,6 +14,7 @@ Input, InputField, LatentsField, + MaskField, OutputField, UIComponent, ) @@ -405,9 +406,19 @@ def invoke(self, context: InvocationContext) -> ColorOutput: # endregion + # region Conditioning +@invocation_output("mask_output") +class MaskOutput(BaseInvocationOutput): + """A torch mask tensor.""" + + mask: MaskField = OutputField(description="The mask.") + width: int = OutputField(description="The width of the mask in pixels.") + height: int = OutputField(description="The height of the mask in pixels.") + + @invocation_output("conditioning_output") class ConditioningOutput(BaseInvocationOutput): """Base class for nodes that output a single conditioning tensor""" From bf3ee1fefa57e786cb68ec5ddf7a1801bf626c92 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 8 Mar 2024 10:48:45 -0500 Subject: [PATCH 03/21] Update compel nodes to accept an optional prompt mask. --- invokeai/app/invocations/compel.py | 35 +++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index c23dd3d908e..6df3301362e 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -5,7 +5,15 @@ from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent +from invokeai.app.invocations.fields import ( + ConditioningField, + FieldDescriptions, + Input, + InputField, + MaskField, + OutputField, + UIComponent, +) from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.ti_utils import generate_ti_list @@ -36,7 +44,7 @@ title="Prompt", tags=["prompt", "compel"], category="conditioning", - version="1.1.1", + version="1.2.0", ) class CompelInvocation(BaseInvocation): """Parse prompt using compel package to conditioning.""" @@ -51,6 +59,9 @@ class CompelInvocation(BaseInvocation): description=FieldDescriptions.clip, input=Input.Connection, ) + mask: Optional[MaskField] = InputField( + default=None, description="A mask defining the region that this conditioning prompt applies to." + ) @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: @@ -117,8 +128,12 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: ) conditioning_name = context.conditioning.save(conditioning_data) - - return ConditioningOutput.build(conditioning_name) + return ConditioningOutput( + conditioning=ConditioningField( + conditioning_name=conditioning_name, + mask=self.mask, + ) + ) class SDXLPromptInvocationBase: @@ -232,7 +247,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: title="SDXL Prompt", tags=["sdxl", "compel", "prompt"], category="conditioning", - version="1.1.1", + version="1.2.0", ) class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -255,6 +270,9 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): target_height: int = InputField(default=1024, description="") clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1") clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2") + mask: Optional[MaskField] = InputField( + default=None, description="A mask defining the region that this conditioning prompt applies to." + ) @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: @@ -317,7 +335,12 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: conditioning_name = context.conditioning.save(conditioning_data) - return ConditioningOutput.build(conditioning_name) + return ConditioningOutput( + conditioning=ConditioningField( + conditioning_name=conditioning_name, + mask=self.mask, + ) + ) @invocation( From ef9e0c969b33faae466cbaebd42e87abf627c777 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 28 Feb 2024 12:15:39 -0500 Subject: [PATCH 04/21] Remove scheduler_args from ConditioningData structure. --- invokeai/app/invocations/latent.py | 28 ++++++++++--------- .../stable_diffusion/diffusers_pipeline.py | 7 ++++- .../diffusion/conditioning_data.py | 24 ++-------------- 3 files changed, 23 insertions(+), 36 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index bc79efdeba4..70233b8f67b 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -1,5 +1,6 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) +import inspect import math from contextlib import ExitStack from functools import singledispatchmethod @@ -368,9 +369,7 @@ def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]: def get_conditioning_data( self, context: InvocationContext, - scheduler: Scheduler, unet: UNet2DConditionModel, - seed: int, ) -> ConditioningData: positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name) c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) @@ -385,14 +384,6 @@ def get_conditioning_data( guidance_rescale_multiplier=self.cfg_rescale_multiplier, ) - conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME - scheduler, - # for ddim scheduler - eta=0.0, # ddim_eta - # for ancestral and sde schedulers - # flip all bits to have noise different from initial - generator=torch.Generator(device=unet.device).manual_seed(seed ^ 0xFFFFFFFF), - ) return conditioning_data def create_pipeline( @@ -636,6 +627,7 @@ def init_scheduler( steps: int, denoising_start: float, denoising_end: float, + seed: int, ) -> Tuple[int, List[int], int]: assert isinstance(scheduler, ConfigMixin) if scheduler.config.get("cpu_only", False): @@ -664,7 +656,15 @@ def init_scheduler( timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx] num_inference_steps = len(timesteps) // scheduler.order - return num_inference_steps, timesteps, init_timestep + scheduler_step_kwargs = {} + scheduler_step_signature = inspect.signature(scheduler.step) + if "generator" in scheduler_step_signature.parameters: + # At some point, someone decided that schedulers that accept a generator should use the original seed with + # all bits flipped. I don't know the original rationale for this, but now we must keep it like this for + # reproducibility. + scheduler_step_kwargs = {"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)} + + return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs def prep_inpaint_mask( self, context: InvocationContext, latents: torch.Tensor @@ -758,7 +758,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: ) pipeline = self.create_pipeline(unet, scheduler) - conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed) + conditioning_data = self.get_conditioning_data(context, unet) controlnet_data = self.prep_control_data( context=context, @@ -776,12 +776,13 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: exit_stack=exit_stack, ) - num_inference_steps, timesteps, init_timestep = self.init_scheduler( + num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler( scheduler, device=unet.device, steps=self.steps, denoising_start=self.denoising_start, denoising_end=self.denoising_end, + seed=seed, ) result_latents = pipeline.latents_from_embeddings( @@ -794,6 +795,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: masked_latents=masked_latents, gradient_mask=gradient_mask, num_inference_steps=num_inference_steps, + scheduler_step_kwargs=scheduler_step_kwargs, conditioning_data=conditioning_data, control_data=controlnet_data, ip_adapter_data=ip_adapter_data, diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index dae55a07517..80fe1a9c404 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -295,6 +295,7 @@ def latents_from_embeddings( self, latents: torch.Tensor, num_inference_steps: int, + scheduler_step_kwargs: dict[str, Any], conditioning_data: ConditioningData, *, noise: Optional[torch.Tensor], @@ -355,6 +356,7 @@ def latents_from_embeddings( latents, timesteps, conditioning_data, + scheduler_step_kwargs=scheduler_step_kwargs, additional_guidance=additional_guidance, control_data=control_data, ip_adapter_data=ip_adapter_data, @@ -381,6 +383,7 @@ def generate_latents_from_embeddings( latents: torch.Tensor, timesteps, conditioning_data: ConditioningData, + scheduler_step_kwargs: dict[str, Any], *, additional_guidance: List[Callable] = None, control_data: List[ControlNetData] = None, @@ -435,6 +438,7 @@ def generate_latents_from_embeddings( conditioning_data, step_index=i, total_step_count=len(timesteps), + scheduler_step_kwargs=scheduler_step_kwargs, additional_guidance=additional_guidance, control_data=control_data, ip_adapter_data=ip_adapter_data, @@ -466,6 +470,7 @@ def step( conditioning_data: ConditioningData, step_index: int, total_step_count: int, + scheduler_step_kwargs: dict[str, Any], additional_guidance: List[Callable] = None, control_data: List[ControlNetData] = None, ip_adapter_data: Optional[list[IPAdapterData]] = None, @@ -569,7 +574,7 @@ def step( ) # compute the previous noisy sample x_t -> x_t-1 - step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args) + step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs) # TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting again. for guidance in additional_guidance: diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 7a7f4624c82..597905481a6 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -1,7 +1,5 @@ -import dataclasses -import inspect -from dataclasses import dataclass, field -from typing import Any, List, Optional, Union +from dataclasses import dataclass +from typing import List, Optional, Union import torch @@ -71,23 +69,5 @@ class ConditioningData: ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) """ guidance_rescale_multiplier: float = 0 - scheduler_args: dict[str, Any] = field(default_factory=dict) ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None - - @property - def dtype(self): - return self.text_embeddings.dtype - - def add_scheduler_args_if_applicable(self, scheduler, **kwargs): - scheduler_args = dict(self.scheduler_args) - step_method = inspect.signature(scheduler.step) - for name, value in kwargs.items(): - try: - step_method.bind_partial(**{name: value}) - except TypeError: - # FIXME: don't silently discard arguments - pass # debug("%s does not accept argument named %r", scheduler, name) - else: - scheduler_args[name] = value - return dataclasses.replace(self, scheduler_args=scheduler_args) From 7fe6f034050697802d3e0ff2aec513ee1e840fcb Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 28 Feb 2024 13:49:02 -0500 Subject: [PATCH 05/21] Split ip_adapter_conditioning out from ConditioningData. --- invokeai/app/invocations/latent.py | 8 +------- .../stable_diffusion/diffusers_pipeline.py | 14 ++++++++++---- .../diffusion/conditioning_data.py | 2 -- .../diffusion/shared_invokeai_diffusion.py | 18 ++++++++++++------ 4 files changed, 23 insertions(+), 19 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 70233b8f67b..fba661671d6 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -488,7 +488,6 @@ def prep_ip_adapter_data( self, context: InvocationContext, ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]], - conditioning_data: ConditioningData, exit_stack: ExitStack, ) -> Optional[list[IPAdapterData]]: """If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings @@ -505,7 +504,6 @@ def prep_ip_adapter_data( return None ip_adapter_data_list = [] - conditioning_data.ip_adapter_conditioning = [] for single_ip_adapter in ip_adapter: ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( context.models.load(single_ip_adapter.ip_adapter_model) @@ -528,16 +526,13 @@ def prep_ip_adapter_data( single_ipa_images, image_encoder_model ) - conditioning_data.ip_adapter_conditioning.append( - IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds) - ) - ip_adapter_data_list.append( IPAdapterData( ip_adapter_model=ip_adapter_model, weight=single_ip_adapter.weight, begin_step_percent=single_ip_adapter.begin_step_percent, end_step_percent=single_ip_adapter.end_step_percent, + ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds), ) ) @@ -772,7 +767,6 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: ip_adapter_data = self.prep_ip_adapter_data( context=context, ip_adapter=self.ip_adapter, - conditioning_data=conditioning_data, exit_stack=exit_stack, ) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 80fe1a9c404..53b1ef5313d 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -23,7 +23,7 @@ from invokeai.app.services.config.config_default import get_config from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from invokeai.backend.util.attention import auto_detect_slice_size from invokeai.backend.util.devices import normalize_device @@ -151,10 +151,11 @@ class ControlNetData: @dataclass class IPAdapterData: - ip_adapter_model: IPAdapter = Field(default=None) - # TODO: change to polymorphic so can do different weights per step (once implemented...) + ip_adapter_model: IPAdapter + ip_adapter_conditioning: IPAdapterConditioningInfo + + # Either a single weight applied to all steps, or a list of weights for each step. weight: Union[float, List[float]] = Field(default=1.0) - # weight: float = Field(default=1.0) begin_step_percent: float = Field(default=0.0) end_step_percent: float = Field(default=1.0) @@ -549,12 +550,17 @@ def step( down_intrablock_additional_residuals = accum_adapter_state + ip_adapter_conditioning = None + if ip_adapter_data is not None: + ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data] + uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step( sample=latent_model_input, timestep=t, # TODO: debug how handled batched and non batched timesteps step_index=step_index, total_step_count=total_step_count, conditioning_data=conditioning_data, + ip_adapter_conditioning=ip_adapter_conditioning, down_block_additional_residuals=down_block_additional_residuals, # for ControlNet mid_block_additional_residual=mid_block_additional_residual, # for ControlNet down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 597905481a6..b00c56120dc 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -69,5 +69,3 @@ class ConditioningData: ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) """ guidance_rescale_multiplier: float = 0 - - ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index f55876623cd..657351e6c61 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -12,6 +12,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( ConditioningData, ExtraConditioningInfo, + IPAdapterConditioningInfo, SDXLConditioningInfo, ) @@ -199,6 +200,7 @@ def do_unet_step( sample: torch.Tensor, timestep: torch.Tensor, conditioning_data: ConditioningData, + ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]], step_index: int, total_step_count: int, down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet @@ -223,6 +225,7 @@ def do_unet_step( x=sample, sigma=timestep, conditioning_data=conditioning_data, + ip_adapter_conditioning=ip_adapter_conditioning, cross_attention_control_types_to_do=cross_attention_control_types_to_do, down_block_additional_residuals=down_block_additional_residuals, mid_block_additional_residual=mid_block_additional_residual, @@ -236,6 +239,7 @@ def do_unet_step( x=sample, sigma=timestep, conditioning_data=conditioning_data, + ip_adapter_conditioning=ip_adapter_conditioning, down_block_additional_residuals=down_block_additional_residuals, mid_block_additional_residual=mid_block_additional_residual, down_intrablock_additional_residuals=down_intrablock_additional_residuals, @@ -297,6 +301,7 @@ def _apply_standard_conditioning( x, sigma, conditioning_data: ConditioningData, + ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]], down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter @@ -308,14 +313,14 @@ def _apply_standard_conditioning( sigma_twice = torch.cat([sigma] * 2) cross_attention_kwargs = None - if conditioning_data.ip_adapter_conditioning is not None: + if ip_adapter_conditioning is not None: # Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len). cross_attention_kwargs = { "ip_adapter_image_prompt_embeds": [ torch.stack( [ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds] ) - for ipa_conditioning in conditioning_data.ip_adapter_conditioning + for ipa_conditioning in ip_adapter_conditioning ] } @@ -361,6 +366,7 @@ def _apply_standard_conditioning_sequentially( x: torch.Tensor, sigma, conditioning_data: ConditioningData, + ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]], cross_attention_control_types_to_do: list[CrossAttentionType], down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet @@ -411,12 +417,12 @@ def _apply_standard_conditioning_sequentially( cross_attention_kwargs = None # Prepare IP-Adapter cross-attention kwargs for the unconditioned pass. - if conditioning_data.ip_adapter_conditioning is not None: + if ip_adapter_conditioning is not None: # Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len). cross_attention_kwargs = { "ip_adapter_image_prompt_embeds": [ torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0) - for ipa_conditioning in conditioning_data.ip_adapter_conditioning + for ipa_conditioning in ip_adapter_conditioning ] } @@ -452,12 +458,12 @@ def _apply_standard_conditioning_sequentially( cross_attention_kwargs = None # Prepare IP-Adapter cross-attention kwargs for the conditioned pass. - if conditioning_data.ip_adapter_conditioning is not None: + if ip_adapter_conditioning is not None: # Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len). cross_attention_kwargs = { "ip_adapter_image_prompt_embeds": [ torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0) - for ipa_conditioning in conditioning_data.ip_adapter_conditioning + for ipa_conditioning in ip_adapter_conditioning ] } From 8923289b894802308605ffd7098cbc2e82105bea Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 8 Mar 2024 11:49:32 -0500 Subject: [PATCH 06/21] Rename ConditioningData -> TextConditioningData. --- invokeai/app/invocations/latent.py | 9 ++++++--- .../backend/stable_diffusion/diffusers_pipeline.py | 11 +++++++---- .../stable_diffusion/diffusion/conditioning_data.py | 2 +- .../diffusion/shared_invokeai_diffusion.py | 8 ++++---- 4 files changed, 18 insertions(+), 12 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index fba661671d6..be5ed915819 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -57,7 +57,10 @@ from invokeai.backend.model_manager import BaseModelType, LoadedModel from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + IPAdapterConditioningInfo, + TextConditioningData, +) from invokeai.backend.util.silence_warnings import SilenceWarnings from ...backend.stable_diffusion.diffusers_pipeline import ( @@ -370,14 +373,14 @@ def get_conditioning_data( self, context: InvocationContext, unet: UNet2DConditionModel, - ) -> ConditioningData: + ) -> TextConditioningData: positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name) c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name) uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) - conditioning_data = ConditioningData( + conditioning_data = TextConditioningData( unconditioned_embeddings=uc, text_embeddings=c, guidance_scale=self.cfg_scale, diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 53b1ef5313d..190cc9869f0 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -23,7 +23,10 @@ from invokeai.app.services.config.config_default import get_config from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + IPAdapterConditioningInfo, + TextConditioningData, +) from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from invokeai.backend.util.attention import auto_detect_slice_size from invokeai.backend.util.devices import normalize_device @@ -297,7 +300,7 @@ def latents_from_embeddings( latents: torch.Tensor, num_inference_steps: int, scheduler_step_kwargs: dict[str, Any], - conditioning_data: ConditioningData, + conditioning_data: TextConditioningData, *, noise: Optional[torch.Tensor], timesteps: torch.Tensor, @@ -383,7 +386,7 @@ def generate_latents_from_embeddings( self, latents: torch.Tensor, timesteps, - conditioning_data: ConditioningData, + conditioning_data: TextConditioningData, scheduler_step_kwargs: dict[str, Any], *, additional_guidance: List[Callable] = None, @@ -468,7 +471,7 @@ def step( self, t: torch.Tensor, latents: torch.Tensor, - conditioning_data: ConditioningData, + conditioning_data: TextConditioningData, step_index: int, total_step_count: int, scheduler_step_kwargs: dict[str, Any], diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index b00c56120dc..9ea8332db17 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -55,7 +55,7 @@ class IPAdapterConditioningInfo: @dataclass -class ConditioningData: +class TextConditioningData: unconditioned_embeddings: BasicConditioningInfo text_embeddings: BasicConditioningInfo """ diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 657351e6c61..5108521982c 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -10,10 +10,10 @@ from invokeai.app.services.config.config_default import get_config from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( - ConditioningData, ExtraConditioningInfo, IPAdapterConditioningInfo, SDXLConditioningInfo, + TextConditioningData, ) from .cross_attention_control import ( @@ -199,7 +199,7 @@ def do_unet_step( self, sample: torch.Tensor, timestep: torch.Tensor, - conditioning_data: ConditioningData, + conditioning_data: TextConditioningData, ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]], step_index: int, total_step_count: int, @@ -300,7 +300,7 @@ def _apply_standard_conditioning( self, x, sigma, - conditioning_data: ConditioningData, + conditioning_data: TextConditioningData, ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]], down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet @@ -365,7 +365,7 @@ def _apply_standard_conditioning_sequentially( self, x: torch.Tensor, sigma, - conditioning_data: ConditioningData, + conditioning_data: TextConditioningData, ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]], cross_attention_control_types_to_do: list[CrossAttentionType], down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet From b76bb45104920783cbb48aa9cce10613ac120dd4 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 8 Mar 2024 11:55:01 -0500 Subject: [PATCH 07/21] Improve documentation of conditioning_data.py. --- .../diffusion/conditioning_data.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 9ea8332db17..051b2fed1ff 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -8,6 +8,10 @@ @dataclass class ExtraConditioningInfo: + """Extra conditioning information produced by Compel. + This is used for prompt-to-prompt cross-attention control (a.k.a. `.swap()` in Compel). + """ + tokens_count_including_eos_bos: int cross_attention_control_args: Optional[Arguments] = None @@ -18,6 +22,8 @@ def wants_cross_attention_control(self): @dataclass class BasicConditioningInfo: + """SD 1/2 text conditioning information produced by Compel.""" + embeds: torch.Tensor extra_conditioning: Optional[ExtraConditioningInfo] @@ -33,6 +39,8 @@ class ConditioningFieldData: @dataclass class SDXLConditioningInfo(BasicConditioningInfo): + """SDXL text conditioning information produced by Compel.""" + pooled_embeds: torch.Tensor add_time_ids: torch.Tensor @@ -58,14 +66,11 @@ class IPAdapterConditioningInfo: class TextConditioningData: unconditioned_embeddings: BasicConditioningInfo text_embeddings: BasicConditioningInfo - """ - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). - Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate - images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - """ + # Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + # `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). + # Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate + # images that are closely linked to the text `prompt`, usually at the expense of lower image quality. guidance_scale: Union[float, List[float]] - """ for models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7 . - ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) - """ + # For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7. + # See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). guidance_rescale_multiplier: float = 0 From c059bc31628811be5d448de439a74587bef44544 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 8 Mar 2024 12:57:33 -0500 Subject: [PATCH 08/21] Add TextConditioningRegions to the TextConditioningData data structure. --- invokeai/app/invocations/latent.py | 6 +- .../stable_diffusion/diffusers_pipeline.py | 2 +- .../diffusion/conditioning_data.py | 58 +++++++++++++++---- .../diffusion/shared_invokeai_diffusion.py | 54 +++++++++-------- 4 files changed, 79 insertions(+), 41 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index be5ed915819..0d894dcee48 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -381,8 +381,10 @@ def get_conditioning_data( uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) conditioning_data = TextConditioningData( - unconditioned_embeddings=uc, - text_embeddings=c, + uncond_text=uc, + cond_text=c, + uncond_regions=None, + cond_regions=None, guidance_scale=self.cfg_scale, guidance_rescale_multiplier=self.cfg_rescale_multiplier, ) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 190cc9869f0..7ef93b0bcbf 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -405,7 +405,7 @@ def generate_latents_from_embeddings( return latents ip_adapter_unet_patcher = None - extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning + extra_conditioning_info = conditioning_data.cond_text.extra_conditioning if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: attn_ctx = self.invokeai_diffuser.custom_attention_context( self.invokeai_diffuser.model, diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 051b2fed1ff..6ef6d68fca7 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -63,14 +63,52 @@ class IPAdapterConditioningInfo: @dataclass +class Range: + start: int + end: int + + +class TextConditioningRegions: + def __init__( + self, + masks: torch.Tensor, + ranges: list[Range], + ): + # A binary mask indicating the regions of the image that the prompt should be applied to. + # Shape: (1, num_prompts, height, width) + # Dtype: torch.bool + self.masks = masks + + # A list of ranges indicating the start and end indices of the embeddings that corresponding mask applies to. + # ranges[i] contains the embedding range for the i'th prompt / mask. + self.ranges = ranges + + assert self.masks.shape[1] == len(self.ranges) + + class TextConditioningData: - unconditioned_embeddings: BasicConditioningInfo - text_embeddings: BasicConditioningInfo - # Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - # `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). - # Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate - # images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - guidance_scale: Union[float, List[float]] - # For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7. - # See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). - guidance_rescale_multiplier: float = 0 + def __init__( + self, + uncond_text: Union[BasicConditioningInfo, SDXLConditioningInfo], + cond_text: Union[BasicConditioningInfo, SDXLConditioningInfo], + uncond_regions: Optional[TextConditioningRegions], + cond_regions: Optional[TextConditioningRegions], + guidance_scale: Union[float, List[float]], + guidance_rescale_multiplier: float = 0, + ): + self.uncond_text = uncond_text + self.cond_text = cond_text + self.uncond_regions = uncond_regions + self.cond_regions = cond_regions + # Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + # `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). + # Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate + # images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + self.guidance_scale = guidance_scale + # For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7. + # See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + self.guidance_rescale_multiplier = guidance_rescale_multiplier + + def is_sdxl(self): + assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo) + return isinstance(self.cond_text, SDXLConditioningInfo) diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 5108521982c..46150d26218 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -12,7 +12,6 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( ExtraConditioningInfo, IPAdapterConditioningInfo, - SDXLConditioningInfo, TextConditioningData, ) @@ -91,7 +90,7 @@ def do_controlnet_step( timestep: torch.Tensor, step_index: int, total_step_count: int, - conditioning_data, + conditioning_data: TextConditioningData, ): down_block_res_samples, mid_block_res_sample = None, None @@ -124,28 +123,28 @@ def do_controlnet_step( added_cond_kwargs = None if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned - if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: + if conditioning_data.is_sdxl(): added_cond_kwargs = { - "text_embeds": conditioning_data.text_embeddings.pooled_embeds, - "time_ids": conditioning_data.text_embeddings.add_time_ids, + "text_embeds": conditioning_data.cond_text.pooled_embeds, + "time_ids": conditioning_data.cond_text.add_time_ids, } - encoder_hidden_states = conditioning_data.text_embeddings.embeds + encoder_hidden_states = conditioning_data.cond_text.embeds encoder_attention_mask = None else: - if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: + if conditioning_data.is_sdxl(): added_cond_kwargs = { "text_embeds": torch.cat( [ # TODO: how to pad? just by zeros? or even truncate? - conditioning_data.unconditioned_embeddings.pooled_embeds, - conditioning_data.text_embeddings.pooled_embeds, + conditioning_data.uncond_text.pooled_embeds, + conditioning_data.cond_text.pooled_embeds, ], dim=0, ), "time_ids": torch.cat( [ - conditioning_data.unconditioned_embeddings.add_time_ids, - conditioning_data.text_embeddings.add_time_ids, + conditioning_data.uncond_text.add_time_ids, + conditioning_data.cond_text.add_time_ids, ], dim=0, ), @@ -154,8 +153,8 @@ def do_controlnet_step( encoder_hidden_states, encoder_attention_mask, ) = self._concat_conditionings_for_batch( - conditioning_data.unconditioned_embeddings.embeds, - conditioning_data.text_embeddings.embeds, + conditioning_data.uncond_text.embeds, + conditioning_data.cond_text.embeds, ) if isinstance(control_datum.weight, list): # if controlnet has multiple weights, use the weight for the current step @@ -325,27 +324,27 @@ def _apply_standard_conditioning( } added_cond_kwargs = None - if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: + if conditioning_data.is_sdxl(): added_cond_kwargs = { "text_embeds": torch.cat( [ # TODO: how to pad? just by zeros? or even truncate? - conditioning_data.unconditioned_embeddings.pooled_embeds, - conditioning_data.text_embeddings.pooled_embeds, + conditioning_data.uncond_text.pooled_embeds, + conditioning_data.cond_text.pooled_embeds, ], dim=0, ), "time_ids": torch.cat( [ - conditioning_data.unconditioned_embeddings.add_time_ids, - conditioning_data.text_embeddings.add_time_ids, + conditioning_data.uncond_text.add_time_ids, + conditioning_data.cond_text.add_time_ids, ], dim=0, ), } both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch( - conditioning_data.unconditioned_embeddings.embeds, conditioning_data.text_embeddings.embeds + conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds ) both_results = self.model_forward_callback( x_twice, @@ -432,18 +431,17 @@ def _apply_standard_conditioning_sequentially( # Prepare SDXL conditioning kwargs for the unconditioned pass. added_cond_kwargs = None - is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo - if is_sdxl: + if conditioning_data.is_sdxl(): added_cond_kwargs = { - "text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds, - "time_ids": conditioning_data.unconditioned_embeddings.add_time_ids, + "text_embeds": conditioning_data.uncond_text.pooled_embeds, + "time_ids": conditioning_data.uncond_text.add_time_ids, } # Run unconditioned UNet denoising (i.e. negative prompt). unconditioned_next_x = self.model_forward_callback( x, sigma, - conditioning_data.unconditioned_embeddings.embeds, + conditioning_data.uncond_text.embeds, cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=uncond_down_block, mid_block_additional_residual=uncond_mid_block, @@ -474,17 +472,17 @@ def _apply_standard_conditioning_sequentially( # Prepare SDXL conditioning kwargs for the conditioned pass. added_cond_kwargs = None - if is_sdxl: + if conditioning_data.is_sdxl(): added_cond_kwargs = { - "text_embeds": conditioning_data.text_embeddings.pooled_embeds, - "time_ids": conditioning_data.text_embeddings.add_time_ids, + "text_embeds": conditioning_data.cond_text.pooled_embeds, + "time_ids": conditioning_data.cond_text.add_time_ids, } # Run conditioned UNet denoising (i.e. positive prompt). conditioned_next_x = self.model_forward_callback( x, sigma, - conditioning_data.text_embeddings.embeds, + conditioning_data.cond_text.embeds, cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=cond_down_block, mid_block_additional_residual=cond_mid_block, From 93056e4ab7b3adfa7c632b4895c76499a78d67c8 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 8 Mar 2024 13:42:35 -0500 Subject: [PATCH 09/21] Add support for lists of prompt embeddings to be passed to the DenoiseLatents invocation, and add handling of the conditioning region masks in DenoiseLatents. --- invokeai/app/invocations/latent.py | 185 +++++++++++++++++++++++++++-- 1 file changed, 172 insertions(+), 13 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 0d894dcee48..f2e1822c305 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -1,5 +1,4 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) - import inspect import math from contextlib import ExitStack @@ -10,6 +9,7 @@ import numpy as np import numpy.typing as npt import torch +import torchvision import torchvision.transforms as T from diffusers import AutoencoderKL, AutoencoderTiny from diffusers.configuration_utils import ConfigMixin @@ -58,8 +58,12 @@ from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + BasicConditioningInfo, IPAdapterConditioningInfo, + Range, + SDXLConditioningInfo, TextConditioningData, + TextConditioningRegions, ) from invokeai.backend.util.silence_warnings import SilenceWarnings @@ -288,10 +292,10 @@ def get_scheduler( class DenoiseLatentsInvocation(BaseInvocation): """Denoises noisy latents to decodable images""" - positive_conditioning: ConditioningField = InputField( + positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField( description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0 ) - negative_conditioning: ConditioningField = InputField( + negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField( description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1 ) noise: Optional[LatentsField] = InputField( @@ -369,26 +373,177 @@ def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]: raise ValueError("cfg_scale must be greater than 1") return v + def _get_text_embeddings_and_masks( + self, + cond_list: list[ConditioningField], + context: InvocationContext, + device: torch.device, + dtype: torch.dtype, + ) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]: + """Get the text embeddings and masks from the input conditioning fields.""" + text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = [] + text_embeddings_masks: list[Optional[torch.Tensor]] = [] + for cond in cond_list: + cond_data = context.conditioning.load(cond.conditioning_name) + text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype)) + + mask = cond.mask + if mask is not None: + mask = context.tensors.load(mask.mask_name) + text_embeddings_masks.append(mask) + + return text_embeddings, text_embeddings_masks + + def _preprocess_regional_prompt_mask( + self, mask: Optional[torch.Tensor], target_height: int, target_width: int + ) -> torch.Tensor: + """Preprocess a regional prompt mask to match the target height and width. + If mask is None, returns a mask of all ones with the target height and width. + If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation. + + Returns: + torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width). + """ + if mask is None: + return torch.ones((1, 1, target_height, target_width), dtype=torch.bool) + + tf = torchvision.transforms.Resize( + (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST + ) + mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w) + resized_mask = tf(mask) + return resized_mask + + def _concat_regional_text_embeddings( + self, + text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], + masks: Optional[list[Optional[torch.Tensor]]], + latent_height: int, + latent_width: int, + ) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]: + """Concatenate regional text embeddings into a single embedding and track the region masks accordingly.""" + if masks is None: + masks = [None] * len(text_conditionings) + assert len(text_conditionings) == len(masks) + + is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo + + all_masks_are_none = all(mask is None for mask in masks) + + text_embedding = [] + pooled_embedding = None + add_time_ids = None + cur_text_embedding_len = 0 + processed_masks = [] + embedding_ranges = [] + extra_conditioning = None + + for prompt_idx, text_embedding_info in enumerate(text_conditionings): + mask = masks[prompt_idx] + if ( + text_embedding_info.extra_conditioning is not None + and text_embedding_info.extra_conditioning.wants_cross_attention_control + ): + extra_conditioning = text_embedding_info.extra_conditioning + + if is_sdxl: + # We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for + # prompts without a mask. We prefer prompts without a mask, because they are more likely to contain + # global prompt information. In an ideal case, there should be exactly one global prompt without a + # mask, but we don't enforce this. + + # HACK(ryand): The fact that we have to choose a single pooled_embedding and add_time_ids here is a + # fundamental interface issue. The SDXL Compel nodes are not designed to be used in the way that we use + # them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single + # pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a + # pretty major breaking change to a popular node, so for now we use this hack. + if pooled_embedding is None or mask is None: + pooled_embedding = text_embedding_info.pooled_embeds + if add_time_ids is None or mask is None: + add_time_ids = text_embedding_info.add_time_ids + + text_embedding.append(text_embedding_info.embeds) + if not all_masks_are_none: + embedding_ranges.append( + Range( + start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1] + ) + ) + processed_masks.append(self._preprocess_regional_prompt_mask(mask, latent_height, latent_width)) + + cur_text_embedding_len += text_embedding_info.embeds.shape[1] + + text_embedding = torch.cat(text_embedding, dim=1) + assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len + + regions = None + if not all_masks_are_none: + regions = TextConditioningRegions( + masks=torch.cat(processed_masks, dim=1), + ranges=embedding_ranges, + ) + + if extra_conditioning is not None and len(text_conditionings) > 1: + raise ValueError( + "Prompt-to-prompt cross-attention control (a.k.a. `swap()`) is not supported when using multiple " + "prompts." + ) + + if is_sdxl: + return SDXLConditioningInfo( + embeds=text_embedding, + extra_conditioning=extra_conditioning, + pooled_embeds=pooled_embedding, + add_time_ids=add_time_ids, + ), regions + return BasicConditioningInfo( + embeds=text_embedding, + extra_conditioning=extra_conditioning, + ), regions + def get_conditioning_data( self, context: InvocationContext, unet: UNet2DConditionModel, + latent_height: int, + latent_width: int, ) -> TextConditioningData: - positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name) - c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) + # Normalize self.positive_conditioning and self.negative_conditioning to lists. + cond_list = self.positive_conditioning + if not isinstance(cond_list, list): + cond_list = [cond_list] + uncond_list = self.negative_conditioning + if not isinstance(uncond_list, list): + uncond_list = [uncond_list] + + cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks( + cond_list, context, unet.device, unet.dtype + ) + uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks( + uncond_list, context, unet.device, unet.dtype + ) - negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name) - uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) + cond_text_embedding, cond_regions = self._concat_regional_text_embeddings( + text_conditionings=cond_text_embeddings, + masks=cond_text_embedding_masks, + latent_height=latent_height, + latent_width=latent_width, + ) + uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings( + text_conditionings=uncond_text_embeddings, + masks=uncond_text_embedding_masks, + latent_height=latent_height, + latent_width=latent_width, + ) conditioning_data = TextConditioningData( - uncond_text=uc, - cond_text=c, - uncond_regions=None, - cond_regions=None, + uncond_text=uncond_text_embedding, + cond_text=cond_text_embedding, + uncond_regions=uncond_regions, + cond_regions=cond_regions, guidance_scale=self.cfg_scale, guidance_rescale_multiplier=self.cfg_rescale_multiplier, ) - return conditioning_data def create_pipeline( @@ -758,7 +913,11 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: ) pipeline = self.create_pipeline(unet, scheduler) - conditioning_data = self.get_conditioning_data(context, unet) + + _, _, latent_height, latent_width = latents.shape + conditioning_data = self.get_conditioning_data( + context=context, unet=unet, latent_height=latent_height, latent_width=latent_width + ) controlnet_data = self.prep_control_data( context=context, From dc90ff2a45656b8d2f43bfa13919bb3e1bb11798 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 8 Mar 2024 13:53:17 -0500 Subject: [PATCH 10/21] Add RegionalPromptData class for managing prompt region masks. --- .../diffusion/regional_prompt_data.py | 103 ++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py new file mode 100644 index 00000000000..95f81b1f93e --- /dev/null +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py @@ -0,0 +1,103 @@ +import torch +import torch.nn.functional as F + +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + TextConditioningRegions, +) + + +class RegionalPromptData: + """A class to manage the prompt data for regional conditioning.""" + + def __init__( + self, + regions: list[TextConditioningRegions], + device: torch.device, + dtype: torch.dtype, + max_downscale_factor: int = 8, + ): + """Initialize a `RegionalPromptData` object. + Args: + regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the + batch. + device (torch.device): The device to use for the attention masks. + dtype (torch.dtype): The data type to use for the attention masks. + max_downscale_factor: Spatial masks will be prepared for downscale factors from 1 to max_downscale_factor + in steps of 2x. + """ + self._regions = regions + self._device = device + self._dtype = dtype + # self._spatial_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query + # sequence length of s. + self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks( + regions, max_downscale_factor + ) + self._negative_cross_attn_mask_score = -10000.0 + + def _prepare_spatial_masks( + self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8 + ) -> list[dict[int, torch.Tensor]]: + """Prepare the spatial masks for all downscaling factors.""" + # batch_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query sequence length + # of s. + batch_sample_masks_by_seq_len: list[dict[int, torch.Tensor]] = [] + + for batch_sample_regions in regions: + batch_sample_masks_by_seq_len.append({}) + + # Convert the bool masks to float masks so that max pooling can be applied. + batch_sample_masks = batch_sample_regions.masks.to(device=self._device, dtype=self._dtype) + + # Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached. + downscale_factor = 1 + while downscale_factor <= max_downscale_factor: + b, _num_prompts, h, w = batch_sample_masks.shape + assert b == 1 + query_seq_len = h * w + + batch_sample_masks_by_seq_len[-1][query_seq_len] = batch_sample_masks + + downscale_factor *= 2 + if downscale_factor <= max_downscale_factor: + # We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt + # regions to be lost entirely. + # TODO(ryand): In the future, we may want to experiment with other downsampling methods (e.g. + # nearest interpolation), and could potentially use a weighted mask rather than a binary mask. + batch_sample_masks = F.max_pool2d(batch_sample_masks, kernel_size=2, stride=2) + + return batch_sample_masks_by_seq_len + + def get_cross_attn_mask(self, query_seq_len: int, key_seq_len: int) -> torch.Tensor: + """Get the cross-attention mask for the given query sequence length. + Args: + query_seq_len: The length of the flattened spatial features at the current downscaling level. + key_seq_len (int): The sequence length of the prompt embeddings (which act as the key in the cross-attention + layers). This is most likely equal to the max embedding range end, but we pass it explicitly to be sure. + Returns: + torch.Tensor: The cross-attention score mask. + shape: (batch_size, query_seq_len, key_seq_len). + dtype: float + """ + batch_size = len(self._spatial_masks_by_seq_len) + batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)] + + # Create an empty attention mask with the correct shape. + attn_mask = torch.zeros((batch_size, query_seq_len, key_seq_len), dtype=self._dtype, device=self._device) + + for batch_idx in range(batch_size): + batch_sample_spatial_masks = batch_spatial_masks[batch_idx] + batch_sample_regions = self._regions[batch_idx] + + # Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1). + _, num_prompts, _, _ = batch_sample_spatial_masks.shape + batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1)) + + for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges): + batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :].clone() + batch_sample_query_mask = batch_sample_query_scores > 0.5 + batch_sample_query_scores[batch_sample_query_mask] = 0.0 + batch_sample_query_scores[~batch_sample_query_mask] = self._negative_cross_attn_mask_score + attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores + + return attn_mask From b76720ffe13541bd948340151502f51fbffac7f6 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 15 Feb 2024 17:52:44 -0500 Subject: [PATCH 11/21] Initialize a RegionalPromptAttnProcessor2_0 class by copying AttnProcessor2_0 from diffusers. --- .../diffusion/custom_atttention.py | 85 +++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 invokeai/backend/stable_diffusion/diffusion/custom_atttention.py diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py new file mode 100644 index 00000000000..a528bb19069 --- /dev/null +++ b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py @@ -0,0 +1,85 @@ +from typing import Optional + +import torch +import torch.nn.functional as F +from diffusers.models.attention_processor import Attention, AttnProcessor2_0 +from diffusers.utils import USE_PEFT_BACKEND + + +class CustomAttnProcessor2_0(AttnProcessor2_0): + """An attention processor that supports regional prompt attention for PyTorch 2.0.""" + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + ) -> torch.FloatTensor: + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + args = () if USE_PEFT_BACKEND else (scale,) + query = attn.to_q(hidden_states, *args) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states, *args) + value = attn.to_v(encoder_hidden_states, *args) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, *args) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states From 203d4a660340451b0512ca4e898eb2e8d4cc2065 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 8 Mar 2024 14:03:33 -0500 Subject: [PATCH 12/21] Update CustomAttention to support both IP-Adapters and regional prompting. --- .../diffusion/custom_atttention.py | 122 +++++++++++++++++- 1 file changed, 120 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py index a528bb19069..47f81ff7aa7 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py @@ -5,9 +5,44 @@ from diffusers.models.attention_processor import Attention, AttnProcessor2_0 from diffusers.utils import USE_PEFT_BACKEND +from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights +from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData + class CustomAttnProcessor2_0(AttnProcessor2_0): - """An attention processor that supports regional prompt attention for PyTorch 2.0.""" + """A custom implementation of AttnProcessor2_0 that supports additional Invoke features. + This implementation is based on + https://github.com/huggingface/diffusers/blame/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204 + Supported custom features: + - IP-Adapter + - Regional prompt attention + """ + + def __init__( + self, + ip_adapter_weights: Optional[list[IPAttentionProcessorWeights]] = None, + ip_adapter_scales: Optional[list[float]] = None, + ): + """Initialize a CustomAttnProcessor2_0. + Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are + layer-specific are passed to __init__(). + Args: + ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights + for the i'th IP-Adapter. + ip_adapter_scales: The IP-Adapter attention scales. ip_adapter_scales[i] contains the attention scale for + the i'th IP-Adapter. + """ + super().__init__() + + self._ip_adapter_weights = ip_adapter_weights + self._ip_adapter_scales = ip_adapter_scales + + assert (self._ip_adapter_weights is None) == (self._ip_adapter_scales is None) + if self._ip_adapter_weights is not None: + assert len(ip_adapter_weights) == len(ip_adapter_scales) + + def _is_ip_adapter_enabled(self) -> bool: + return self._ip_adapter_weights is not None def __call__( self, @@ -17,7 +52,25 @@ def __call__( attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, + # For regional prompting: + regional_prompt_data: Optional[RegionalPromptData] = None, + percent_through: Optional[torch.FloatTensor] = None, + # For IP-Adapter: + ip_adapter_image_prompt_embeds: Optional[list[torch.Tensor]] = None, ) -> torch.FloatTensor: + """Apply attention. + Args: + regional_prompt_data: The regional prompt data for the current batch. If not None, this will be used to + apply regional prompt masking. + ip_adapter_image_prompt_embeds: The IP-Adapter image prompt embeddings for the current batch. + ip_adapter_image_prompt_embeds[i] contains the image prompt embeddings for the i'th IP-Adapter. Each + tensor has shape (batch_size, num_ip_images, seq_len, ip_embedding_len). + """ + # If true, we are doing cross-attention, if false we are doing self-attention. + is_cross_attention = encoder_hidden_states is not None + + # Start unmodified block from AttnProcessor2_0. + # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -31,7 +84,25 @@ def __call__( batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) - + # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + # End unmodified block from AttnProcessor2_0. + + # Handle regional prompt attention masks. + if regional_prompt_data is not None: + assert percent_through is not None + _, query_seq_len, _ = hidden_states.shape + if is_cross_attention: + prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask( + query_seq_len=query_seq_len, key_seq_len=sequence_length + ) + + if attention_mask is None: + attention_mask = prompt_region_attention_mask + else: + attention_mask = prompt_region_attention_mask + attention_mask + + # Start unmodified block from AttnProcessor2_0. + # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be @@ -68,7 +139,54 @@ def __call__( hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) + # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + # End unmodified block from AttnProcessor2_0. + + # Apply IP-Adapter conditioning. + if is_cross_attention and self._is_ip_adapter_enabled(): + if self._is_ip_adapter_enabled(): + assert ip_adapter_image_prompt_embeds is not None + for ipa_embed, ipa_weights, scale in zip( + ip_adapter_image_prompt_embeds, self._ip_adapter_weights, self._ip_adapter_scales, strict=True + ): + # The batch dimensions should match. + assert ipa_embed.shape[0] == encoder_hidden_states.shape[0] + # The token_len dimensions should match. + assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1] + + ip_hidden_states = ipa_embed + + # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding) + + ip_key = ipa_weights.to_k_ip(ip_hidden_states) + ip_value = ipa_weights.to_v_ip(ip_hidden_states) + + # Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim) + + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + # Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim) + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim) + + hidden_states = hidden_states + scale * ip_hidden_states + else: + # If IP-Adapter is not enabled, then ip_adapter_image_prompt_embeds should not be passed in. + assert ip_adapter_image_prompt_embeds is None + # Start unmodified block from AttnProcessor2_0. + # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv # linear proj hidden_states = attn.to_out[0](hidden_states, *args) # dropout From 787a085efc3fc229487649b19896c0eb1aaec91a Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 8 Mar 2024 14:15:16 -0500 Subject: [PATCH 13/21] Create a UNetAttentionPatcher for patching UNet models with CustomAttnProcessor2_0 modules. --- .../backend/ip_adapter/attention_processor.py | 182 ------------------ .../stable_diffusion/diffusers_pipeline.py | 6 +- .../diffusion/unet_attention_patcher.py} | 38 ++-- tests/backend/ip_adapter/test_ip_adapter.py | 4 +- 4 files changed, 25 insertions(+), 205 deletions(-) delete mode 100644 invokeai/backend/ip_adapter/attention_processor.py rename invokeai/backend/{ip_adapter/unet_patcher.py => stable_diffusion/diffusion/unet_attention_patcher.py} (53%) diff --git a/invokeai/backend/ip_adapter/attention_processor.py b/invokeai/backend/ip_adapter/attention_processor.py deleted file mode 100644 index 195cb12d1b8..00000000000 --- a/invokeai/backend/ip_adapter/attention_processor.py +++ /dev/null @@ -1,182 +0,0 @@ -# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0) -# and modified as needed - -# tencent-ailab comment: -# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py -import torch -import torch.nn as nn -import torch.nn.functional as F -from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0 - -from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights - - -# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict -# loading. -class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module): - def __init__(self): - DiffusersAttnProcessor2_0.__init__(self) - nn.Module.__init__(self) - - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - ip_adapter_image_prompt_embeds=None, - ): - """Re-definition of DiffusersAttnProcessor2_0.__call__(...) that accepts and ignores the - ip_adapter_image_prompt_embeds parameter. - """ - return DiffusersAttnProcessor2_0.__call__( - self, attn, hidden_states, encoder_hidden_states, attention_mask, temb - ) - - -class IPAttnProcessor2_0(torch.nn.Module): - r""" - Attention processor for IP-Adapater for PyTorch 2.0. - Args: - hidden_size (`int`): - The hidden size of the attention layer. - cross_attention_dim (`int`): - The number of channels in the `encoder_hidden_states`. - scale (`float`, defaults to 1.0): - the weight scale of image prompt. - """ - - def __init__(self, weights: list[IPAttentionProcessorWeights], scales: list[float]): - super().__init__() - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - assert len(weights) == len(scales) - - self._weights = weights - self._scales = scales - - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - ip_adapter_image_prompt_embeds=None, - ): - """Apply IP-Adapter attention. - - Args: - ip_adapter_image_prompt_embeds (torch.Tensor): The image prompt embeddings. - Shape: (batch_size, num_ip_images, seq_len, ip_embedding_len). - """ - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - # If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case, - # we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here. - assert ip_adapter_image_prompt_embeds is not None - assert len(ip_adapter_image_prompt_embeds) == len(self._weights) - - for ipa_embed, ipa_weights, scale in zip( - ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True - ): - # The batch dimensions should match. - assert ipa_embed.shape[0] == encoder_hidden_states.shape[0] - # The token_len dimensions should match. - assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1] - - ip_hidden_states = ipa_embed - - # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding) - - ip_key = ipa_weights.to_k_ip(ip_hidden_states) - ip_value = ipa_weights.to_v_ip(ip_hidden_states) - - # Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim) - - # TODO: add support for attn.scale when we move to Torch 2.1 - ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - # Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim) - - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - ip_hidden_states = ip_hidden_states.to(query.dtype) - - # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim) - - hidden_states = hidden_states + scale * ip_hidden_states - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 7ef93b0bcbf..2c765c0380d 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -22,12 +22,12 @@ from invokeai.app.services.config.config_default import get_config from invokeai.backend.ip_adapter.ip_adapter import IPAdapter -from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( IPAdapterConditioningInfo, TextConditioningData, ) from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent +from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher from invokeai.backend.util.attention import auto_detect_slice_size from invokeai.backend.util.devices import normalize_device @@ -415,7 +415,7 @@ def generate_latents_from_embeddings( elif ip_adapter_data is not None: # TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active? # As it is now, the IP-Adapter will silently be skipped. - ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data]) + ip_adapter_unet_patcher = UNetAttentionPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data]) attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) self.use_ip_adapter = True else: @@ -479,7 +479,7 @@ def step( control_data: List[ControlNetData] = None, ip_adapter_data: Optional[list[IPAdapterData]] = None, t2i_adapter_data: Optional[list[T2IAdapterData]] = None, - ip_adapter_unet_patcher: Optional[UNetPatcher] = None, + ip_adapter_unet_patcher: Optional[UNetAttentionPatcher] = None, ): # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value timestep = t[0] diff --git a/invokeai/backend/ip_adapter/unet_patcher.py b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py similarity index 53% rename from invokeai/backend/ip_adapter/unet_patcher.py rename to invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py index f8c1870f6ee..364ec18da42 100644 --- a/invokeai/backend/ip_adapter/unet_patcher.py +++ b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py @@ -1,52 +1,54 @@ from contextlib import contextmanager +from typing import Optional from diffusers.models import UNet2DConditionModel -from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0 from invokeai.backend.ip_adapter.ip_adapter import IPAdapter +from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0 -class UNetPatcher: - """A class that contains multiple IP-Adapters and can apply them to a UNet.""" +class UNetAttentionPatcher: + """A class for patching a UNet with CustomAttnProcessor2_0 attention layers.""" - def __init__(self, ip_adapters: list[IPAdapter]): + def __init__(self, ip_adapters: Optional[list[IPAdapter]]): self._ip_adapters = ip_adapters - self._scales = [1.0] * len(self._ip_adapters) + self._ip_adapter_scales = None + + if self._ip_adapters is not None: + self._ip_adapter_scales = [1.0] * len(self._ip_adapters) def set_scale(self, idx: int, value: float): - self._scales[idx] = value + self._ip_adapter_scales[idx] = value def _prepare_attention_processors(self, unet: UNet2DConditionModel): """Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention - weights into them. - + weights into them (if IP-Adapters are being applied). Note that the `unet` param is only used to determine attention block dimensions and naming. """ # Construct a dict of attention processors based on the UNet's architecture. attn_procs = {} for idx, name in enumerate(unet.attn_processors.keys()): - if name.endswith("attn1.processor"): - attn_procs[name] = AttnProcessor2_0() + if name.endswith("attn1.processor") or self._ip_adapters is None: + # "attn1" processors do not use IP-Adapters. + attn_procs[name] = CustomAttnProcessor2_0() else: # Collect the weights from each IP Adapter for the idx'th attention processor. - attn_procs[name] = IPAttnProcessor2_0( + attn_procs[name] = CustomAttnProcessor2_0( [ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters], - self._scales, + self._ip_adapter_scales, ) return attn_procs @contextmanager def apply_ip_adapter_attention(self, unet: UNet2DConditionModel): - """A context manager that patches `unet` with IP-Adapter attention processors.""" - + """A context manager that patches `unet` with CustomAttnProcessor2_0 attention layers.""" attn_procs = self._prepare_attention_processors(unet) - orig_attn_processors = unet.attn_processors try: - # Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from the - # passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy - # of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`. + # Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from + # the passed dict. So, if you wanted to keep the dict for future use, you'd have to make a + # moderately-shallow copy of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`. unet.set_attn_processor(attn_procs) yield None finally: diff --git a/tests/backend/ip_adapter/test_ip_adapter.py b/tests/backend/ip_adapter/test_ip_adapter.py index 9ed3c9bc507..138de398882 100644 --- a/tests/backend/ip_adapter/test_ip_adapter.py +++ b/tests/backend/ip_adapter/test_ip_adapter.py @@ -1,8 +1,8 @@ import pytest import torch -from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType +from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher from invokeai.backend.util.test_utils import install_and_load_model @@ -77,7 +77,7 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device): ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device) cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]} - ip_adapter_unet_patcher = UNetPatcher([ip_adapter]) + ip_adapter_unet_patcher = UNetAttentionPatcher([ip_adapter]) with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet): output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample From ee34091bdb9806895bba01fc2fd36731795c9e1f Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 8 Mar 2024 14:34:49 -0500 Subject: [PATCH 14/21] Update the diffusion logic to use the new regional prompting feature. --- .../stable_diffusion/diffusers_pipeline.py | 43 +++++---- .../diffusion/custom_atttention.py | 9 +- .../diffusion/shared_invokeai_diffusion.py | 90 +++++++++++++------ 3 files changed, 96 insertions(+), 46 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 2c765c0380d..c33f7b73705 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -404,22 +404,35 @@ def generate_latents_from_embeddings( if timesteps.shape[0] == 0: return latents - ip_adapter_unet_patcher = None extra_conditioning_info = conditioning_data.cond_text.extra_conditioning - if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: + use_cross_attention_control = ( + extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control + ) + use_ip_adapter = ip_adapter_data is not None + use_regional_prompting = ( + conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None + ) + if use_cross_attention_control and use_ip_adapter: + raise ValueError( + "Prompt-to-prompt cross-attention control (`.swap()`) and IP-Adapter cannot be used simultaneously." + ) + if use_cross_attention_control and use_regional_prompting: + raise ValueError( + "Prompt-to-prompt cross-attention control (`.swap()`) and regional prompting cannot be used simultaneously." + ) + + unet_attention_patcher = None + self.use_ip_adapter = use_ip_adapter + attn_ctx = nullcontext() + if use_cross_attention_control: attn_ctx = self.invokeai_diffuser.custom_attention_context( self.invokeai_diffuser.model, extra_conditioning_info=extra_conditioning_info, ) - self.use_ip_adapter = False - elif ip_adapter_data is not None: - # TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active? - # As it is now, the IP-Adapter will silently be skipped. - ip_adapter_unet_patcher = UNetAttentionPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data]) - attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) - self.use_ip_adapter = True - else: - attn_ctx = nullcontext() + if use_ip_adapter or use_regional_prompting: + ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None + unet_attention_patcher = UNetAttentionPatcher(ip_adapters) + attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) with attn_ctx: if callback is not None: @@ -447,7 +460,7 @@ def generate_latents_from_embeddings( control_data=control_data, ip_adapter_data=ip_adapter_data, t2i_adapter_data=t2i_adapter_data, - ip_adapter_unet_patcher=ip_adapter_unet_patcher, + unet_attention_patcher=unet_attention_patcher, ) latents = step_output.prev_sample predicted_original = getattr(step_output, "pred_original_sample", None) @@ -479,7 +492,7 @@ def step( control_data: List[ControlNetData] = None, ip_adapter_data: Optional[list[IPAdapterData]] = None, t2i_adapter_data: Optional[list[T2IAdapterData]] = None, - ip_adapter_unet_patcher: Optional[UNetAttentionPatcher] = None, + unet_attention_patcher: Optional[UNetAttentionPatcher] = None, ): # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value timestep = t[0] @@ -506,10 +519,10 @@ def step( ) if step_index >= first_adapter_step and step_index <= last_adapter_step: # Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range. - ip_adapter_unet_patcher.set_scale(i, weight) + unet_attention_patcher.set_scale(i, weight) else: # Otherwise, set the IP-Adapter's scale to 0, so it has no effect. - ip_adapter_unet_patcher.set_scale(i, 0.0) + unet_attention_patcher.set_scale(i, 0.0) # Handle ControlNet(s) down_block_additional_residuals = None diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py index 47f81ff7aa7..2f7523dd469 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py @@ -88,13 +88,12 @@ def __call__( # End unmodified block from AttnProcessor2_0. # Handle regional prompt attention masks. - if regional_prompt_data is not None: + if regional_prompt_data is not None and is_cross_attention: assert percent_through is not None _, query_seq_len, _ = hidden_states.shape - if is_cross_attention: - prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask( - query_seq_len=query_seq_len, key_seq_len=sequence_length - ) + prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask( + query_seq_len=query_seq_len, key_seq_len=sequence_length + ) if attention_mask is None: attention_mask = prompt_region_attention_mask diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 46150d26218..8ba988a0ebc 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -12,8 +12,11 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( ExtraConditioningInfo, IPAdapterConditioningInfo, + Range, TextConditioningData, + TextConditioningRegions, ) +from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData from .cross_attention_control import ( CrossAttentionType, @@ -206,9 +209,9 @@ def do_unet_step( mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter ): + percent_through = step_index / total_step_count cross_attention_control_types_to_do = [] if self.cross_attention_control_context is not None: - percent_through = step_index / total_step_count cross_attention_control_types_to_do = ( self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through) ) @@ -225,6 +228,7 @@ def do_unet_step( sigma=timestep, conditioning_data=conditioning_data, ip_adapter_conditioning=ip_adapter_conditioning, + percent_through=percent_through, cross_attention_control_types_to_do=cross_attention_control_types_to_do, down_block_additional_residuals=down_block_additional_residuals, mid_block_additional_residual=mid_block_additional_residual, @@ -239,6 +243,7 @@ def do_unet_step( sigma=timestep, conditioning_data=conditioning_data, ip_adapter_conditioning=ip_adapter_conditioning, + percent_through=percent_through, down_block_additional_residuals=down_block_additional_residuals, mid_block_additional_residual=mid_block_additional_residual, down_intrablock_additional_residuals=down_intrablock_additional_residuals, @@ -301,6 +306,7 @@ def _apply_standard_conditioning( sigma, conditioning_data: TextConditioningData, ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]], + percent_through: float, down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter @@ -311,17 +317,13 @@ def _apply_standard_conditioning( x_twice = torch.cat([x] * 2) sigma_twice = torch.cat([sigma] * 2) - cross_attention_kwargs = None + cross_attention_kwargs = {} if ip_adapter_conditioning is not None: # Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len). - cross_attention_kwargs = { - "ip_adapter_image_prompt_embeds": [ - torch.stack( - [ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds] - ) - for ipa_conditioning in ip_adapter_conditioning - ] - } + cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [ + torch.stack([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]) + for ipa_conditioning in ip_adapter_conditioning + ] added_cond_kwargs = None if conditioning_data.is_sdxl(): @@ -343,6 +345,31 @@ def _apply_standard_conditioning( ), } + if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None: + # TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings + # and masks are not changing from step-to-step, so this really only needs to be done once. While this seems + # painfully inefficient, the time spent is typically negligible compared to the forward inference pass of + # the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly + # awkward to handle both standard conditioning and sequential conditioning further up the stack. + regions = [] + for c, r in [ + (conditioning_data.uncond_text, conditioning_data.uncond_regions), + (conditioning_data.cond_text, conditioning_data.cond_regions), + ]: + if r is None: + # Create a dummy mask and range for text conditioning that doesn't have region masks. + _, _, h, w = x.shape + r = TextConditioningRegions( + masks=torch.ones((1, 1, h, w), dtype=torch.bool), + ranges=[Range(start=0, end=c.embeds.shape[1])], + ) + regions.append(r) + + cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( + regions=regions, device=x.device, dtype=x.dtype + ) + cross_attention_kwargs["percent_through"] = percent_through + both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch( conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds ) @@ -366,6 +393,7 @@ def _apply_standard_conditioning_sequentially( sigma, conditioning_data: TextConditioningData, ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]], + percent_through: float, cross_attention_control_types_to_do: list[CrossAttentionType], down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet @@ -413,21 +441,19 @@ def _apply_standard_conditioning_sequentially( # Unconditioned pass ##################### - cross_attention_kwargs = None + cross_attention_kwargs = {} # Prepare IP-Adapter cross-attention kwargs for the unconditioned pass. if ip_adapter_conditioning is not None: # Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len). - cross_attention_kwargs = { - "ip_adapter_image_prompt_embeds": [ - torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0) - for ipa_conditioning in ip_adapter_conditioning - ] - } + cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [ + torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0) + for ipa_conditioning in ip_adapter_conditioning + ] # Prepare cross-attention control kwargs for the unconditioned pass. if cross_attn_processor_context is not None: - cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context} + cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context # Prepare SDXL conditioning kwargs for the unconditioned pass. added_cond_kwargs = None @@ -437,6 +463,13 @@ def _apply_standard_conditioning_sequentially( "time_ids": conditioning_data.uncond_text.add_time_ids, } + # Prepare prompt regions for the unconditioned pass. + if conditioning_data.uncond_regions is not None: + cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( + regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype + ) + cross_attention_kwargs["percent_through"] = percent_through + # Run unconditioned UNet denoising (i.e. negative prompt). unconditioned_next_x = self.model_forward_callback( x, @@ -453,22 +486,20 @@ def _apply_standard_conditioning_sequentially( # Conditioned pass ################### - cross_attention_kwargs = None + cross_attention_kwargs = {} # Prepare IP-Adapter cross-attention kwargs for the conditioned pass. if ip_adapter_conditioning is not None: # Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len). - cross_attention_kwargs = { - "ip_adapter_image_prompt_embeds": [ - torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0) - for ipa_conditioning in ip_adapter_conditioning - ] - } + cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [ + torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0) + for ipa_conditioning in ip_adapter_conditioning + ] # Prepare cross-attention control kwargs for the conditioned pass. if cross_attn_processor_context is not None: cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do - cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context} + cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context # Prepare SDXL conditioning kwargs for the conditioned pass. added_cond_kwargs = None @@ -478,6 +509,13 @@ def _apply_standard_conditioning_sequentially( "time_ids": conditioning_data.cond_text.add_time_ids, } + # Prepare prompt regions for the conditioned pass. + if conditioning_data.cond_regions is not None: + cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( + regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype + ) + cross_attention_kwargs["percent_through"] = percent_through + # Run conditioned UNet denoising (i.e. positive prompt). conditioned_next_x = self.model_forward_callback( x, From 4f97192c053735d9c3d7f2fbf5f2b1000a8af3ad Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 11 Mar 2024 09:45:25 -0400 Subject: [PATCH 15/21] (minor) The latest ruff version has _slightly_ different formatting preferences. --- invokeai/app/invocations/mask.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/invokeai/app/invocations/mask.py b/invokeai/app/invocations/mask.py index e892a766c1c..fc5ff2bde69 100644 --- a/invokeai/app/invocations/mask.py +++ b/invokeai/app/invocations/mask.py @@ -28,9 +28,9 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata): def invoke(self, context: InvocationContext) -> MaskOutput: mask = torch.zeros((1, self.height, self.width), dtype=torch.bool) - mask[ - :, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width - ] = True + mask[:, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width] = ( + True + ) mask_name = context.tensors.save(mask) return MaskOutput( From 3a531c50976beae6d3b544d649e48d934d8996ce Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 8 Apr 2024 18:11:00 +1000 Subject: [PATCH 16/21] feat(nodes): add prompt region from image nodes --- invokeai/app/invocations/mask.py | 79 +++++++++++++++++++++++++++++++- 1 file changed, 78 insertions(+), 1 deletion(-) diff --git a/invokeai/app/invocations/mask.py b/invokeai/app/invocations/mask.py index fc5ff2bde69..572fd7c15df 100644 --- a/invokeai/app/invocations/mask.py +++ b/invokeai/app/invocations/mask.py @@ -1,11 +1,15 @@ +import numpy as np import torch +from pydantic import BaseModel from invokeai.app.invocations.baseinvocation import ( BaseInvocation, + BaseInvocationOutput, InvocationContext, invocation, + invocation_output, ) -from invokeai.app.invocations.fields import InputField, MaskField, WithMetadata +from invokeai.app.invocations.fields import ColorField, ImageField, InputField, MaskField, OutputField, WithMetadata from invokeai.app.invocations.primitives import MaskOutput @@ -38,3 +42,76 @@ def invoke(self, context: InvocationContext) -> MaskOutput: width=self.width, height=self.height, ) + + +class PromptColorPair(BaseModel): + prompt: str + color: ColorField + + +class PromptMaskPair(BaseModel): + prompt: str + mask: MaskField + + +default_prompt_color_pairs = [ + PromptColorPair(prompt="Strawberries", color=ColorField(r=200, g=0, b=0, a=255)), + PromptColorPair(prompt="Frog", color=ColorField(r=0, g=200, b=0, a=255)), + PromptColorPair(prompt="Banana", color=ColorField(r=0, g=0, b=200, a=255)), + PromptColorPair(prompt="A gnome", color=ColorField(r=215, g=0, b=255, a=255)), +] + + +@invocation_output("extract_masks_and_prompts_output") +class ExtractMasksAndPromptsOutput(BaseInvocationOutput): + prompt_mask_pairs: list[PromptMaskPair] = OutputField(description="List of prompts and their corresponding masks.") + + +@invocation( + "extract_masks_and_prompts", + title="Extract Masks and Prompts", + tags=["conditioning"], + category="conditioning", + version="1.0.0", +) +class ExtractMasksAndPromptsInvocation(BaseInvocation): + """Extract masks and prompts from a segmented mask image and prompt-to-color map.""" + + prompt_color_pairs: list[PromptColorPair] = InputField( + default=default_prompt_color_pairs, description="List of prompts and their corresponding colors." + ) + image: ImageField = InputField(description="Mask to apply to the prompts.") + + def invoke(self, context: InvocationContext) -> ExtractMasksAndPromptsOutput: + prompt_mask_pairs: list[PromptMaskPair] = [] + image = context.images.get_pil(self.image.image_name) + image_as_tensor = torch.from_numpy(np.array(image, dtype=np.uint8)) + + for pair in self.prompt_color_pairs: + mask = torch.all(image_as_tensor == torch.tensor(pair.color.tuple()), dim=-1) + mask_name = context.tensors.save(mask) + prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=MaskField(mask_name=mask_name))) + + return ExtractMasksAndPromptsOutput(prompt_mask_pairs=prompt_mask_pairs) + + +@invocation_output("split_mask_prompt_pair_output") +class SplitMaskPromptPairOutput(BaseInvocationOutput): + prompt: str = OutputField() + mask: MaskField = OutputField() + + +@invocation( + "split_mask_prompt_pair", + title="Split Mask-Prompt pair", + tags=["conditioning"], + category="conditioning", + version="1.0.0", +) +class SplitMaskPromptPair(BaseInvocation): + """Extract masks and prompts from a segmented mask image and prompt-to-color map.""" + + prompt_mask_pair: PromptMaskPair = InputField() + + def invoke(self, context: InvocationContext) -> SplitMaskPromptPairOutput: + return SplitMaskPromptPairOutput(mask=self.prompt_mask_pair.mask, prompt=self.prompt_mask_pair.prompt) From 98900a7ff1de2dea52c1742d0d87057096e77240 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 8 Apr 2024 10:55:54 -0400 Subject: [PATCH 17/21] Pull the upstream changes from diffusers' AttnProcessor2_0 into CustomAttnProcessor2_0. This fixes a bug in CustomAttnProcessor2_0 that was being triggered when peft was not installed. The bug was present in a block of code that was previously copied from diffusers. The bug seems to have been introduced during diffusers' migration to PEFT for their LoRA handling. The upstream bug was fixed in https://github.com/huggingface/diffusers/commit/531e719163d2d7cf0d725bb685c1e8fe3393b9da. --- .../stable_diffusion/diffusion/custom_atttention.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py index 2f7523dd469..667fcd9a645 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py @@ -3,7 +3,6 @@ import torch import torch.nn.functional as F from diffusers.models.attention_processor import Attention, AttnProcessor2_0 -from diffusers.utils import USE_PEFT_BACKEND from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData @@ -51,7 +50,6 @@ def __call__( encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, - scale: float = 1.0, # For regional prompting: regional_prompt_data: Optional[RegionalPromptData] = None, percent_through: Optional[torch.FloatTensor] = None, @@ -111,16 +109,15 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - args = () if USE_PEFT_BACKEND else (scale,) - query = attn.to_q(hidden_states, *args) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, *args) - value = attn.to_v(encoder_hidden_states, *args) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -187,7 +184,7 @@ def __call__( # Start unmodified block from AttnProcessor2_0. # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv # linear proj - hidden_states = attn.to_out[0](hidden_states, *args) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) From 826f3d625a61ecaeef59864aabb61af40ab5338a Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 8 Apr 2024 12:27:57 -0400 Subject: [PATCH 18/21] Fix dimensions of mask produced by ExtractMasksAndPromptsInvocation. Also, added a clearer error message in case the same error is introduced in the future. --- invokeai/app/invocations/latent.py | 5 +++++ invokeai/app/invocations/mask.py | 4 +++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 764e744a2ec..db7cd201725 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -401,6 +401,11 @@ def _preprocess_regional_prompt_mask( tf = torchvision.transforms.Resize( (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST ) + + if len(mask.shape) != 3 or mask.shape[0] != 1: + raise ValueError(f"Invalid regional prompt mask shape: {mask.shape}. Expected shape (1, h, w).") + + # Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w). mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w) resized_mask = tf(mask) return resized_mask diff --git a/invokeai/app/invocations/mask.py b/invokeai/app/invocations/mask.py index 572fd7c15df..31eb70e0567 100644 --- a/invokeai/app/invocations/mask.py +++ b/invokeai/app/invocations/mask.py @@ -88,10 +88,12 @@ def invoke(self, context: InvocationContext) -> ExtractMasksAndPromptsOutput: image_as_tensor = torch.from_numpy(np.array(image, dtype=np.uint8)) for pair in self.prompt_color_pairs: + # TODO(ryand): Make this work for both RGB and RGBA images. mask = torch.all(image_as_tensor == torch.tensor(pair.color.tuple()), dim=-1) + # Add explicit channel dimension. + mask = mask.unsqueeze(0) mask_name = context.tensors.save(mask) prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=MaskField(mask_name=mask_name))) - return ExtractMasksAndPromptsOutput(prompt_mask_pairs=prompt_mask_pairs) From 26a2b23fa661ba1a1a2645e5b50cf2427e0cff06 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 8 Apr 2024 14:16:22 -0400 Subject: [PATCH 19/21] Rename MaskField to be a generice TensorField. --- invokeai/app/invocations/compel.py | 6 +++--- invokeai/app/invocations/fields.py | 10 +++++----- invokeai/app/invocations/latent.py | 2 +- invokeai/app/invocations/mask.py | 14 +++++++------- invokeai/app/invocations/primitives.py | 4 ++-- 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 6df3301362e..92012691ea2 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -10,8 +10,8 @@ FieldDescriptions, Input, InputField, - MaskField, OutputField, + TensorField, UIComponent, ) from invokeai.app.invocations.primitives import ConditioningOutput @@ -59,7 +59,7 @@ class CompelInvocation(BaseInvocation): description=FieldDescriptions.clip, input=Input.Connection, ) - mask: Optional[MaskField] = InputField( + mask: Optional[TensorField] = InputField( default=None, description="A mask defining the region that this conditioning prompt applies to." ) @@ -270,7 +270,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): target_height: int = InputField(default=1024, description="") clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1") clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2") - mask: Optional[MaskField] = InputField( + mask: Optional[TensorField] = InputField( default=None, description="A mask defining the region that this conditioning prompt applies to." ) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 56b9e12a6cd..0fa0216f1c7 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -203,10 +203,10 @@ class DenoiseMaskField(BaseModel): gradient: bool = Field(default=False, description="Used for gradient inpainting") -class MaskField(BaseModel): - """A mask primitive field.""" +class TensorField(BaseModel): + """A tensor primitive field.""" - mask_name: str = Field(description="The name of a spatial mask. dtype: bool, shape: (1, h, w).") + tensor_name: str = Field(description="The name of a tensor.") class LatentsField(BaseModel): @@ -232,9 +232,9 @@ class ConditioningField(BaseModel): """A conditioning tensor primitive value""" conditioning_name: str = Field(description="The name of conditioning tensor") - mask: Optional[MaskField] = Field( + mask: Optional[TensorField] = Field( default=None, - description="The bool mask associated with this conditioning tensor. Excluded regions should be set to False, " + description="The mask associated with this conditioning tensor. Excluded regions should be set to False, " "included regions should be set to True.", ) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index db7cd201725..3070cd1e703 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -380,7 +380,7 @@ def _get_text_embeddings_and_masks( mask = cond.mask if mask is not None: - mask = context.tensors.load(mask.mask_name) + mask = context.tensors.load(mask.tensor_name) text_embeddings_masks.append(mask) return text_embeddings, text_embeddings_masks diff --git a/invokeai/app/invocations/mask.py b/invokeai/app/invocations/mask.py index 31eb70e0567..de4887e20d1 100644 --- a/invokeai/app/invocations/mask.py +++ b/invokeai/app/invocations/mask.py @@ -9,7 +9,7 @@ invocation, invocation_output, ) -from invokeai.app.invocations.fields import ColorField, ImageField, InputField, MaskField, OutputField, WithMetadata +from invokeai.app.invocations.fields import ColorField, ImageField, InputField, OutputField, TensorField, WithMetadata from invokeai.app.invocations.primitives import MaskOutput @@ -36,9 +36,9 @@ def invoke(self, context: InvocationContext) -> MaskOutput: True ) - mask_name = context.tensors.save(mask) + mask_tensor_name = context.tensors.save(mask) return MaskOutput( - mask=MaskField(mask_name=mask_name), + mask=TensorField(tensor_name=mask_tensor_name), width=self.width, height=self.height, ) @@ -51,7 +51,7 @@ class PromptColorPair(BaseModel): class PromptMaskPair(BaseModel): prompt: str - mask: MaskField + mask: TensorField default_prompt_color_pairs = [ @@ -92,15 +92,15 @@ def invoke(self, context: InvocationContext) -> ExtractMasksAndPromptsOutput: mask = torch.all(image_as_tensor == torch.tensor(pair.color.tuple()), dim=-1) # Add explicit channel dimension. mask = mask.unsqueeze(0) - mask_name = context.tensors.save(mask) - prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=MaskField(mask_name=mask_name))) + mask_tensor_name = context.tensors.save(mask) + prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=TensorField(tensor_name=mask_tensor_name))) return ExtractMasksAndPromptsOutput(prompt_mask_pairs=prompt_mask_pairs) @invocation_output("split_mask_prompt_pair_output") class SplitMaskPromptPairOutput(BaseInvocationOutput): prompt: str = OutputField() - mask: MaskField = OutputField() + mask: TensorField = OutputField() @invocation( diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index 25930f7d004..28f72fb377a 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -14,8 +14,8 @@ Input, InputField, LatentsField, - MaskField, OutputField, + TensorField, UIComponent, ) from invokeai.app.services.images.images_common import ImageDTO @@ -414,7 +414,7 @@ def invoke(self, context: InvocationContext) -> ColorOutput: class MaskOutput(BaseInvocationOutput): """A torch mask tensor.""" - mask: MaskField = OutputField(description="The mask.") + mask: TensorField = OutputField(description="The mask.") width: int = OutputField(description="The width of the mask in pixels.") height: int = OutputField(description="The height of the mask in pixels.") From eb328421db8c9b325a995aa1fce35f077c58eff0 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 8 Apr 2024 15:07:49 -0400 Subject: [PATCH 20/21] Add utility to_standard_float_mask(...) to convert various mask formats to a standardized format. --- invokeai/app/invocations/latent.py | 20 +++-- invokeai/app/invocations/mask.py | 2 - .../diffusion/regional_prompt_data.py | 1 - .../diffusion/shared_invokeai_diffusion.py | 2 +- invokeai/backend/util/mask.py | 53 +++++++++++ tests/backend/util/test_mask.py | 88 +++++++++++++++++++ 6 files changed, 155 insertions(+), 11 deletions(-) create mode 100644 invokeai/backend/util/mask.py create mode 100644 tests/backend/util/test_mask.py diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 3070cd1e703..d5babe42cc1 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -61,6 +61,7 @@ TextConditioningData, TextConditioningRegions, ) +from invokeai.backend.util.mask import to_standard_float_mask from invokeai.backend.util.silence_warnings import SilenceWarnings from ...backend.stable_diffusion.diffusers_pipeline import ( @@ -386,25 +387,25 @@ def _get_text_embeddings_and_masks( return text_embeddings, text_embeddings_masks def _preprocess_regional_prompt_mask( - self, mask: Optional[torch.Tensor], target_height: int, target_width: int + self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype ) -> torch.Tensor: """Preprocess a regional prompt mask to match the target height and width. If mask is None, returns a mask of all ones with the target height and width. If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation. Returns: - torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width). + torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width). """ + if mask is None: - return torch.ones((1, 1, target_height, target_width), dtype=torch.bool) + return torch.ones((1, 1, target_height, target_width), dtype=dtype) + + mask = to_standard_float_mask(mask, out_dtype=dtype) tf = torchvision.transforms.Resize( (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST ) - if len(mask.shape) != 3 or mask.shape[0] != 1: - raise ValueError(f"Invalid regional prompt mask shape: {mask.shape}. Expected shape (1, h, w).") - # Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w). mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w) resized_mask = tf(mask) @@ -416,6 +417,7 @@ def _concat_regional_text_embeddings( masks: Optional[list[Optional[torch.Tensor]]], latent_height: int, latent_width: int, + dtype: torch.dtype, ) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]: """Concatenate regional text embeddings into a single embedding and track the region masks accordingly.""" if masks is None: @@ -465,7 +467,9 @@ def _concat_regional_text_embeddings( start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1] ) ) - processed_masks.append(self._preprocess_regional_prompt_mask(mask, latent_height, latent_width)) + processed_masks.append( + self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype) + ) cur_text_embedding_len += text_embedding_info.embeds.shape[1] @@ -524,12 +528,14 @@ def get_conditioning_data( masks=cond_text_embedding_masks, latent_height=latent_height, latent_width=latent_width, + dtype=unet.dtype, ) uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings( text_conditionings=uncond_text_embeddings, masks=uncond_text_embedding_masks, latent_height=latent_height, latent_width=latent_width, + dtype=unet.dtype, ) conditioning_data = TextConditioningData( diff --git a/invokeai/app/invocations/mask.py b/invokeai/app/invocations/mask.py index de4887e20d1..b3588f204df 100644 --- a/invokeai/app/invocations/mask.py +++ b/invokeai/app/invocations/mask.py @@ -90,8 +90,6 @@ def invoke(self, context: InvocationContext) -> ExtractMasksAndPromptsOutput: for pair in self.prompt_color_pairs: # TODO(ryand): Make this work for both RGB and RGBA images. mask = torch.all(image_as_tensor == torch.tensor(pair.color.tuple()), dim=-1) - # Add explicit channel dimension. - mask = mask.unsqueeze(0) mask_tensor_name = context.tensors.save(mask) prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=TensorField(tensor_name=mask_tensor_name))) return ExtractMasksAndPromptsOutput(prompt_mask_pairs=prompt_mask_pairs) diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py index 95f81b1f93e..85331013d5f 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py @@ -46,7 +46,6 @@ def _prepare_spatial_masks( for batch_sample_regions in regions: batch_sample_masks_by_seq_len.append({}) - # Convert the bool masks to float masks so that max pooling can be applied. batch_sample_masks = batch_sample_regions.masks.to(device=self._device, dtype=self._dtype) # Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached. diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 8ba988a0ebc..4d95cb8f0df 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -360,7 +360,7 @@ def _apply_standard_conditioning( # Create a dummy mask and range for text conditioning that doesn't have region masks. _, _, h, w = x.shape r = TextConditioningRegions( - masks=torch.ones((1, 1, h, w), dtype=torch.bool), + masks=torch.ones((1, 1, h, w), dtype=x.dtype), ranges=[Range(start=0, end=c.embeds.shape[1])], ) regions.append(r) diff --git a/invokeai/backend/util/mask.py b/invokeai/backend/util/mask.py new file mode 100644 index 00000000000..45aa32061c2 --- /dev/null +++ b/invokeai/backend/util/mask.py @@ -0,0 +1,53 @@ +import torch + + +def to_standard_mask_dim(mask: torch.Tensor) -> torch.Tensor: + """Standardize the dimensions of a mask tensor. + + Args: + mask (torch.Tensor): A mask tensor. The shape can be (1, h, w) or (h, w). + + Returns: + torch.Tensor: The output mask tensor. The shape is (1, h, w). + """ + # Get the mask height and width. + if mask.ndim == 2: + mask = mask.unsqueeze(0) + elif mask.ndim == 3 and mask.shape[0] == 1: + pass + else: + raise ValueError(f"Unsupported mask shape: {mask.shape}. Expected (1, h, w) or (h, w).") + + return mask + + +def to_standard_float_mask(mask: torch.Tensor, out_dtype: torch.dtype) -> torch.Tensor: + """Standardize the format of a mask tensor. + + Args: + mask (torch.Tensor): A mask tensor. The dtype can be any bool, float, or int type. The shape must be (1, h, w) + or (h, w). + + out_dtype (torch.dtype): The dtype of the output mask tensor. Must be a float type. + + Returns: + torch.Tensor: The output mask tensor. The dtype is out_dtype. The shape is (1, h, w). All values are either 0.0 + or 1.0. + """ + + if not out_dtype.is_floating_point: + raise ValueError(f"out_dtype must be a float type, but got {out_dtype}") + + mask = to_standard_mask_dim(mask) + mask = mask.to(out_dtype) + + # Set masked regions to 1.0. + if mask.dtype == torch.bool: + mask = mask.to(out_dtype) + else: + mask = mask.to(out_dtype) + mask_region = mask > 0.5 + mask[mask_region] = 1.0 + mask[~mask_region] = 0.0 + + return mask diff --git a/tests/backend/util/test_mask.py b/tests/backend/util/test_mask.py new file mode 100644 index 00000000000..96d3aab07f9 --- /dev/null +++ b/tests/backend/util/test_mask.py @@ -0,0 +1,88 @@ +import pytest +import torch + +from invokeai.backend.util.mask import to_standard_float_mask + + +def test_to_standard_float_mask_wrong_ndim(): + with pytest.raises(ValueError): + to_standard_float_mask(mask=torch.zeros((1, 1, 5, 10)), out_dtype=torch.float32) + + +def test_to_standard_float_mask_wrong_shape(): + with pytest.raises(ValueError): + to_standard_float_mask(mask=torch.zeros((2, 5, 10)), out_dtype=torch.float32) + + +def check_mask_result(mask: torch.Tensor, expected_mask: torch.Tensor): + """Helper function to check the result of `to_standard_float_mask()`.""" + assert mask.shape == expected_mask.shape + assert mask.dtype == expected_mask.dtype + assert torch.allclose(mask, expected_mask) + + +def test_to_standard_float_mask_ndim_2(): + """Test the case where the input mask has shape (h, w).""" + mask = torch.zeros((3, 2), dtype=torch.float32) + mask[0, 0] = 1.0 + mask[1, 1] = 1.0 + + expected_mask = torch.zeros((1, 3, 2), dtype=torch.float32) + expected_mask[0, 0, 0] = 1.0 + expected_mask[0, 1, 1] = 1.0 + + new_mask = to_standard_float_mask(mask=mask, out_dtype=torch.float32) + + check_mask_result(mask=new_mask, expected_mask=expected_mask) + + +def test_to_standard_float_mask_ndim_3(): + """Test the case where the input mask has shape (1, h, w).""" + mask = torch.zeros((1, 3, 2), dtype=torch.float32) + mask[0, 0, 0] = 1.0 + mask[0, 1, 1] = 1.0 + + expected_mask = torch.zeros((1, 3, 2), dtype=torch.float32) + expected_mask[0, 0, 0] = 1.0 + expected_mask[0, 1, 1] = 1.0 + + new_mask = to_standard_float_mask(mask=mask, out_dtype=torch.float32) + + check_mask_result(mask=new_mask, expected_mask=expected_mask) + + +@pytest.mark.parametrize( + "out_dtype", + [torch.float32, torch.float16], +) +def test_to_standard_float_mask_bool_to_float(out_dtype: torch.dtype): + """Test the case where the input mask has dtype bool.""" + mask = torch.zeros((3, 2), dtype=torch.bool) + mask[0, 0] = True + mask[1, 1] = True + + expected_mask = torch.zeros((1, 3, 2), dtype=out_dtype) + expected_mask[0, 0, 0] = 1.0 + expected_mask[0, 1, 1] = 1.0 + + new_mask = to_standard_float_mask(mask=mask, out_dtype=out_dtype) + + check_mask_result(mask=new_mask, expected_mask=expected_mask) + + +@pytest.mark.parametrize( + "out_dtype", + [torch.float32, torch.float16], +) +def test_to_standard_float_mask_float_to_float(out_dtype: torch.dtype): + """Test the case where the input mask has type float (but not all values are 0.0 or 1.0).""" + mask = torch.zeros((3, 2), dtype=torch.float32) + mask[0, 0] = 0.1 # Should be converted to 0.0 + mask[0, 1] = 0.9 # Should be converted to 1.0 + + expected_mask = torch.zeros((1, 3, 2), dtype=out_dtype) + expected_mask[0, 0, 1] = 1.0 + + new_mask = to_standard_float_mask(mask=mask, out_dtype=out_dtype) + + check_mask_result(mask=new_mask, expected_mask=expected_mask) From 3e61d5f1c89e306e57ef30cff95ec6f1b27e75d7 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 9 Apr 2024 20:28:38 +1000 Subject: [PATCH 21/21] Revert "feat(nodes): add prompt region from image nodes" This reverts commit 3a531c50976beae6d3b544d649e48d934d8996ce. --- invokeai/app/invocations/mask.py | 85 +------------------------------- 1 file changed, 2 insertions(+), 83 deletions(-) diff --git a/invokeai/app/invocations/mask.py b/invokeai/app/invocations/mask.py index b3588f204df..2d414ac2bda 100644 --- a/invokeai/app/invocations/mask.py +++ b/invokeai/app/invocations/mask.py @@ -1,15 +1,7 @@ -import numpy as np import torch -from pydantic import BaseModel -from invokeai.app.invocations.baseinvocation import ( - BaseInvocation, - BaseInvocationOutput, - InvocationContext, - invocation, - invocation_output, -) -from invokeai.app.invocations.fields import ColorField, ImageField, InputField, OutputField, TensorField, WithMetadata +from invokeai.app.invocations.baseinvocation import BaseInvocation, InvocationContext, invocation +from invokeai.app.invocations.fields import InputField, TensorField, WithMetadata from invokeai.app.invocations.primitives import MaskOutput @@ -42,76 +34,3 @@ def invoke(self, context: InvocationContext) -> MaskOutput: width=self.width, height=self.height, ) - - -class PromptColorPair(BaseModel): - prompt: str - color: ColorField - - -class PromptMaskPair(BaseModel): - prompt: str - mask: TensorField - - -default_prompt_color_pairs = [ - PromptColorPair(prompt="Strawberries", color=ColorField(r=200, g=0, b=0, a=255)), - PromptColorPair(prompt="Frog", color=ColorField(r=0, g=200, b=0, a=255)), - PromptColorPair(prompt="Banana", color=ColorField(r=0, g=0, b=200, a=255)), - PromptColorPair(prompt="A gnome", color=ColorField(r=215, g=0, b=255, a=255)), -] - - -@invocation_output("extract_masks_and_prompts_output") -class ExtractMasksAndPromptsOutput(BaseInvocationOutput): - prompt_mask_pairs: list[PromptMaskPair] = OutputField(description="List of prompts and their corresponding masks.") - - -@invocation( - "extract_masks_and_prompts", - title="Extract Masks and Prompts", - tags=["conditioning"], - category="conditioning", - version="1.0.0", -) -class ExtractMasksAndPromptsInvocation(BaseInvocation): - """Extract masks and prompts from a segmented mask image and prompt-to-color map.""" - - prompt_color_pairs: list[PromptColorPair] = InputField( - default=default_prompt_color_pairs, description="List of prompts and their corresponding colors." - ) - image: ImageField = InputField(description="Mask to apply to the prompts.") - - def invoke(self, context: InvocationContext) -> ExtractMasksAndPromptsOutput: - prompt_mask_pairs: list[PromptMaskPair] = [] - image = context.images.get_pil(self.image.image_name) - image_as_tensor = torch.from_numpy(np.array(image, dtype=np.uint8)) - - for pair in self.prompt_color_pairs: - # TODO(ryand): Make this work for both RGB and RGBA images. - mask = torch.all(image_as_tensor == torch.tensor(pair.color.tuple()), dim=-1) - mask_tensor_name = context.tensors.save(mask) - prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=TensorField(tensor_name=mask_tensor_name))) - return ExtractMasksAndPromptsOutput(prompt_mask_pairs=prompt_mask_pairs) - - -@invocation_output("split_mask_prompt_pair_output") -class SplitMaskPromptPairOutput(BaseInvocationOutput): - prompt: str = OutputField() - mask: TensorField = OutputField() - - -@invocation( - "split_mask_prompt_pair", - title="Split Mask-Prompt pair", - tags=["conditioning"], - category="conditioning", - version="1.0.0", -) -class SplitMaskPromptPair(BaseInvocation): - """Extract masks and prompts from a segmented mask image and prompt-to-color map.""" - - prompt_mask_pair: PromptMaskPair = InputField() - - def invoke(self, context: InvocationContext) -> SplitMaskPromptPairOutput: - return SplitMaskPromptPairOutput(mask=self.prompt_mask_pair.mask, prompt=self.prompt_mask_pair.prompt)