diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index c23dd3d908e..92012691ea2 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -5,7 +5,15 @@ from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent +from invokeai.app.invocations.fields import ( + ConditioningField, + FieldDescriptions, + Input, + InputField, + OutputField, + TensorField, + UIComponent, +) from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.ti_utils import generate_ti_list @@ -36,7 +44,7 @@ title="Prompt", tags=["prompt", "compel"], category="conditioning", - version="1.1.1", + version="1.2.0", ) class CompelInvocation(BaseInvocation): """Parse prompt using compel package to conditioning.""" @@ -51,6 +59,9 @@ class CompelInvocation(BaseInvocation): description=FieldDescriptions.clip, input=Input.Connection, ) + mask: Optional[TensorField] = InputField( + default=None, description="A mask defining the region that this conditioning prompt applies to." + ) @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: @@ -117,8 +128,12 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: ) conditioning_name = context.conditioning.save(conditioning_data) - - return ConditioningOutput.build(conditioning_name) + return ConditioningOutput( + conditioning=ConditioningField( + conditioning_name=conditioning_name, + mask=self.mask, + ) + ) class SDXLPromptInvocationBase: @@ -232,7 +247,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: title="SDXL Prompt", tags=["sdxl", "compel", "prompt"], category="conditioning", - version="1.1.1", + version="1.2.0", ) class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -255,6 +270,9 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): target_height: int = InputField(default=1024, description="") clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1") clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2") + mask: Optional[TensorField] = InputField( + default=None, description="A mask defining the region that this conditioning prompt applies to." + ) @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: @@ -317,7 +335,12 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: conditioning_name = context.conditioning.save(conditioning_data) - return ConditioningOutput.build(conditioning_name) + return ConditioningOutput( + conditioning=ConditioningField( + conditioning_name=conditioning_name, + mask=self.mask, + ) + ) @invocation( diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index d90c71a32de..0fa0216f1c7 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -203,6 +203,12 @@ class DenoiseMaskField(BaseModel): gradient: bool = Field(default=False, description="Used for gradient inpainting") +class TensorField(BaseModel): + """A tensor primitive field.""" + + tensor_name: str = Field(description="The name of a tensor.") + + class LatentsField(BaseModel): """A latents tensor primitive field""" @@ -226,7 +232,11 @@ class ConditioningField(BaseModel): """A conditioning tensor primitive value""" conditioning_name: str = Field(description="The name of conditioning tensor") - # endregion + mask: Optional[TensorField] = Field( + default=None, + description="The mask associated with this conditioning tensor. Excluded regions should be set to False, " + "included regions should be set to True.", + ) class MetadataField(RootModel[dict[str, Any]]): diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 449d0135049..d5babe42cc1 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -1,5 +1,5 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) - +import inspect import math from contextlib import ExitStack from functools import singledispatchmethod @@ -9,6 +9,7 @@ import numpy as np import numpy.typing as npt import torch +import torchvision import torchvision.transforms as T from diffusers import AutoencoderKL, AutoencoderTiny from diffusers.configuration_utils import ConfigMixin @@ -52,7 +53,15 @@ from invokeai.backend.model_manager import BaseModelType, LoadedModel from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + BasicConditioningInfo, + IPAdapterConditioningInfo, + Range, + SDXLConditioningInfo, + TextConditioningData, + TextConditioningRegions, +) +from invokeai.backend.util.mask import to_standard_float_mask from invokeai.backend.util.silence_warnings import SilenceWarnings from ...backend.stable_diffusion.diffusers_pipeline import ( @@ -275,10 +284,10 @@ def get_scheduler( class DenoiseLatentsInvocation(BaseInvocation): """Denoises noisy latents to decodable images""" - positive_conditioning: ConditioningField = InputField( + positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField( description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0 ) - negative_conditioning: ConditioningField = InputField( + negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField( description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1 ) noise: Optional[LatentsField] = InputField( @@ -356,34 +365,187 @@ def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]: raise ValueError("cfg_scale must be greater than 1") return v + def _get_text_embeddings_and_masks( + self, + cond_list: list[ConditioningField], + context: InvocationContext, + device: torch.device, + dtype: torch.dtype, + ) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]: + """Get the text embeddings and masks from the input conditioning fields.""" + text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = [] + text_embeddings_masks: list[Optional[torch.Tensor]] = [] + for cond in cond_list: + cond_data = context.conditioning.load(cond.conditioning_name) + text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype)) + + mask = cond.mask + if mask is not None: + mask = context.tensors.load(mask.tensor_name) + text_embeddings_masks.append(mask) + + return text_embeddings, text_embeddings_masks + + def _preprocess_regional_prompt_mask( + self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype + ) -> torch.Tensor: + """Preprocess a regional prompt mask to match the target height and width. + If mask is None, returns a mask of all ones with the target height and width. + If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation. + + Returns: + torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width). + """ + + if mask is None: + return torch.ones((1, 1, target_height, target_width), dtype=dtype) + + mask = to_standard_float_mask(mask, out_dtype=dtype) + + tf = torchvision.transforms.Resize( + (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST + ) + + # Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w). + mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w) + resized_mask = tf(mask) + return resized_mask + + def _concat_regional_text_embeddings( + self, + text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], + masks: Optional[list[Optional[torch.Tensor]]], + latent_height: int, + latent_width: int, + dtype: torch.dtype, + ) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]: + """Concatenate regional text embeddings into a single embedding and track the region masks accordingly.""" + if masks is None: + masks = [None] * len(text_conditionings) + assert len(text_conditionings) == len(masks) + + is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo + + all_masks_are_none = all(mask is None for mask in masks) + + text_embedding = [] + pooled_embedding = None + add_time_ids = None + cur_text_embedding_len = 0 + processed_masks = [] + embedding_ranges = [] + extra_conditioning = None + + for prompt_idx, text_embedding_info in enumerate(text_conditionings): + mask = masks[prompt_idx] + if ( + text_embedding_info.extra_conditioning is not None + and text_embedding_info.extra_conditioning.wants_cross_attention_control + ): + extra_conditioning = text_embedding_info.extra_conditioning + + if is_sdxl: + # We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for + # prompts without a mask. We prefer prompts without a mask, because they are more likely to contain + # global prompt information. In an ideal case, there should be exactly one global prompt without a + # mask, but we don't enforce this. + + # HACK(ryand): The fact that we have to choose a single pooled_embedding and add_time_ids here is a + # fundamental interface issue. The SDXL Compel nodes are not designed to be used in the way that we use + # them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single + # pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a + # pretty major breaking change to a popular node, so for now we use this hack. + if pooled_embedding is None or mask is None: + pooled_embedding = text_embedding_info.pooled_embeds + if add_time_ids is None or mask is None: + add_time_ids = text_embedding_info.add_time_ids + + text_embedding.append(text_embedding_info.embeds) + if not all_masks_are_none: + embedding_ranges.append( + Range( + start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1] + ) + ) + processed_masks.append( + self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype) + ) + + cur_text_embedding_len += text_embedding_info.embeds.shape[1] + + text_embedding = torch.cat(text_embedding, dim=1) + assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len + + regions = None + if not all_masks_are_none: + regions = TextConditioningRegions( + masks=torch.cat(processed_masks, dim=1), + ranges=embedding_ranges, + ) + + if extra_conditioning is not None and len(text_conditionings) > 1: + raise ValueError( + "Prompt-to-prompt cross-attention control (a.k.a. `swap()`) is not supported when using multiple " + "prompts." + ) + + if is_sdxl: + return SDXLConditioningInfo( + embeds=text_embedding, + extra_conditioning=extra_conditioning, + pooled_embeds=pooled_embedding, + add_time_ids=add_time_ids, + ), regions + return BasicConditioningInfo( + embeds=text_embedding, + extra_conditioning=extra_conditioning, + ), regions + def get_conditioning_data( self, context: InvocationContext, - scheduler: Scheduler, unet: UNet2DConditionModel, - seed: int, - ) -> ConditioningData: - positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name) - c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) + latent_height: int, + latent_width: int, + ) -> TextConditioningData: + # Normalize self.positive_conditioning and self.negative_conditioning to lists. + cond_list = self.positive_conditioning + if not isinstance(cond_list, list): + cond_list = [cond_list] + uncond_list = self.negative_conditioning + if not isinstance(uncond_list, list): + uncond_list = [uncond_list] + + cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks( + cond_list, context, unet.device, unet.dtype + ) + uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks( + uncond_list, context, unet.device, unet.dtype + ) - negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name) - uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) + cond_text_embedding, cond_regions = self._concat_regional_text_embeddings( + text_conditionings=cond_text_embeddings, + masks=cond_text_embedding_masks, + latent_height=latent_height, + latent_width=latent_width, + dtype=unet.dtype, + ) + uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings( + text_conditionings=uncond_text_embeddings, + masks=uncond_text_embedding_masks, + latent_height=latent_height, + latent_width=latent_width, + dtype=unet.dtype, + ) - conditioning_data = ConditioningData( - unconditioned_embeddings=uc, - text_embeddings=c, + conditioning_data = TextConditioningData( + uncond_text=uncond_text_embedding, + cond_text=cond_text_embedding, + uncond_regions=uncond_regions, + cond_regions=cond_regions, guidance_scale=self.cfg_scale, guidance_rescale_multiplier=self.cfg_rescale_multiplier, ) - - conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME - scheduler, - # for ddim scheduler - eta=0.0, # ddim_eta - # for ancestral and sde schedulers - # flip all bits to have noise different from initial - generator=torch.Generator(device=unet.device).manual_seed(seed ^ 0xFFFFFFFF), - ) return conditioning_data def create_pipeline( @@ -488,7 +650,6 @@ def prep_ip_adapter_data( self, context: InvocationContext, ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]], - conditioning_data: ConditioningData, exit_stack: ExitStack, ) -> Optional[list[IPAdapterData]]: """If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings @@ -505,7 +666,6 @@ def prep_ip_adapter_data( return None ip_adapter_data_list = [] - conditioning_data.ip_adapter_conditioning = [] for single_ip_adapter in ip_adapter: ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( context.models.load(single_ip_adapter.ip_adapter_model) @@ -528,16 +688,13 @@ def prep_ip_adapter_data( single_ipa_images, image_encoder_model ) - conditioning_data.ip_adapter_conditioning.append( - IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds) - ) - ip_adapter_data_list.append( IPAdapterData( ip_adapter_model=ip_adapter_model, weight=single_ip_adapter.weight, begin_step_percent=single_ip_adapter.begin_step_percent, end_step_percent=single_ip_adapter.end_step_percent, + ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds), ) ) @@ -627,6 +784,7 @@ def init_scheduler( steps: int, denoising_start: float, denoising_end: float, + seed: int, ) -> Tuple[int, List[int], int]: assert isinstance(scheduler, ConfigMixin) if scheduler.config.get("cpu_only", False): @@ -655,7 +813,15 @@ def init_scheduler( timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx] num_inference_steps = len(timesteps) // scheduler.order - return num_inference_steps, timesteps, init_timestep + scheduler_step_kwargs = {} + scheduler_step_signature = inspect.signature(scheduler.step) + if "generator" in scheduler_step_signature.parameters: + # At some point, someone decided that schedulers that accept a generator should use the original seed with + # all bits flipped. I don't know the original rationale for this, but now we must keep it like this for + # reproducibility. + scheduler_step_kwargs = {"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)} + + return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs def prep_inpaint_mask( self, context: InvocationContext, latents: torch.Tensor @@ -749,7 +915,11 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: ) pipeline = self.create_pipeline(unet, scheduler) - conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed) + + _, _, latent_height, latent_width = latents.shape + conditioning_data = self.get_conditioning_data( + context=context, unet=unet, latent_height=latent_height, latent_width=latent_width + ) controlnet_data = self.prep_control_data( context=context, @@ -763,16 +933,16 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: ip_adapter_data = self.prep_ip_adapter_data( context=context, ip_adapter=self.ip_adapter, - conditioning_data=conditioning_data, exit_stack=exit_stack, ) - num_inference_steps, timesteps, init_timestep = self.init_scheduler( + num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler( scheduler, device=unet.device, steps=self.steps, denoising_start=self.denoising_start, denoising_end=self.denoising_end, + seed=seed, ) result_latents = pipeline.latents_from_embeddings( @@ -785,6 +955,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: masked_latents=masked_latents, gradient_mask=gradient_mask, num_inference_steps=num_inference_steps, + scheduler_step_kwargs=scheduler_step_kwargs, conditioning_data=conditioning_data, control_data=controlnet_data, ip_adapter_data=ip_adapter_data, diff --git a/invokeai/app/invocations/mask.py b/invokeai/app/invocations/mask.py new file mode 100644 index 00000000000..2d414ac2bda --- /dev/null +++ b/invokeai/app/invocations/mask.py @@ -0,0 +1,36 @@ +import torch + +from invokeai.app.invocations.baseinvocation import BaseInvocation, InvocationContext, invocation +from invokeai.app.invocations.fields import InputField, TensorField, WithMetadata +from invokeai.app.invocations.primitives import MaskOutput + + +@invocation( + "rectangle_mask", + title="Create Rectangle Mask", + tags=["conditioning"], + category="conditioning", + version="1.0.0", +) +class RectangleMaskInvocation(BaseInvocation, WithMetadata): + """Create a rectangular mask.""" + + height: int = InputField(description="The height of the entire mask.") + width: int = InputField(description="The width of the entire mask.") + y_top: int = InputField(description="The top y-coordinate of the rectangular masked region (inclusive).") + x_left: int = InputField(description="The left x-coordinate of the rectangular masked region (inclusive).") + rectangle_height: int = InputField(description="The height of the rectangular masked region.") + rectangle_width: int = InputField(description="The width of the rectangular masked region.") + + def invoke(self, context: InvocationContext) -> MaskOutput: + mask = torch.zeros((1, self.height, self.width), dtype=torch.bool) + mask[:, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width] = ( + True + ) + + mask_tensor_name = context.tensors.save(mask) + return MaskOutput( + mask=TensorField(tensor_name=mask_tensor_name), + width=self.width, + height=self.height, + ) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index 6a8e4e4531d..28f72fb377a 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -15,6 +15,7 @@ InputField, LatentsField, OutputField, + TensorField, UIComponent, ) from invokeai.app.services.images.images_common import ImageDTO @@ -405,9 +406,19 @@ def invoke(self, context: InvocationContext) -> ColorOutput: # endregion + # region Conditioning +@invocation_output("mask_output") +class MaskOutput(BaseInvocationOutput): + """A torch mask tensor.""" + + mask: TensorField = OutputField(description="The mask.") + width: int = OutputField(description="The width of the mask in pixels.") + height: int = OutputField(description="The height of the mask in pixels.") + + @invocation_output("conditioning_output") class ConditioningOutput(BaseInvocationOutput): """Base class for nodes that output a single conditioning tensor""" diff --git a/invokeai/backend/ip_adapter/attention_processor.py b/invokeai/backend/ip_adapter/attention_processor.py deleted file mode 100644 index 195cb12d1b8..00000000000 --- a/invokeai/backend/ip_adapter/attention_processor.py +++ /dev/null @@ -1,182 +0,0 @@ -# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0) -# and modified as needed - -# tencent-ailab comment: -# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py -import torch -import torch.nn as nn -import torch.nn.functional as F -from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0 - -from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights - - -# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict -# loading. -class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module): - def __init__(self): - DiffusersAttnProcessor2_0.__init__(self) - nn.Module.__init__(self) - - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - ip_adapter_image_prompt_embeds=None, - ): - """Re-definition of DiffusersAttnProcessor2_0.__call__(...) that accepts and ignores the - ip_adapter_image_prompt_embeds parameter. - """ - return DiffusersAttnProcessor2_0.__call__( - self, attn, hidden_states, encoder_hidden_states, attention_mask, temb - ) - - -class IPAttnProcessor2_0(torch.nn.Module): - r""" - Attention processor for IP-Adapater for PyTorch 2.0. - Args: - hidden_size (`int`): - The hidden size of the attention layer. - cross_attention_dim (`int`): - The number of channels in the `encoder_hidden_states`. - scale (`float`, defaults to 1.0): - the weight scale of image prompt. - """ - - def __init__(self, weights: list[IPAttentionProcessorWeights], scales: list[float]): - super().__init__() - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - assert len(weights) == len(scales) - - self._weights = weights - self._scales = scales - - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - ip_adapter_image_prompt_embeds=None, - ): - """Apply IP-Adapter attention. - - Args: - ip_adapter_image_prompt_embeds (torch.Tensor): The image prompt embeddings. - Shape: (batch_size, num_ip_images, seq_len, ip_embedding_len). - """ - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - # If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case, - # we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here. - assert ip_adapter_image_prompt_embeds is not None - assert len(ip_adapter_image_prompt_embeds) == len(self._weights) - - for ipa_embed, ipa_weights, scale in zip( - ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True - ): - # The batch dimensions should match. - assert ipa_embed.shape[0] == encoder_hidden_states.shape[0] - # The token_len dimensions should match. - assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1] - - ip_hidden_states = ipa_embed - - # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding) - - ip_key = ipa_weights.to_k_ip(ip_hidden_states) - ip_value = ipa_weights.to_v_ip(ip_hidden_states) - - # Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim) - - # TODO: add support for attn.scale when we move to Torch 2.1 - ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - # Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim) - - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - ip_hidden_states = ip_hidden_states.to(query.dtype) - - # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim) - - hidden_states = hidden_states + scale * ip_hidden_states - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 3370b71f027..278a53eb0f1 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -22,9 +22,12 @@ from invokeai.app.services.config.config_default import get_config from invokeai.backend.ip_adapter.ip_adapter import IPAdapter -from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + IPAdapterConditioningInfo, + TextConditioningData, +) from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent +from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher from invokeai.backend.util.attention import auto_detect_slice_size from invokeai.backend.util.devices import normalize_device @@ -151,10 +154,11 @@ class ControlNetData: @dataclass class IPAdapterData: - ip_adapter_model: IPAdapter = Field(default=None) - # TODO: change to polymorphic so can do different weights per step (once implemented...) + ip_adapter_model: IPAdapter + ip_adapter_conditioning: IPAdapterConditioningInfo + + # Either a single weight applied to all steps, or a list of weights for each step. weight: Union[float, List[float]] = Field(default=1.0) - # weight: float = Field(default=1.0) begin_step_percent: float = Field(default=0.0) end_step_percent: float = Field(default=1.0) @@ -295,7 +299,8 @@ def latents_from_embeddings( self, latents: torch.Tensor, num_inference_steps: int, - conditioning_data: ConditioningData, + scheduler_step_kwargs: dict[str, Any], + conditioning_data: TextConditioningData, *, noise: Optional[torch.Tensor], timesteps: torch.Tensor, @@ -352,6 +357,7 @@ def latents_from_embeddings( latents, timesteps, conditioning_data, + scheduler_step_kwargs=scheduler_step_kwargs, additional_guidance=additional_guidance, control_data=control_data, ip_adapter_data=ip_adapter_data, @@ -377,7 +383,8 @@ def generate_latents_from_embeddings( self, latents: torch.Tensor, timesteps, - conditioning_data: ConditioningData, + conditioning_data: TextConditioningData, + scheduler_step_kwargs: dict[str, Any], *, additional_guidance: List[Callable] = None, control_data: List[ControlNetData] = None, @@ -394,22 +401,35 @@ def generate_latents_from_embeddings( if timesteps.shape[0] == 0: return latents - ip_adapter_unet_patcher = None - extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning - if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: + extra_conditioning_info = conditioning_data.cond_text.extra_conditioning + use_cross_attention_control = ( + extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control + ) + use_ip_adapter = ip_adapter_data is not None + use_regional_prompting = ( + conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None + ) + if use_cross_attention_control and use_ip_adapter: + raise ValueError( + "Prompt-to-prompt cross-attention control (`.swap()`) and IP-Adapter cannot be used simultaneously." + ) + if use_cross_attention_control and use_regional_prompting: + raise ValueError( + "Prompt-to-prompt cross-attention control (`.swap()`) and regional prompting cannot be used simultaneously." + ) + + unet_attention_patcher = None + self.use_ip_adapter = use_ip_adapter + attn_ctx = nullcontext() + if use_cross_attention_control: attn_ctx = self.invokeai_diffuser.custom_attention_context( self.invokeai_diffuser.model, extra_conditioning_info=extra_conditioning_info, ) - self.use_ip_adapter = False - elif ip_adapter_data is not None: - # TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active? - # As it is now, the IP-Adapter will silently be skipped. - ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data]) - attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) - self.use_ip_adapter = True - else: - attn_ctx = nullcontext() + if use_ip_adapter or use_regional_prompting: + ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None + unet_attention_patcher = UNetAttentionPatcher(ip_adapters) + attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) with attn_ctx: if callback is not None: @@ -432,11 +452,12 @@ def generate_latents_from_embeddings( conditioning_data, step_index=i, total_step_count=len(timesteps), + scheduler_step_kwargs=scheduler_step_kwargs, additional_guidance=additional_guidance, control_data=control_data, ip_adapter_data=ip_adapter_data, t2i_adapter_data=t2i_adapter_data, - ip_adapter_unet_patcher=ip_adapter_unet_patcher, + unet_attention_patcher=unet_attention_patcher, ) latents = step_output.prev_sample predicted_original = getattr(step_output, "pred_original_sample", None) @@ -460,14 +481,15 @@ def step( self, t: torch.Tensor, latents: torch.Tensor, - conditioning_data: ConditioningData, + conditioning_data: TextConditioningData, step_index: int, total_step_count: int, + scheduler_step_kwargs: dict[str, Any], additional_guidance: List[Callable] = None, control_data: List[ControlNetData] = None, ip_adapter_data: Optional[list[IPAdapterData]] = None, t2i_adapter_data: Optional[list[T2IAdapterData]] = None, - ip_adapter_unet_patcher: Optional[UNetPatcher] = None, + unet_attention_patcher: Optional[UNetAttentionPatcher] = None, ): # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value timestep = t[0] @@ -494,10 +516,10 @@ def step( ) if step_index >= first_adapter_step and step_index <= last_adapter_step: # Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range. - ip_adapter_unet_patcher.set_scale(i, weight) + unet_attention_patcher.set_scale(i, weight) else: # Otherwise, set the IP-Adapter's scale to 0, so it has no effect. - ip_adapter_unet_patcher.set_scale(i, 0.0) + unet_attention_patcher.set_scale(i, 0.0) # Handle ControlNet(s) down_block_additional_residuals = None @@ -541,12 +563,17 @@ def step( down_intrablock_additional_residuals = accum_adapter_state + ip_adapter_conditioning = None + if ip_adapter_data is not None: + ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data] + uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step( sample=latent_model_input, timestep=t, # TODO: debug how handled batched and non batched timesteps step_index=step_index, total_step_count=total_step_count, conditioning_data=conditioning_data, + ip_adapter_conditioning=ip_adapter_conditioning, down_block_additional_residuals=down_block_additional_residuals, # for ControlNet mid_block_additional_residual=mid_block_additional_residual, # for ControlNet down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter @@ -566,7 +593,7 @@ def step( ) # compute the previous noisy sample x_t -> x_t-1 - step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args) + step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs) # TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting again. for guidance in additional_guidance: diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 7a7f4624c82..6ef6d68fca7 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -1,7 +1,5 @@ -import dataclasses -import inspect -from dataclasses import dataclass, field -from typing import Any, List, Optional, Union +from dataclasses import dataclass +from typing import List, Optional, Union import torch @@ -10,6 +8,10 @@ @dataclass class ExtraConditioningInfo: + """Extra conditioning information produced by Compel. + This is used for prompt-to-prompt cross-attention control (a.k.a. `.swap()` in Compel). + """ + tokens_count_including_eos_bos: int cross_attention_control_args: Optional[Arguments] = None @@ -20,6 +22,8 @@ def wants_cross_attention_control(self): @dataclass class BasicConditioningInfo: + """SD 1/2 text conditioning information produced by Compel.""" + embeds: torch.Tensor extra_conditioning: Optional[ExtraConditioningInfo] @@ -35,6 +39,8 @@ class ConditioningFieldData: @dataclass class SDXLConditioningInfo(BasicConditioningInfo): + """SDXL text conditioning information produced by Compel.""" + pooled_embeds: torch.Tensor add_time_ids: torch.Tensor @@ -57,37 +63,52 @@ class IPAdapterConditioningInfo: @dataclass -class ConditioningData: - unconditioned_embeddings: BasicConditioningInfo - text_embeddings: BasicConditioningInfo - """ - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). - Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate - images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - """ - guidance_scale: Union[float, List[float]] - """ for models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7 . - ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) - """ - guidance_rescale_multiplier: float = 0 - scheduler_args: dict[str, Any] = field(default_factory=dict) - - ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None - - @property - def dtype(self): - return self.text_embeddings.dtype - - def add_scheduler_args_if_applicable(self, scheduler, **kwargs): - scheduler_args = dict(self.scheduler_args) - step_method = inspect.signature(scheduler.step) - for name, value in kwargs.items(): - try: - step_method.bind_partial(**{name: value}) - except TypeError: - # FIXME: don't silently discard arguments - pass # debug("%s does not accept argument named %r", scheduler, name) - else: - scheduler_args[name] = value - return dataclasses.replace(self, scheduler_args=scheduler_args) +class Range: + start: int + end: int + + +class TextConditioningRegions: + def __init__( + self, + masks: torch.Tensor, + ranges: list[Range], + ): + # A binary mask indicating the regions of the image that the prompt should be applied to. + # Shape: (1, num_prompts, height, width) + # Dtype: torch.bool + self.masks = masks + + # A list of ranges indicating the start and end indices of the embeddings that corresponding mask applies to. + # ranges[i] contains the embedding range for the i'th prompt / mask. + self.ranges = ranges + + assert self.masks.shape[1] == len(self.ranges) + + +class TextConditioningData: + def __init__( + self, + uncond_text: Union[BasicConditioningInfo, SDXLConditioningInfo], + cond_text: Union[BasicConditioningInfo, SDXLConditioningInfo], + uncond_regions: Optional[TextConditioningRegions], + cond_regions: Optional[TextConditioningRegions], + guidance_scale: Union[float, List[float]], + guidance_rescale_multiplier: float = 0, + ): + self.uncond_text = uncond_text + self.cond_text = cond_text + self.uncond_regions = uncond_regions + self.cond_regions = cond_regions + # Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + # `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). + # Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate + # images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + self.guidance_scale = guidance_scale + # For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7. + # See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + self.guidance_rescale_multiplier = guidance_rescale_multiplier + + def is_sdxl(self): + assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo) + return isinstance(self.cond_text, SDXLConditioningInfo) diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py new file mode 100644 index 00000000000..667fcd9a645 --- /dev/null +++ b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py @@ -0,0 +1,199 @@ +from typing import Optional + +import torch +import torch.nn.functional as F +from diffusers.models.attention_processor import Attention, AttnProcessor2_0 + +from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights +from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData + + +class CustomAttnProcessor2_0(AttnProcessor2_0): + """A custom implementation of AttnProcessor2_0 that supports additional Invoke features. + This implementation is based on + https://github.com/huggingface/diffusers/blame/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204 + Supported custom features: + - IP-Adapter + - Regional prompt attention + """ + + def __init__( + self, + ip_adapter_weights: Optional[list[IPAttentionProcessorWeights]] = None, + ip_adapter_scales: Optional[list[float]] = None, + ): + """Initialize a CustomAttnProcessor2_0. + Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are + layer-specific are passed to __init__(). + Args: + ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights + for the i'th IP-Adapter. + ip_adapter_scales: The IP-Adapter attention scales. ip_adapter_scales[i] contains the attention scale for + the i'th IP-Adapter. + """ + super().__init__() + + self._ip_adapter_weights = ip_adapter_weights + self._ip_adapter_scales = ip_adapter_scales + + assert (self._ip_adapter_weights is None) == (self._ip_adapter_scales is None) + if self._ip_adapter_weights is not None: + assert len(ip_adapter_weights) == len(ip_adapter_scales) + + def _is_ip_adapter_enabled(self) -> bool: + return self._ip_adapter_weights is not None + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + # For regional prompting: + regional_prompt_data: Optional[RegionalPromptData] = None, + percent_through: Optional[torch.FloatTensor] = None, + # For IP-Adapter: + ip_adapter_image_prompt_embeds: Optional[list[torch.Tensor]] = None, + ) -> torch.FloatTensor: + """Apply attention. + Args: + regional_prompt_data: The regional prompt data for the current batch. If not None, this will be used to + apply regional prompt masking. + ip_adapter_image_prompt_embeds: The IP-Adapter image prompt embeddings for the current batch. + ip_adapter_image_prompt_embeds[i] contains the image prompt embeddings for the i'th IP-Adapter. Each + tensor has shape (batch_size, num_ip_images, seq_len, ip_embedding_len). + """ + # If true, we are doing cross-attention, if false we are doing self-attention. + is_cross_attention = encoder_hidden_states is not None + + # Start unmodified block from AttnProcessor2_0. + # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + # End unmodified block from AttnProcessor2_0. + + # Handle regional prompt attention masks. + if regional_prompt_data is not None and is_cross_attention: + assert percent_through is not None + _, query_seq_len, _ = hidden_states.shape + prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask( + query_seq_len=query_seq_len, key_seq_len=sequence_length + ) + + if attention_mask is None: + attention_mask = prompt_region_attention_mask + else: + attention_mask = prompt_region_attention_mask + attention_mask + + # Start unmodified block from AttnProcessor2_0. + # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + # End unmodified block from AttnProcessor2_0. + + # Apply IP-Adapter conditioning. + if is_cross_attention and self._is_ip_adapter_enabled(): + if self._is_ip_adapter_enabled(): + assert ip_adapter_image_prompt_embeds is not None + for ipa_embed, ipa_weights, scale in zip( + ip_adapter_image_prompt_embeds, self._ip_adapter_weights, self._ip_adapter_scales, strict=True + ): + # The batch dimensions should match. + assert ipa_embed.shape[0] == encoder_hidden_states.shape[0] + # The token_len dimensions should match. + assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1] + + ip_hidden_states = ipa_embed + + # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding) + + ip_key = ipa_weights.to_k_ip(ip_hidden_states) + ip_value = ipa_weights.to_v_ip(ip_hidden_states) + + # Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim) + + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + # Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim) + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim) + + hidden_states = hidden_states + scale * ip_hidden_states + else: + # If IP-Adapter is not enabled, then ip_adapter_image_prompt_embeds should not be passed in. + assert ip_adapter_image_prompt_embeds is None + + # Start unmodified block from AttnProcessor2_0. + # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py new file mode 100644 index 00000000000..85331013d5f --- /dev/null +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py @@ -0,0 +1,102 @@ +import torch +import torch.nn.functional as F + +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + TextConditioningRegions, +) + + +class RegionalPromptData: + """A class to manage the prompt data for regional conditioning.""" + + def __init__( + self, + regions: list[TextConditioningRegions], + device: torch.device, + dtype: torch.dtype, + max_downscale_factor: int = 8, + ): + """Initialize a `RegionalPromptData` object. + Args: + regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the + batch. + device (torch.device): The device to use for the attention masks. + dtype (torch.dtype): The data type to use for the attention masks. + max_downscale_factor: Spatial masks will be prepared for downscale factors from 1 to max_downscale_factor + in steps of 2x. + """ + self._regions = regions + self._device = device + self._dtype = dtype + # self._spatial_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query + # sequence length of s. + self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks( + regions, max_downscale_factor + ) + self._negative_cross_attn_mask_score = -10000.0 + + def _prepare_spatial_masks( + self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8 + ) -> list[dict[int, torch.Tensor]]: + """Prepare the spatial masks for all downscaling factors.""" + # batch_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query sequence length + # of s. + batch_sample_masks_by_seq_len: list[dict[int, torch.Tensor]] = [] + + for batch_sample_regions in regions: + batch_sample_masks_by_seq_len.append({}) + + batch_sample_masks = batch_sample_regions.masks.to(device=self._device, dtype=self._dtype) + + # Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached. + downscale_factor = 1 + while downscale_factor <= max_downscale_factor: + b, _num_prompts, h, w = batch_sample_masks.shape + assert b == 1 + query_seq_len = h * w + + batch_sample_masks_by_seq_len[-1][query_seq_len] = batch_sample_masks + + downscale_factor *= 2 + if downscale_factor <= max_downscale_factor: + # We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt + # regions to be lost entirely. + # TODO(ryand): In the future, we may want to experiment with other downsampling methods (e.g. + # nearest interpolation), and could potentially use a weighted mask rather than a binary mask. + batch_sample_masks = F.max_pool2d(batch_sample_masks, kernel_size=2, stride=2) + + return batch_sample_masks_by_seq_len + + def get_cross_attn_mask(self, query_seq_len: int, key_seq_len: int) -> torch.Tensor: + """Get the cross-attention mask for the given query sequence length. + Args: + query_seq_len: The length of the flattened spatial features at the current downscaling level. + key_seq_len (int): The sequence length of the prompt embeddings (which act as the key in the cross-attention + layers). This is most likely equal to the max embedding range end, but we pass it explicitly to be sure. + Returns: + torch.Tensor: The cross-attention score mask. + shape: (batch_size, query_seq_len, key_seq_len). + dtype: float + """ + batch_size = len(self._spatial_masks_by_seq_len) + batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)] + + # Create an empty attention mask with the correct shape. + attn_mask = torch.zeros((batch_size, query_seq_len, key_seq_len), dtype=self._dtype, device=self._device) + + for batch_idx in range(batch_size): + batch_sample_spatial_masks = batch_spatial_masks[batch_idx] + batch_sample_regions = self._regions[batch_idx] + + # Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1). + _, num_prompts, _, _ = batch_sample_spatial_masks.shape + batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1)) + + for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges): + batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :].clone() + batch_sample_query_mask = batch_sample_query_scores > 0.5 + batch_sample_query_scores[batch_sample_query_mask] = 0.0 + batch_sample_query_scores[~batch_sample_query_mask] = self._negative_cross_attn_mask_score + attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores + + return attn_mask diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index f55876623cd..4d95cb8f0df 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -10,10 +10,13 @@ from invokeai.app.services.config.config_default import get_config from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( - ConditioningData, ExtraConditioningInfo, - SDXLConditioningInfo, + IPAdapterConditioningInfo, + Range, + TextConditioningData, + TextConditioningRegions, ) +from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData from .cross_attention_control import ( CrossAttentionType, @@ -90,7 +93,7 @@ def do_controlnet_step( timestep: torch.Tensor, step_index: int, total_step_count: int, - conditioning_data, + conditioning_data: TextConditioningData, ): down_block_res_samples, mid_block_res_sample = None, None @@ -123,28 +126,28 @@ def do_controlnet_step( added_cond_kwargs = None if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned - if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: + if conditioning_data.is_sdxl(): added_cond_kwargs = { - "text_embeds": conditioning_data.text_embeddings.pooled_embeds, - "time_ids": conditioning_data.text_embeddings.add_time_ids, + "text_embeds": conditioning_data.cond_text.pooled_embeds, + "time_ids": conditioning_data.cond_text.add_time_ids, } - encoder_hidden_states = conditioning_data.text_embeddings.embeds + encoder_hidden_states = conditioning_data.cond_text.embeds encoder_attention_mask = None else: - if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: + if conditioning_data.is_sdxl(): added_cond_kwargs = { "text_embeds": torch.cat( [ # TODO: how to pad? just by zeros? or even truncate? - conditioning_data.unconditioned_embeddings.pooled_embeds, - conditioning_data.text_embeddings.pooled_embeds, + conditioning_data.uncond_text.pooled_embeds, + conditioning_data.cond_text.pooled_embeds, ], dim=0, ), "time_ids": torch.cat( [ - conditioning_data.unconditioned_embeddings.add_time_ids, - conditioning_data.text_embeddings.add_time_ids, + conditioning_data.uncond_text.add_time_ids, + conditioning_data.cond_text.add_time_ids, ], dim=0, ), @@ -153,8 +156,8 @@ def do_controlnet_step( encoder_hidden_states, encoder_attention_mask, ) = self._concat_conditionings_for_batch( - conditioning_data.unconditioned_embeddings.embeds, - conditioning_data.text_embeddings.embeds, + conditioning_data.uncond_text.embeds, + conditioning_data.cond_text.embeds, ) if isinstance(control_datum.weight, list): # if controlnet has multiple weights, use the weight for the current step @@ -198,16 +201,17 @@ def do_unet_step( self, sample: torch.Tensor, timestep: torch.Tensor, - conditioning_data: ConditioningData, + conditioning_data: TextConditioningData, + ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]], step_index: int, total_step_count: int, down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter ): + percent_through = step_index / total_step_count cross_attention_control_types_to_do = [] if self.cross_attention_control_context is not None: - percent_through = step_index / total_step_count cross_attention_control_types_to_do = ( self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through) ) @@ -223,6 +227,8 @@ def do_unet_step( x=sample, sigma=timestep, conditioning_data=conditioning_data, + ip_adapter_conditioning=ip_adapter_conditioning, + percent_through=percent_through, cross_attention_control_types_to_do=cross_attention_control_types_to_do, down_block_additional_residuals=down_block_additional_residuals, mid_block_additional_residual=mid_block_additional_residual, @@ -236,6 +242,8 @@ def do_unet_step( x=sample, sigma=timestep, conditioning_data=conditioning_data, + ip_adapter_conditioning=ip_adapter_conditioning, + percent_through=percent_through, down_block_additional_residuals=down_block_additional_residuals, mid_block_additional_residual=mid_block_additional_residual, down_intrablock_additional_residuals=down_intrablock_additional_residuals, @@ -296,7 +304,9 @@ def _apply_standard_conditioning( self, x, sigma, - conditioning_data: ConditioningData, + conditioning_data: TextConditioningData, + ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]], + percent_through: float, down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter @@ -307,40 +317,61 @@ def _apply_standard_conditioning( x_twice = torch.cat([x] * 2) sigma_twice = torch.cat([sigma] * 2) - cross_attention_kwargs = None - if conditioning_data.ip_adapter_conditioning is not None: + cross_attention_kwargs = {} + if ip_adapter_conditioning is not None: # Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len). - cross_attention_kwargs = { - "ip_adapter_image_prompt_embeds": [ - torch.stack( - [ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds] - ) - for ipa_conditioning in conditioning_data.ip_adapter_conditioning - ] - } + cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [ + torch.stack([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]) + for ipa_conditioning in ip_adapter_conditioning + ] added_cond_kwargs = None - if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: + if conditioning_data.is_sdxl(): added_cond_kwargs = { "text_embeds": torch.cat( [ # TODO: how to pad? just by zeros? or even truncate? - conditioning_data.unconditioned_embeddings.pooled_embeds, - conditioning_data.text_embeddings.pooled_embeds, + conditioning_data.uncond_text.pooled_embeds, + conditioning_data.cond_text.pooled_embeds, ], dim=0, ), "time_ids": torch.cat( [ - conditioning_data.unconditioned_embeddings.add_time_ids, - conditioning_data.text_embeddings.add_time_ids, + conditioning_data.uncond_text.add_time_ids, + conditioning_data.cond_text.add_time_ids, ], dim=0, ), } + if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None: + # TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings + # and masks are not changing from step-to-step, so this really only needs to be done once. While this seems + # painfully inefficient, the time spent is typically negligible compared to the forward inference pass of + # the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly + # awkward to handle both standard conditioning and sequential conditioning further up the stack. + regions = [] + for c, r in [ + (conditioning_data.uncond_text, conditioning_data.uncond_regions), + (conditioning_data.cond_text, conditioning_data.cond_regions), + ]: + if r is None: + # Create a dummy mask and range for text conditioning that doesn't have region masks. + _, _, h, w = x.shape + r = TextConditioningRegions( + masks=torch.ones((1, 1, h, w), dtype=x.dtype), + ranges=[Range(start=0, end=c.embeds.shape[1])], + ) + regions.append(r) + + cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( + regions=regions, device=x.device, dtype=x.dtype + ) + cross_attention_kwargs["percent_through"] = percent_through + both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch( - conditioning_data.unconditioned_embeddings.embeds, conditioning_data.text_embeddings.embeds + conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds ) both_results = self.model_forward_callback( x_twice, @@ -360,7 +391,9 @@ def _apply_standard_conditioning_sequentially( self, x: torch.Tensor, sigma, - conditioning_data: ConditioningData, + conditioning_data: TextConditioningData, + ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]], + percent_through: float, cross_attention_control_types_to_do: list[CrossAttentionType], down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet @@ -408,36 +441,40 @@ def _apply_standard_conditioning_sequentially( # Unconditioned pass ##################### - cross_attention_kwargs = None + cross_attention_kwargs = {} # Prepare IP-Adapter cross-attention kwargs for the unconditioned pass. - if conditioning_data.ip_adapter_conditioning is not None: + if ip_adapter_conditioning is not None: # Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len). - cross_attention_kwargs = { - "ip_adapter_image_prompt_embeds": [ - torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0) - for ipa_conditioning in conditioning_data.ip_adapter_conditioning - ] - } + cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [ + torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0) + for ipa_conditioning in ip_adapter_conditioning + ] # Prepare cross-attention control kwargs for the unconditioned pass. if cross_attn_processor_context is not None: - cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context} + cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context # Prepare SDXL conditioning kwargs for the unconditioned pass. added_cond_kwargs = None - is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo - if is_sdxl: + if conditioning_data.is_sdxl(): added_cond_kwargs = { - "text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds, - "time_ids": conditioning_data.unconditioned_embeddings.add_time_ids, + "text_embeds": conditioning_data.uncond_text.pooled_embeds, + "time_ids": conditioning_data.uncond_text.add_time_ids, } + # Prepare prompt regions for the unconditioned pass. + if conditioning_data.uncond_regions is not None: + cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( + regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype + ) + cross_attention_kwargs["percent_through"] = percent_through + # Run unconditioned UNet denoising (i.e. negative prompt). unconditioned_next_x = self.model_forward_callback( x, sigma, - conditioning_data.unconditioned_embeddings.embeds, + conditioning_data.uncond_text.embeds, cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=uncond_down_block, mid_block_additional_residual=uncond_mid_block, @@ -449,36 +486,41 @@ def _apply_standard_conditioning_sequentially( # Conditioned pass ################### - cross_attention_kwargs = None + cross_attention_kwargs = {} # Prepare IP-Adapter cross-attention kwargs for the conditioned pass. - if conditioning_data.ip_adapter_conditioning is not None: + if ip_adapter_conditioning is not None: # Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len). - cross_attention_kwargs = { - "ip_adapter_image_prompt_embeds": [ - torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0) - for ipa_conditioning in conditioning_data.ip_adapter_conditioning - ] - } + cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [ + torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0) + for ipa_conditioning in ip_adapter_conditioning + ] # Prepare cross-attention control kwargs for the conditioned pass. if cross_attn_processor_context is not None: cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do - cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context} + cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context # Prepare SDXL conditioning kwargs for the conditioned pass. added_cond_kwargs = None - if is_sdxl: + if conditioning_data.is_sdxl(): added_cond_kwargs = { - "text_embeds": conditioning_data.text_embeddings.pooled_embeds, - "time_ids": conditioning_data.text_embeddings.add_time_ids, + "text_embeds": conditioning_data.cond_text.pooled_embeds, + "time_ids": conditioning_data.cond_text.add_time_ids, } + # Prepare prompt regions for the conditioned pass. + if conditioning_data.cond_regions is not None: + cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( + regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype + ) + cross_attention_kwargs["percent_through"] = percent_through + # Run conditioned UNet denoising (i.e. positive prompt). conditioned_next_x = self.model_forward_callback( x, sigma, - conditioning_data.text_embeddings.embeds, + conditioning_data.cond_text.embeds, cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=cond_down_block, mid_block_additional_residual=cond_mid_block, diff --git a/invokeai/backend/ip_adapter/unet_patcher.py b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py similarity index 53% rename from invokeai/backend/ip_adapter/unet_patcher.py rename to invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py index f8c1870f6ee..364ec18da42 100644 --- a/invokeai/backend/ip_adapter/unet_patcher.py +++ b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py @@ -1,52 +1,54 @@ from contextlib import contextmanager +from typing import Optional from diffusers.models import UNet2DConditionModel -from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0 from invokeai.backend.ip_adapter.ip_adapter import IPAdapter +from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0 -class UNetPatcher: - """A class that contains multiple IP-Adapters and can apply them to a UNet.""" +class UNetAttentionPatcher: + """A class for patching a UNet with CustomAttnProcessor2_0 attention layers.""" - def __init__(self, ip_adapters: list[IPAdapter]): + def __init__(self, ip_adapters: Optional[list[IPAdapter]]): self._ip_adapters = ip_adapters - self._scales = [1.0] * len(self._ip_adapters) + self._ip_adapter_scales = None + + if self._ip_adapters is not None: + self._ip_adapter_scales = [1.0] * len(self._ip_adapters) def set_scale(self, idx: int, value: float): - self._scales[idx] = value + self._ip_adapter_scales[idx] = value def _prepare_attention_processors(self, unet: UNet2DConditionModel): """Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention - weights into them. - + weights into them (if IP-Adapters are being applied). Note that the `unet` param is only used to determine attention block dimensions and naming. """ # Construct a dict of attention processors based on the UNet's architecture. attn_procs = {} for idx, name in enumerate(unet.attn_processors.keys()): - if name.endswith("attn1.processor"): - attn_procs[name] = AttnProcessor2_0() + if name.endswith("attn1.processor") or self._ip_adapters is None: + # "attn1" processors do not use IP-Adapters. + attn_procs[name] = CustomAttnProcessor2_0() else: # Collect the weights from each IP Adapter for the idx'th attention processor. - attn_procs[name] = IPAttnProcessor2_0( + attn_procs[name] = CustomAttnProcessor2_0( [ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters], - self._scales, + self._ip_adapter_scales, ) return attn_procs @contextmanager def apply_ip_adapter_attention(self, unet: UNet2DConditionModel): - """A context manager that patches `unet` with IP-Adapter attention processors.""" - + """A context manager that patches `unet` with CustomAttnProcessor2_0 attention layers.""" attn_procs = self._prepare_attention_processors(unet) - orig_attn_processors = unet.attn_processors try: - # Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from the - # passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy - # of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`. + # Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from + # the passed dict. So, if you wanted to keep the dict for future use, you'd have to make a + # moderately-shallow copy of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`. unet.set_attn_processor(attn_procs) yield None finally: diff --git a/invokeai/backend/util/mask.py b/invokeai/backend/util/mask.py new file mode 100644 index 00000000000..45aa32061c2 --- /dev/null +++ b/invokeai/backend/util/mask.py @@ -0,0 +1,53 @@ +import torch + + +def to_standard_mask_dim(mask: torch.Tensor) -> torch.Tensor: + """Standardize the dimensions of a mask tensor. + + Args: + mask (torch.Tensor): A mask tensor. The shape can be (1, h, w) or (h, w). + + Returns: + torch.Tensor: The output mask tensor. The shape is (1, h, w). + """ + # Get the mask height and width. + if mask.ndim == 2: + mask = mask.unsqueeze(0) + elif mask.ndim == 3 and mask.shape[0] == 1: + pass + else: + raise ValueError(f"Unsupported mask shape: {mask.shape}. Expected (1, h, w) or (h, w).") + + return mask + + +def to_standard_float_mask(mask: torch.Tensor, out_dtype: torch.dtype) -> torch.Tensor: + """Standardize the format of a mask tensor. + + Args: + mask (torch.Tensor): A mask tensor. The dtype can be any bool, float, or int type. The shape must be (1, h, w) + or (h, w). + + out_dtype (torch.dtype): The dtype of the output mask tensor. Must be a float type. + + Returns: + torch.Tensor: The output mask tensor. The dtype is out_dtype. The shape is (1, h, w). All values are either 0.0 + or 1.0. + """ + + if not out_dtype.is_floating_point: + raise ValueError(f"out_dtype must be a float type, but got {out_dtype}") + + mask = to_standard_mask_dim(mask) + mask = mask.to(out_dtype) + + # Set masked regions to 1.0. + if mask.dtype == torch.bool: + mask = mask.to(out_dtype) + else: + mask = mask.to(out_dtype) + mask_region = mask > 0.5 + mask[mask_region] = 1.0 + mask[~mask_region] = 0.0 + + return mask diff --git a/tests/backend/ip_adapter/test_ip_adapter.py b/tests/backend/ip_adapter/test_ip_adapter.py index 9ed3c9bc507..138de398882 100644 --- a/tests/backend/ip_adapter/test_ip_adapter.py +++ b/tests/backend/ip_adapter/test_ip_adapter.py @@ -1,8 +1,8 @@ import pytest import torch -from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType +from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher from invokeai.backend.util.test_utils import install_and_load_model @@ -77,7 +77,7 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device): ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device) cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]} - ip_adapter_unet_patcher = UNetPatcher([ip_adapter]) + ip_adapter_unet_patcher = UNetAttentionPatcher([ip_adapter]) with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet): output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample diff --git a/tests/backend/util/test_mask.py b/tests/backend/util/test_mask.py new file mode 100644 index 00000000000..96d3aab07f9 --- /dev/null +++ b/tests/backend/util/test_mask.py @@ -0,0 +1,88 @@ +import pytest +import torch + +from invokeai.backend.util.mask import to_standard_float_mask + + +def test_to_standard_float_mask_wrong_ndim(): + with pytest.raises(ValueError): + to_standard_float_mask(mask=torch.zeros((1, 1, 5, 10)), out_dtype=torch.float32) + + +def test_to_standard_float_mask_wrong_shape(): + with pytest.raises(ValueError): + to_standard_float_mask(mask=torch.zeros((2, 5, 10)), out_dtype=torch.float32) + + +def check_mask_result(mask: torch.Tensor, expected_mask: torch.Tensor): + """Helper function to check the result of `to_standard_float_mask()`.""" + assert mask.shape == expected_mask.shape + assert mask.dtype == expected_mask.dtype + assert torch.allclose(mask, expected_mask) + + +def test_to_standard_float_mask_ndim_2(): + """Test the case where the input mask has shape (h, w).""" + mask = torch.zeros((3, 2), dtype=torch.float32) + mask[0, 0] = 1.0 + mask[1, 1] = 1.0 + + expected_mask = torch.zeros((1, 3, 2), dtype=torch.float32) + expected_mask[0, 0, 0] = 1.0 + expected_mask[0, 1, 1] = 1.0 + + new_mask = to_standard_float_mask(mask=mask, out_dtype=torch.float32) + + check_mask_result(mask=new_mask, expected_mask=expected_mask) + + +def test_to_standard_float_mask_ndim_3(): + """Test the case where the input mask has shape (1, h, w).""" + mask = torch.zeros((1, 3, 2), dtype=torch.float32) + mask[0, 0, 0] = 1.0 + mask[0, 1, 1] = 1.0 + + expected_mask = torch.zeros((1, 3, 2), dtype=torch.float32) + expected_mask[0, 0, 0] = 1.0 + expected_mask[0, 1, 1] = 1.0 + + new_mask = to_standard_float_mask(mask=mask, out_dtype=torch.float32) + + check_mask_result(mask=new_mask, expected_mask=expected_mask) + + +@pytest.mark.parametrize( + "out_dtype", + [torch.float32, torch.float16], +) +def test_to_standard_float_mask_bool_to_float(out_dtype: torch.dtype): + """Test the case where the input mask has dtype bool.""" + mask = torch.zeros((3, 2), dtype=torch.bool) + mask[0, 0] = True + mask[1, 1] = True + + expected_mask = torch.zeros((1, 3, 2), dtype=out_dtype) + expected_mask[0, 0, 0] = 1.0 + expected_mask[0, 1, 1] = 1.0 + + new_mask = to_standard_float_mask(mask=mask, out_dtype=out_dtype) + + check_mask_result(mask=new_mask, expected_mask=expected_mask) + + +@pytest.mark.parametrize( + "out_dtype", + [torch.float32, torch.float16], +) +def test_to_standard_float_mask_float_to_float(out_dtype: torch.dtype): + """Test the case where the input mask has type float (but not all values are 0.0 or 1.0).""" + mask = torch.zeros((3, 2), dtype=torch.float32) + mask[0, 0] = 0.1 # Should be converted to 0.0 + mask[0, 1] = 0.9 # Should be converted to 1.0 + + expected_mask = torch.zeros((1, 3, 2), dtype=out_dtype) + expected_mask[0, 0, 1] = 1.0 + + new_mask = to_standard_float_mask(mask=mask, out_dtype=out_dtype) + + check_mask_result(mask=new_mask, expected_mask=expected_mask)