diff --git a/invokeai/backend/ip_adapter/attention_processor.py b/invokeai/backend/ip_adapter/attention_processor.py deleted file mode 100644 index 195cb12d1b8..00000000000 --- a/invokeai/backend/ip_adapter/attention_processor.py +++ /dev/null @@ -1,182 +0,0 @@ -# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0) -# and modified as needed - -# tencent-ailab comment: -# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py -import torch -import torch.nn as nn -import torch.nn.functional as F -from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0 - -from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights - - -# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict -# loading. -class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module): - def __init__(self): - DiffusersAttnProcessor2_0.__init__(self) - nn.Module.__init__(self) - - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - ip_adapter_image_prompt_embeds=None, - ): - """Re-definition of DiffusersAttnProcessor2_0.__call__(...) that accepts and ignores the - ip_adapter_image_prompt_embeds parameter. - """ - return DiffusersAttnProcessor2_0.__call__( - self, attn, hidden_states, encoder_hidden_states, attention_mask, temb - ) - - -class IPAttnProcessor2_0(torch.nn.Module): - r""" - Attention processor for IP-Adapater for PyTorch 2.0. - Args: - hidden_size (`int`): - The hidden size of the attention layer. - cross_attention_dim (`int`): - The number of channels in the `encoder_hidden_states`. - scale (`float`, defaults to 1.0): - the weight scale of image prompt. - """ - - def __init__(self, weights: list[IPAttentionProcessorWeights], scales: list[float]): - super().__init__() - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - assert len(weights) == len(scales) - - self._weights = weights - self._scales = scales - - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - ip_adapter_image_prompt_embeds=None, - ): - """Apply IP-Adapter attention. - - Args: - ip_adapter_image_prompt_embeds (torch.Tensor): The image prompt embeddings. - Shape: (batch_size, num_ip_images, seq_len, ip_embedding_len). - """ - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - # If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case, - # we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here. - assert ip_adapter_image_prompt_embeds is not None - assert len(ip_adapter_image_prompt_embeds) == len(self._weights) - - for ipa_embed, ipa_weights, scale in zip( - ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True - ): - # The batch dimensions should match. - assert ipa_embed.shape[0] == encoder_hidden_states.shape[0] - # The token_len dimensions should match. - assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1] - - ip_hidden_states = ipa_embed - - # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding) - - ip_key = ipa_weights.to_k_ip(ip_hidden_states) - ip_value = ipa_weights.to_v_ip(ip_hidden_states) - - # Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim) - - # TODO: add support for attn.scale when we move to Torch 2.1 - ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - # Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim) - - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - ip_hidden_states = ip_hidden_states.to(query.dtype) - - # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim) - - hidden_states = hidden_states + scale * ip_hidden_states - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 882becc84bb..3c1787bf16d 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -23,12 +23,12 @@ from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.ip_adapter.ip_adapter import IPAdapter -from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( IPAdapterConditioningInfo, TextConditioningData, ) from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent +from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher from ..util import auto_detect_slice_size, normalize_device @@ -428,7 +428,7 @@ def generate_latents_from_embeddings( elif ip_adapter_data is not None: # TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active? # As it is now, the IP-Adapter will silently be skipped. - ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data]) + ip_adapter_unet_patcher = UNetAttentionPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data]) attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) self.use_ip_adapter = True else: @@ -492,7 +492,7 @@ def step( control_data: List[ControlNetData] = None, ip_adapter_data: Optional[list[IPAdapterData]] = None, t2i_adapter_data: Optional[list[T2IAdapterData]] = None, - ip_adapter_unet_patcher: Optional[UNetPatcher] = None, + ip_adapter_unet_patcher: Optional[UNetAttentionPatcher] = None, ): # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value timestep = t[0] diff --git a/invokeai/backend/ip_adapter/unet_patcher.py b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py similarity index 53% rename from invokeai/backend/ip_adapter/unet_patcher.py rename to invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py index f8c1870f6ee..364ec18da42 100644 --- a/invokeai/backend/ip_adapter/unet_patcher.py +++ b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py @@ -1,52 +1,54 @@ from contextlib import contextmanager +from typing import Optional from diffusers.models import UNet2DConditionModel -from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0 from invokeai.backend.ip_adapter.ip_adapter import IPAdapter +from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0 -class UNetPatcher: - """A class that contains multiple IP-Adapters and can apply them to a UNet.""" +class UNetAttentionPatcher: + """A class for patching a UNet with CustomAttnProcessor2_0 attention layers.""" - def __init__(self, ip_adapters: list[IPAdapter]): + def __init__(self, ip_adapters: Optional[list[IPAdapter]]): self._ip_adapters = ip_adapters - self._scales = [1.0] * len(self._ip_adapters) + self._ip_adapter_scales = None + + if self._ip_adapters is not None: + self._ip_adapter_scales = [1.0] * len(self._ip_adapters) def set_scale(self, idx: int, value: float): - self._scales[idx] = value + self._ip_adapter_scales[idx] = value def _prepare_attention_processors(self, unet: UNet2DConditionModel): """Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention - weights into them. - + weights into them (if IP-Adapters are being applied). Note that the `unet` param is only used to determine attention block dimensions and naming. """ # Construct a dict of attention processors based on the UNet's architecture. attn_procs = {} for idx, name in enumerate(unet.attn_processors.keys()): - if name.endswith("attn1.processor"): - attn_procs[name] = AttnProcessor2_0() + if name.endswith("attn1.processor") or self._ip_adapters is None: + # "attn1" processors do not use IP-Adapters. + attn_procs[name] = CustomAttnProcessor2_0() else: # Collect the weights from each IP Adapter for the idx'th attention processor. - attn_procs[name] = IPAttnProcessor2_0( + attn_procs[name] = CustomAttnProcessor2_0( [ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters], - self._scales, + self._ip_adapter_scales, ) return attn_procs @contextmanager def apply_ip_adapter_attention(self, unet: UNet2DConditionModel): - """A context manager that patches `unet` with IP-Adapter attention processors.""" - + """A context manager that patches `unet` with CustomAttnProcessor2_0 attention layers.""" attn_procs = self._prepare_attention_processors(unet) - orig_attn_processors = unet.attn_processors try: - # Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from the - # passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy - # of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`. + # Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from + # the passed dict. So, if you wanted to keep the dict for future use, you'd have to make a + # moderately-shallow copy of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`. unet.set_attn_processor(attn_procs) yield None finally: diff --git a/tests/backend/ip_adapter/test_ip_adapter.py b/tests/backend/ip_adapter/test_ip_adapter.py index 9ed3c9bc507..138de398882 100644 --- a/tests/backend/ip_adapter/test_ip_adapter.py +++ b/tests/backend/ip_adapter/test_ip_adapter.py @@ -1,8 +1,8 @@ import pytest import torch -from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType +from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher from invokeai.backend.util.test_utils import install_and_load_model @@ -77,7 +77,7 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device): ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device) cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]} - ip_adapter_unet_patcher = UNetPatcher([ip_adapter]) + ip_adapter_unet_patcher = UNetAttentionPatcher([ip_adapter]) with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet): output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample