Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use xformers for single frame model #47

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions diffusers_helper/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import Optional

import torch
from xformers.ops import memory_efficient_attention


class AttnProcessor2_0_xformers:
def __call__(
self,
attn,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim)

key = key.view(batch_size, -1, attn.heads, head_dim)
value = value.view(batch_size, -1, attn.heads, head_dim)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = memory_efficient_attention(
query, key, value, attention_mask, p=0.0
)

hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states
6 changes: 3 additions & 3 deletions gradio_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from diffusers_helper.code_cond import unet_add_coded_conds
from diffusers_helper.cat_cond import unet_add_concat_conds
from diffusers_helper.k_diffusion import KDiffusionSampler
from diffusers_helper.attention import AttnProcessor2_0_xformers
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers.models.attention_processor import AttnProcessor2_0
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers_vdm.pipeline import LatentVideoDiffusionPipeline
from diffusers_vdm.utils import resize_and_center_crop, save_bcthw_as_mp4
Expand All @@ -41,8 +41,8 @@ def from_config(cls, *args, **kwargs):
vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae").to(torch.bfloat16) # bfloat16 vae
unet = ModifiedUNet.from_pretrained(model_name, subfolder="unet").to(torch.float16)

unet.set_attn_processor(AttnProcessor2_0())
vae.set_attn_processor(AttnProcessor2_0())
unet.set_attn_processor(AttnProcessor2_0_xformers())
vae.set_attn_processor(AttnProcessor2_0_xformers())

video_pipe = LatentVideoDiffusionPipeline.from_pretrained(
'lllyasviel/paints_undo_multi_frame',
Expand Down