diff --git a/diffusers_helper/attention.py b/diffusers_helper/attention.py new file mode 100644 index 0000000..f1efd91 --- /dev/null +++ b/diffusers_helper/attention.py @@ -0,0 +1,81 @@ +from typing import Optional + +import torch +from xformers.ops import memory_efficient_attention + + +class AttnProcessor2_0_xformers: + def __call__( + self, + attn, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim) + + key = key.view(batch_size, -1, attn.heads, head_dim) + value = value.view(batch_size, -1, attn.heads, head_dim) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = memory_efficient_attention( + query, key, value, attention_mask, p=0.0 + ) + + hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states \ No newline at end of file diff --git a/gradio_app.py b/gradio_app.py index 831396b..d7fdbff 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -19,8 +19,8 @@ from diffusers_helper.code_cond import unet_add_coded_conds from diffusers_helper.cat_cond import unet_add_concat_conds from diffusers_helper.k_diffusion import KDiffusionSampler +from diffusers_helper.attention import AttnProcessor2_0_xformers from diffusers import AutoencoderKL, UNet2DConditionModel -from diffusers.models.attention_processor import AttnProcessor2_0 from transformers import CLIPTextModel, CLIPTokenizer from diffusers_vdm.pipeline import LatentVideoDiffusionPipeline from diffusers_vdm.utils import resize_and_center_crop, save_bcthw_as_mp4 @@ -41,8 +41,8 @@ def from_config(cls, *args, **kwargs): vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae").to(torch.bfloat16) # bfloat16 vae unet = ModifiedUNet.from_pretrained(model_name, subfolder="unet").to(torch.float16) -unet.set_attn_processor(AttnProcessor2_0()) -vae.set_attn_processor(AttnProcessor2_0()) +unet.set_attn_processor(AttnProcessor2_0_xformers()) +vae.set_attn_processor(AttnProcessor2_0_xformers()) video_pipe = LatentVideoDiffusionPipeline.from_pretrained( 'lllyasviel/paints_undo_multi_frame',