Skip to content

Commit

Permalink
Create a UNetAttentionPatcher for patching UNet models with CustomAtt…
Browse files Browse the repository at this point in the history
…nProcessor2_0 modules.
  • Loading branch information
RyanJDick committed Mar 8, 2024
1 parent d969be8 commit 7fb5e46
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 205 deletions.
182 changes: 0 additions & 182 deletions invokeai/backend/ip_adapter/attention_processor.py

This file was deleted.

6 changes: 3 additions & 3 deletions invokeai/backend/stable_diffusion/diffusers_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@

from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
IPAdapterConditioningInfo,
TextConditioningData,
)
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher

from ..util import auto_detect_slice_size, normalize_device

Expand Down Expand Up @@ -428,7 +428,7 @@ def generate_latents_from_embeddings(
elif ip_adapter_data is not None:
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
# As it is now, the IP-Adapter will silently be skipped.
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
ip_adapter_unet_patcher = UNetAttentionPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
self.use_ip_adapter = True
else:
Expand Down Expand Up @@ -492,7 +492,7 @@ def step(
control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
ip_adapter_unet_patcher: Optional[UNetPatcher] = None,
ip_adapter_unet_patcher: Optional[UNetAttentionPatcher] = None,
):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,52 +1,54 @@
from contextlib import contextmanager
from typing import Optional

from diffusers.models import UNet2DConditionModel

from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0


class UNetPatcher:
"""A class that contains multiple IP-Adapters and can apply them to a UNet."""
class UNetAttentionPatcher:
"""A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""

def __init__(self, ip_adapters: list[IPAdapter]):
def __init__(self, ip_adapters: Optional[list[IPAdapter]]):
self._ip_adapters = ip_adapters
self._scales = [1.0] * len(self._ip_adapters)
self._ip_adapter_scales = None

if self._ip_adapters is not None:
self._ip_adapter_scales = [1.0] * len(self._ip_adapters)

def set_scale(self, idx: int, value: float):
self._scales[idx] = value
self._ip_adapter_scales[idx] = value

def _prepare_attention_processors(self, unet: UNet2DConditionModel):
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
weights into them.
weights into them (if IP-Adapters are being applied).
Note that the `unet` param is only used to determine attention block dimensions and naming.
"""
# Construct a dict of attention processors based on the UNet's architecture.
attn_procs = {}
for idx, name in enumerate(unet.attn_processors.keys()):
if name.endswith("attn1.processor"):
attn_procs[name] = AttnProcessor2_0()
if name.endswith("attn1.processor") or self._ip_adapters is None:
# "attn1" processors do not use IP-Adapters.
attn_procs[name] = CustomAttnProcessor2_0()
else:
# Collect the weights from each IP Adapter for the idx'th attention processor.
attn_procs[name] = IPAttnProcessor2_0(
attn_procs[name] = CustomAttnProcessor2_0(
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
self._scales,
self._ip_adapter_scales,
)
return attn_procs

@contextmanager
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel):
"""A context manager that patches `unet` with IP-Adapter attention processors."""

"""A context manager that patches `unet` with CustomAttnProcessor2_0 attention layers."""
attn_procs = self._prepare_attention_processors(unet)

orig_attn_processors = unet.attn_processors

try:
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from the
# passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy
# of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from
# the passed dict. So, if you wanted to keep the dict for future use, you'd have to make a
# moderately-shallow copy of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
unet.set_attn_processor(attn_procs)
yield None
finally:
Expand Down
4 changes: 2 additions & 2 deletions tests/backend/ip_adapter/test_ip_adapter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import pytest
import torch

from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
from invokeai.backend.util.test_utils import install_and_load_model


Expand Down Expand Up @@ -77,7 +77,7 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device):
ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device)

cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]}
ip_adapter_unet_patcher = UNetPatcher([ip_adapter])
ip_adapter_unet_patcher = UNetAttentionPatcher([ip_adapter])
with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet):
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample

Expand Down

0 comments on commit 7fb5e46

Please sign in to comment.