Skip to content

Commit

Permalink
Update the diffusion logic to use the new regional prompting feature.
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanJDick committed Mar 10, 2024
1 parent 7fb5e46 commit 5be7ca4
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 46 deletions.
43 changes: 28 additions & 15 deletions invokeai/backend/stable_diffusion/diffusers_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,22 +417,35 @@ def generate_latents_from_embeddings(
if timesteps.shape[0] == 0:
return latents

ip_adapter_unet_patcher = None
extra_conditioning_info = conditioning_data.cond_text.extra_conditioning
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
use_cross_attention_control = (
extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
)
use_ip_adapter = ip_adapter_data is not None
use_regional_prompting = (
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
)
if use_cross_attention_control and use_ip_adapter:
raise ValueError(
"Prompt-to-prompt cross-attention control (`.swap()`) and IP-Adapter cannot be used simultaneously."
)
if use_cross_attention_control and use_regional_prompting:
raise ValueError(
"Prompt-to-prompt cross-attention control (`.swap()`) and regional prompting cannot be used simultaneously."
)

unet_attention_patcher = None
self.use_ip_adapter = use_ip_adapter
attn_ctx = nullcontext()
if use_cross_attention_control:
attn_ctx = self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model,
extra_conditioning_info=extra_conditioning_info,
)
self.use_ip_adapter = False
elif ip_adapter_data is not None:
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
# As it is now, the IP-Adapter will silently be skipped.
ip_adapter_unet_patcher = UNetAttentionPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
self.use_ip_adapter = True
else:
attn_ctx = nullcontext()
if use_ip_adapter or use_regional_prompting:
ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)

with attn_ctx:
if callback is not None:
Expand Down Expand Up @@ -460,7 +473,7 @@ def generate_latents_from_embeddings(
control_data=control_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
ip_adapter_unet_patcher=ip_adapter_unet_patcher,
unet_attention_patcher=unet_attention_patcher,
)
latents = step_output.prev_sample
predicted_original = getattr(step_output, "pred_original_sample", None)
Expand Down Expand Up @@ -492,7 +505,7 @@ def step(
control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
ip_adapter_unet_patcher: Optional[UNetAttentionPatcher] = None,
unet_attention_patcher: Optional[UNetAttentionPatcher] = None,
):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0]
Expand All @@ -515,10 +528,10 @@ def step(
)
if step_index >= first_adapter_step and step_index <= last_adapter_step:
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
ip_adapter_unet_patcher.set_scale(i, weight)
unet_attention_patcher.set_scale(i, weight)
else:
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
ip_adapter_unet_patcher.set_scale(i, 0.0)
unet_attention_patcher.set_scale(i, 0.0)

# Handle ControlNet(s)
down_block_additional_residuals = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,12 @@ def __call__(
# End unmodified block from AttnProcessor2_0.

# Handle regional prompt attention masks.
if regional_prompt_data is not None:
if regional_prompt_data is not None and is_cross_attention:
assert percent_through is not None
_, query_seq_len, _ = hidden_states.shape
if is_cross_attention:
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
query_seq_len=query_seq_len, key_seq_len=sequence_length
)
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
query_seq_len=query_seq_len, key_seq_len=sequence_length
)

if attention_mask is None:
attention_mask = prompt_region_attention_mask
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ExtraConditioningInfo,
IPAdapterConditioningInfo,
Range,
TextConditioningData,
TextConditioningRegions,
)
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData

from .cross_attention_control import (
CrossAttentionType,
Expand Down Expand Up @@ -206,9 +209,9 @@ def do_unet_step(
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
):
percent_through = step_index / total_step_count
cross_attention_control_types_to_do = []
if self.cross_attention_control_context is not None:
percent_through = step_index / total_step_count
cross_attention_control_types_to_do = (
self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through)
)
Expand All @@ -225,6 +228,7 @@ def do_unet_step(
sigma=timestep,
conditioning_data=conditioning_data,
ip_adapter_conditioning=ip_adapter_conditioning,
percent_through=percent_through,
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
Expand All @@ -239,6 +243,7 @@ def do_unet_step(
sigma=timestep,
conditioning_data=conditioning_data,
ip_adapter_conditioning=ip_adapter_conditioning,
percent_through=percent_through,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
Expand Down Expand Up @@ -301,6 +306,7 @@ def _apply_standard_conditioning(
sigma,
conditioning_data: TextConditioningData,
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
percent_through: float,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
Expand All @@ -311,17 +317,13 @@ def _apply_standard_conditioning(
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)

cross_attention_kwargs = None
cross_attention_kwargs = {}
if ip_adapter_conditioning is not None:
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
torch.stack(
[ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]
)
for ipa_conditioning in ip_adapter_conditioning
]
}
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
torch.stack([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds])
for ipa_conditioning in ip_adapter_conditioning
]

added_cond_kwargs = None
if conditioning_data.is_sdxl():
Expand All @@ -343,6 +345,31 @@ def _apply_standard_conditioning(
),
}

if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
regions = []
for c, r in [
(conditioning_data.uncond_text, conditioning_data.uncond_regions),
(conditioning_data.cond_text, conditioning_data.cond_regions),
]:
if r is None:
# Create a dummy mask and range for text conditioning that doesn't have region masks.
_, _, h, w = x.shape
r = TextConditioningRegions(
masks=torch.ones((1, 1, h, w), dtype=torch.bool),
ranges=[Range(start=0, end=c.embeds.shape[1])],
)
regions.append(r)

cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=regions, device=x.device, dtype=x.dtype
)
cross_attention_kwargs["percent_through"] = percent_through

both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
)
Expand All @@ -366,6 +393,7 @@ def _apply_standard_conditioning_sequentially(
sigma,
conditioning_data: TextConditioningData,
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
percent_through: float,
cross_attention_control_types_to_do: list[CrossAttentionType],
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
Expand Down Expand Up @@ -413,21 +441,19 @@ def _apply_standard_conditioning_sequentially(
# Unconditioned pass
#####################

cross_attention_kwargs = None
cross_attention_kwargs = {}

# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
if ip_adapter_conditioning is not None:
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
for ipa_conditioning in ip_adapter_conditioning
]
}
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
for ipa_conditioning in ip_adapter_conditioning
]

# Prepare cross-attention control kwargs for the unconditioned pass.
if cross_attn_processor_context is not None:
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context

# Prepare SDXL conditioning kwargs for the unconditioned pass.
added_cond_kwargs = None
Expand All @@ -437,6 +463,13 @@ def _apply_standard_conditioning_sequentially(
"time_ids": conditioning_data.uncond_text.add_time_ids,
}

# Prepare prompt regions for the unconditioned pass.
if conditioning_data.uncond_regions is not None:
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype
)
cross_attention_kwargs["percent_through"] = percent_through

# Run unconditioned UNet denoising (i.e. negative prompt).
unconditioned_next_x = self.model_forward_callback(
x,
Expand All @@ -453,22 +486,20 @@ def _apply_standard_conditioning_sequentially(
# Conditioned pass
###################

cross_attention_kwargs = None
cross_attention_kwargs = {}

# Prepare IP-Adapter cross-attention kwargs for the conditioned pass.
if ip_adapter_conditioning is not None:
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
for ipa_conditioning in ip_adapter_conditioning
]
}
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
for ipa_conditioning in ip_adapter_conditioning
]

# Prepare cross-attention control kwargs for the conditioned pass.
if cross_attn_processor_context is not None:
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context

# Prepare SDXL conditioning kwargs for the conditioned pass.
added_cond_kwargs = None
Expand All @@ -478,6 +509,13 @@ def _apply_standard_conditioning_sequentially(
"time_ids": conditioning_data.cond_text.add_time_ids,
}

# Prepare prompt regions for the conditioned pass.
if conditioning_data.cond_regions is not None:
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype
)
cross_attention_kwargs["percent_through"] = percent_through

# Run conditioned UNet denoising (i.e. positive prompt).
conditioned_next_x = self.model_forward_callback(
x,
Expand Down

0 comments on commit 5be7ca4

Please sign in to comment.