Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement CustomDiffusionAttnProcessor2_0. #4604

Merged
merged 9 commits into from
Sep 18, 2023
5 changes: 4 additions & 1 deletion docs/source/en/api/attnprocessor.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ An attention processor is a class for applying different types of attention mech
## CustomDiffusionAttnProcessor
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor

## CustomDiffusionAttnProcessor2_0
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor2_0

## AttnAddedKVProcessor
[[autodoc]] models.attention_processor.AttnAddedKVProcessor

Expand All @@ -39,4 +42,4 @@ An attention processor is a class for applying different types of attention mech
[[autodoc]] models.attention_processor.SlicedAttnProcessor

## SlicedAttnAddedKVProcessor
[[autodoc]] models.attention_processor.SlicedAttnAddedKVProcessor
[[autodoc]] models.attention_processor.SlicedAttnAddedKVProcessor
10 changes: 8 additions & 2 deletions examples/custom_diffusion/train_custom_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@
UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor
from diffusers.models.attention_processor import (
CustomDiffusionAttnProcessor,
CustomDiffusionAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
Expand Down Expand Up @@ -870,7 +874,9 @@ def main(args):
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)

attention_class = CustomDiffusionAttnProcessor
attention_class = (
CustomDiffusionAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else CustomDiffusionAttnProcessor
)
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
Expand Down
15 changes: 13 additions & 2 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ def save_attn_procs(
"""
from .models.attention_processor import (
CustomDiffusionAttnProcessor,
CustomDiffusionAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
)

Expand All @@ -578,15 +579,25 @@ def save_function(weights, filename):
os.makedirs(save_directory, exist_ok=True)

is_custom_diffusion = any(
isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
isinstance(
x,
(CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor),
)
for (_, x) in self.attn_processors.items()
)
if is_custom_diffusion:
model_to_save = AttnProcsLayers(
{
y: x
for (y, x) in self.attn_processors.items()
if isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
if isinstance(
x,
(
CustomDiffusionAttnProcessor,
CustomDiffusionAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
),
)
}
)
state_dict = model_to_save.state_dict()
Expand Down
116 changes: 114 additions & 2 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ def set_use_memory_efficient_attention_xformers(
LORA_ATTENTION_PROCESSORS,
)
is_custom_diffusion = hasattr(self, "processor") and isinstance(
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
self.processor,
(CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
)
is_added_kv_processor = hasattr(self, "processor") and isinstance(
self.processor,
Expand Down Expand Up @@ -261,7 +262,12 @@ def set_use_memory_efficient_attention_xformers(
processor.load_state_dict(self.processor.state_dict())
processor.to(self.processor.to_q_lora.up.weight.device)
elif is_custom_diffusion:
processor = CustomDiffusionAttnProcessor(
attn_processor_class = (
CustomDiffusionAttnProcessor2_0
if hasattr(F, "scaled_dot_product_attention")
else CustomDiffusionAttnProcessor
)
processor = attn_processor_class(
train_kv=self.processor.train_kv,
train_q_out=self.processor.train_q_out,
hidden_size=self.processor.hidden_size,
Expand Down Expand Up @@ -1156,6 +1162,111 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
return hidden_states


class CustomDiffusionAttnProcessor2_0(nn.Module):
r"""
Processor for implementing attention for the Custom Diffusion method using PyTorch 2.0’s memory-efficient scaled
dot-product attention.

Args:
train_kv (`bool`, defaults to `True`):
Whether to newly train the key and value matrices corresponding to the text features.
train_q_out (`bool`, defaults to `True`):
Whether to newly train query matrices corresponding to the latent image features.
hidden_size (`int`, *optional*, defaults to `None`):
The hidden size of the attention layer.
cross_attention_dim (`int`, *optional*, defaults to `None`):
The number of channels in the `encoder_hidden_states`.
out_bias (`bool`, defaults to `True`):
Whether to include the bias parameter in `train_q_out`.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability to use.
"""

def __init__(
self,
train_kv=True,
train_q_out=True,
hidden_size=None,
cross_attention_dim=None,
out_bias=True,
dropout=0.0,
):
super().__init__()
self.train_kv = train_kv
self.train_q_out = train_q_out

self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim

# `_custom_diffusion` id for easy serialization and loading.
if self.train_kv:
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
if self.train_q_out:
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
self.to_out_custom_diffusion = nn.ModuleList([])
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
self.to_out_custom_diffusion.append(nn.Dropout(dropout))

def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if self.train_q_out:
query = self.to_q_custom_diffusion(hidden_states)
else:
query = attn.to_q(hidden_states)

if encoder_hidden_states is None:
crossattn = False
encoder_hidden_states = hidden_states
else:
crossattn = True
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

if self.train_kv:
key = self.to_k_custom_diffusion(encoder_hidden_states)
value = self.to_v_custom_diffusion(encoder_hidden_states)
else:
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

if crossattn:
detach = torch.ones_like(key)
detach[:, :1, :] = detach[:, :1, :] * 0.0
key = detach * key + (1 - detach) * key.detach()
value = detach * value + (1 - detach) * value.detach()

inner_dim = hidden_states.shape[-1]

head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

if self.train_q_out:
# linear proj
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
# dropout
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
else:
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

return hidden_states


class SlicedAttnProcessor:
r"""
Processor for implementing sliced attention.
Expand Down Expand Up @@ -1639,6 +1750,7 @@ def __call__(self, attn: Attention, hidden_states, *args, **kwargs):
XFormersAttnAddedKVProcessor,
CustomDiffusionAttnProcessor,
CustomDiffusionXFormersAttnProcessor,
CustomDiffusionAttnProcessor2_0,
# depraceted
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
Expand Down