From 47e6ca846b567e7f0f774a5db45c81d9b90cf441 Mon Sep 17 00:00:00 2001 From: max Date: Wed, 15 Mar 2023 20:02:53 -0500 Subject: [PATCH 1/2] build a "linearized" unet using https://github.com/jansel/pytorch-jit-paritybench --- README.md | 2 +- examples/unet.py | 74 +- pi/models/unet.py | 11078 ++++++++++++++++++++++++++ pi/models/unet/__init__.py | 1 - pi/models/unet/attention.py | 519 -- pi/models/unet/cross_attention.py | 679 -- pi/models/unet/embeddings.py | 385 - pi/models/unet/resnet.py | 752 -- pi/models/unet/transformer_2d.py | 357 - pi/models/unet/unet_2d_blocks.py | 2312 ------ pi/models/unet/unet_2d_condition.py | 517 -- pyproject.toml | 6 + 12 files changed, 11138 insertions(+), 5544 deletions(-) create mode 100644 pi/models/unet.py delete mode 100644 pi/models/unet/__init__.py delete mode 100644 pi/models/unet/attention.py delete mode 100644 pi/models/unet/cross_attention.py delete mode 100644 pi/models/unet/embeddings.py delete mode 100644 pi/models/unet/resnet.py delete mode 100644 pi/models/unet/transformer_2d.py delete mode 100644 pi/models/unet/unet_2d_blocks.py delete mode 100644 pi/models/unet/unet_2d_condition.py diff --git a/README.md b/README.md index 7a101c1..3b55d81 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ Early days of a lightweight MLIR Python frontend with support for PyTorch (throu Just ```shell -pip install - requirements.txt +pip install -r requirements.txt pip install . --no-build-isolation ``` diff --git a/examples/unet.py b/examples/unet.py index ca51b1d..268fef9 100644 --- a/examples/unet.py +++ b/examples/unet.py @@ -1,27 +1,59 @@ -import pi -from pi import nn -from pi.mlir.utils import pipile -from pi.utils.annotations import annotate_args +import torch +import numpy as np + from pi.models.unet import UNet2DConditionModel +import torch_mlir + +unet = UNet2DConditionModel( + **{ + "block_out_channels": (32, 64), + "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"), + "up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"), + "cross_attention_dim": 32, + "attention_head_dim": 8, + "out_channels": 4, + "in_channels": 4, + "layers_per_block": 2, + "sample_size": 32, + } +) +unet.eval() + +batch_size = 4 +num_channels = 4 +sizes = (32, 32) + + +def floats_tensor(shape, scale=1.0, rng=None, name=None): + + total_dims = 1 + for dim in shape: + total_dims *= dim + + values = [] + for _ in range(total_dims): + values.append(np.random.random() * scale) + + return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() + +noise = floats_tensor((batch_size, num_channels) + sizes) +time_step = torch.tensor([10]) +encoder_hidden_states = floats_tensor((batch_size, 4, 32)) -class MyUNet(nn.Module): - def __init__(self): - super().__init__() - self.unet = UNet2DConditionModel() +output = unet(noise, time_step, encoder_hidden_states) +print(output) - @annotate_args( - [ - None, - ([-1, -1, -1, -1], pi.float32, True), - ] - ) - def forward(self, x): - y = self.resnet(x) - return y +traced = torch.jit.trace(unet, (noise, time_step, encoder_hidden_states), strict=False) +frozen = torch.jit.freeze(traced) +print(frozen.graph) -test_module = MyUNet() -x = pi.randn((1, 3, 64, 64)) -mlir_module = pipile(test_module, example_args=(x,)) -print(mlir_module) +module = torch_mlir.compile( + frozen, + (noise, time_step, encoder_hidden_states), + use_tracing=True, + output_type=torch_mlir.OutputType.RAW, +) +with open("unet.mlir", "w") as f: + f.write(str(module)) diff --git a/pi/models/unet.py b/pi/models/unet.py new file mode 100644 index 0000000..7491e73 --- /dev/null +++ b/pi/models/unet.py @@ -0,0 +1,11078 @@ +import importlib +import json +import logging +import sys +from pathlib import Path, PosixPath + +logger = logging.getLogger(__name__) + +import functools +import operator as op +import os +from dataclasses import dataclass, fields +from collections import OrderedDict, defaultdict +from copy import deepcopy +from functools import partial +from torch import Tensor, device +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from uuid import uuid4 +import inspect +import math +import numpy as np +import re +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +import warnings + + +class MakeCutouts(nn.Module): + def __init__(self, cut_size, cut_power=1.0): + super().__init__() + self.cut_size = cut_size + self.cut_power = cut_power + + def forward(self, pixel_values, num_cutouts): + sideY, sideX = pixel_values.shape[2:4] + max_size = min(sideX, sideY) + min_size = min(sideX, sideY, self.cut_size) + cutouts = [] + for _ in range(num_cutouts): + size = int( + torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size + ) + offsetx = torch.randint(0, sideX - size + 1, ()) + offsety = torch.randint(0, sideY - size + 1, ()) + cutout = pixel_values[ + :, :, offsety : offsety + size, offsetx : offsetx + size + ] + cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) + return torch.cat(cutouts) + + +class DiffusionUncond(nn.Module): + def __init__(self, global_args): + super().__init__() + self.diffusion = DiffusionAttnUnet1D(global_args, n_attn_layers=4) + self.diffusion_ema = deepcopy(self.diffusion) + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + +class AttnProcsLayers(torch.nn.Module): + def __init__(self, state_dict: "Dict[str, torch.Tensor]"): + super().__init__() + self.layers = torch.nn.ModuleList(state_dict.values()) + self.mapping = {k: v for k, v in enumerate(state_dict.keys())} + self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} + + def map_to(module, state_dict, *args, **kwargs): + new_state_dict = {} + for key, value in state_dict.items(): + num = int(key.split(".")[1]) + new_key = key.replace(f"layers.{num}", module.mapping[num]) + new_state_dict[new_key] = value + return new_state_dict + + def map_from(module, state_dict, *args, **kwargs): + all_keys = list(state_dict.keys()) + for key in all_keys: + replace_key = key.split(".processor")[0] + ".processor" + new_key = key.replace( + replace_key, f"layers.{module.rev_mapping[replace_key]}" + ) + state_dict[new_key] = state_dict[key] + del state_dict[key] + + self._register_state_dict_hook(map_to) + self._register_load_state_dict_pre_hook(map_from, with_module=True) + + +def is_xformers_available(): + return _xformers_available + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted + to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + Uses three q, k, v linear layers to compute attention. + + Parameters: + channels (`int`): The number of channels in the input and output. + num_head_channels (`int`, *optional*): + The number of channels in each head. If None, then `num_heads` = 1. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm. + rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by. + eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. + """ + + def __init__( + self, + channels: "int", + num_head_channels: "Optional[int]" = None, + norm_num_groups: "int" = 32, + rescale_output_factor: "float" = 1.0, + eps: "float" = 1e-05, + ): + super().__init__() + self.channels = channels + self.num_heads = ( + channels // num_head_channels if num_head_channels is not None else 1 + ) + self.num_head_size = num_head_channels + self.group_norm = nn.GroupNorm( + num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True + ) + self.query = nn.Linear(channels, channels) + self.key = nn.Linear(channels, channels) + self.value = nn.Linear(channels, channels) + self.rescale_output_factor = rescale_output_factor + self.proj_attn = nn.Linear(channels, channels, 1) + self._use_memory_efficient_attention_xformers = False + self._attention_op = None + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.num_heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape( + batch_size * head_size, seq_len, dim // head_size + ) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.num_heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape( + batch_size // head_size, seq_len, dim * head_size + ) + return tensor + + def set_use_memory_efficient_attention_xformers( + self, + use_memory_efficient_attention_xformers: "bool", + attention_op: "Optional[Callable]" = None, + ): + if use_memory_efficient_attention_xformers: + if not is_xformers_available(): + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only available for GPU " + ) + else: + try: + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self._use_memory_efficient_attention_xformers = ( + use_memory_efficient_attention_xformers + ) + self._attention_op = attention_op + + def forward(self, hidden_states): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + hidden_states = self.group_norm(hidden_states) + hidden_states = hidden_states.view(batch, channel, height * width).transpose( + 1, 2 + ) + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + scale = 1 / math.sqrt(self.channels / self.num_heads) + query_proj = self.reshape_heads_to_batch_dim(query_proj) + key_proj = self.reshape_heads_to_batch_dim(key_proj) + value_proj = self.reshape_heads_to_batch_dim(value_proj) + if self._use_memory_efficient_attention_xformers: + hidden_states = xformers.ops.memory_efficient_attention( + query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op + ) + hidden_states = hidden_states + else: + attention_scores = torch.baddbmm( + torch.empty( + query_proj.shape[0], + query_proj.shape[1], + key_proj.shape[1], + dtype=query_proj.dtype, + device=query_proj.device, + ), + query_proj, + key_proj.transpose(-1, -2), + beta=0, + alpha=scale, + ) + attention_probs = torch.softmax(attention_scores.float(), dim=-1).type( + attention_scores.dtype + ) + hidden_states = torch.bmm(attention_probs, value_proj) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + hidden_states = self.proj_attn(hidden_states) + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch, channel, height, width + ) + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + +class AdaLayerNorm(nn.Module): + """ + Norm layer modified to incorporate timestep embeddings. + """ + + def __init__(self, embedding_dim, num_embeddings): + super().__init__() + self.emb = nn.Embedding(num_embeddings, embedding_dim) + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, embedding_dim * 2) + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) + + def forward(self, x, timestep): + emb = self.linear(self.silu(self.emb(timestep))) + scale, shift = torch.chunk(emb, 2) + x = self.norm(x) * (1 + scale) + shift + return x + + +class LabelEmbedding(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + + Args: + num_classes (`int`): The number of classes. + hidden_size (`int`): The size of the vector embeddings. + dropout_prob (`float`): The probability of dropping a label. + """ + + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding( + num_classes + use_cfg_embedding, hidden_size + ) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = ( + torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + ) + else: + drop_ids = torch.tensor(force_drop_ids == 1) + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if self.training and use_dropout or force_drop_ids is not None: + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: "int", + time_embed_dim: "int", + act_fn: "str" = "silu", + out_dim: "int" = None, + post_act_fn: "Optional[str]" = None, + cond_proj_dim=None, + ): + super().__init__() + self.linear_1 = nn.Linear(in_channels, time_embed_dim) + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + if act_fn == "silu": + self.act = nn.SiLU() + elif act_fn == "mish": + self.act = nn.Mish() + elif act_fn == "gelu": + self.act = nn.GELU() + else: + raise ValueError( + f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'" + ) + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) + if post_act_fn is None: + self.post_act = None + elif post_act_fn == "silu": + self.post_act = nn.SiLU() + elif post_act_fn == "mish": + self.post_act = nn.Mish() + elif post_act_fn == "gelu": + self.post_act = nn.GELU() + else: + raise ValueError( + f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'" + ) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + if self.act is not None: + sample = self.act(sample) + sample = self.linear_2(sample) + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +def get_timestep_embedding( + timesteps: "torch.Tensor", + embedding_dim: "int", + flip_sin_to_cos: "bool" = False, + downscale_freq_shift: "float" = 1, + scale: "float" = 1, + max_period: "int" = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + emb = scale * emb + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class Timesteps(nn.Module): + def __init__( + self, + num_channels: "int", + flip_sin_to_cos: "bool", + downscale_freq_shift: "float", + ): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + ) + return t_emb + + +class CombinedTimestepLabelEmbeddings(nn.Module): + def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): + super().__init__() + self.time_proj = Timesteps( + num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1 + ) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, time_embed_dim=embedding_dim + ) + self.class_embedder = LabelEmbedding( + num_classes, embedding_dim, class_dropout_prob + ) + + def forward(self, timestep, class_labels, hidden_dtype=None): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj) + class_labels = self.class_embedder(class_labels) + conditioning = timesteps_emb + class_labels + return conditioning + + +class AdaLayerNormZero(nn.Module): + """ + Norm layer adaptive layer norm zero (adaLN-Zero). + """ + + def __init__(self, embedding_dim, num_embeddings): + super().__init__() + self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-06) + + def forward(self, x, timestep, class_labels, hidden_dtype=None): + emb = self.linear( + self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)) + ) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk( + 6, dim=1 + ) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +class AttnProcessor2_0: + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: "CrossAttention", + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + ): + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + inner_dim = hidden_states.shape[-1] + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + query = attn.to_q(hidden_states) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.cross_attention_norm: + encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class CrossAttnAddedKVProcessor: + def __call__( + self, + attn: "CrossAttention", + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + ): + residual = hidden_states + hidden_states = hidden_states.view( + hidden_states.shape[0], hidden_states.shape[1], -1 + ).transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.head_to_batch_dim( + encoder_hidden_states_key_proj + ) + encoder_hidden_states_value_proj = attn.head_to_batch_dim( + encoder_hidden_states_value_proj + ) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + return hidden_states + + +class CrossAttnProcessor: + def __call__( + self, + attn: "CrossAttention", + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + ): + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + query = attn.to_q(hidden_states) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.cross_attention_norm: + encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class LoRALinearLayer(nn.Module): + def __init__(self, in_features, out_features, rank=4): + super().__init__() + if rank > min(in_features, out_features): + raise ValueError( + f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}" + ) + self.down = nn.Linear(in_features, rank, bias=False) + self.up = nn.Linear(rank, out_features, bias=False) + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + down_hidden_states = self.down(hidden_states) + up_hidden_states = self.up(down_hidden_states) + return up_hidden_states + + +class LoRACrossAttnProcessor(nn.Module): + def __init__(self, hidden_size, cross_attention_dim=None, rank=4): + super().__init__() + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.rank = rank + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + self.to_k_lora = LoRALinearLayer( + cross_attention_dim or hidden_size, hidden_size, rank + ) + self.to_v_lora = LoRALinearLayer( + cross_attention_dim or hidden_size, hidden_size, rank + ) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + + def __call__( + self, + attn: "CrossAttention", + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + scale=1.0, + ): + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) + query = attn.head_to_batch_dim(query) + encoder_hidden_states = ( + encoder_hidden_states + if encoder_hidden_states is not None + else hidden_states + ) + key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora( + encoder_hidden_states + ) + value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora( + encoder_hidden_states + ) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora( + hidden_states + ) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class LoRAXFormersCrossAttnProcessor(nn.Module): + def __init__( + self, + hidden_size, + cross_attention_dim, + rank=4, + attention_op: "Optional[Callable]" = None, + ): + super().__init__() + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.rank = rank + self.attention_op = attention_op + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + self.to_k_lora = LoRALinearLayer( + cross_attention_dim or hidden_size, hidden_size, rank + ) + self.to_v_lora = LoRALinearLayer( + cross_attention_dim or hidden_size, hidden_size, rank + ) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + + def __call__( + self, + attn: "CrossAttention", + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + scale=1.0, + ): + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) + query = attn.head_to_batch_dim(query).contiguous() + encoder_hidden_states = ( + encoder_hidden_states + if encoder_hidden_states is not None + else hidden_states + ) + key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora( + encoder_hidden_states + ) + value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora( + encoder_hidden_states + ) + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + hidden_states = xformers.ops.memory_efficient_attention( + query, + key, + value, + attn_bias=attention_mask, + op=self.attention_op, + scale=attn.scale, + ) + hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora( + hidden_states + ) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class SlicedAttnAddedKVProcessor: + def __init__(self, slice_size): + self.slice_size = slice_size + + def __call__( + self, + attn: "'CrossAttention'", + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + ): + residual = hidden_states + hidden_states = hidden_states.view( + hidden_states.shape[0], hidden_states.shape[1], -1 + ).transpose(1, 2) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + query = attn.to_q(hidden_states) + dim = query.shape[-1] + query = attn.head_to_batch_dim(query) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + encoder_hidden_states_key_proj = attn.head_to_batch_dim( + encoder_hidden_states_key_proj + ) + encoder_hidden_states_value_proj = attn.head_to_batch_dim( + encoder_hidden_states_value_proj + ) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + batch_size_attention, query_tokens, _ = query.shape + hidden_states = torch.zeros( + (batch_size_attention, query_tokens, dim // attn.heads), + device=query.device, + dtype=query.dtype, + ) + for i in range(batch_size_attention // self.slice_size): + start_idx = i * self.slice_size + end_idx = (i + 1) * self.slice_size + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = ( + attention_mask[start_idx:end_idx] + if attention_mask is not None + else None + ) + attn_slice = attn.get_attention_scores( + query_slice, key_slice, attn_mask_slice + ) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + hidden_states[start_idx:end_idx] = attn_slice + hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + return hidden_states + + +class SlicedAttnProcessor: + def __init__(self, slice_size): + self.slice_size = slice_size + + def __call__( + self, + attn: "CrossAttention", + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + ): + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + query = attn.to_q(hidden_states) + dim = query.shape[-1] + query = attn.head_to_batch_dim(query) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.cross_attention_norm: + encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + batch_size_attention, query_tokens, _ = query.shape + hidden_states = torch.zeros( + (batch_size_attention, query_tokens, dim // attn.heads), + device=query.device, + dtype=query.dtype, + ) + for i in range(batch_size_attention // self.slice_size): + start_idx = i * self.slice_size + end_idx = (i + 1) * self.slice_size + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = ( + attention_mask[start_idx:end_idx] + if attention_mask is not None + else None + ) + attn_slice = attn.get_attention_scores( + query_slice, key_slice, attn_mask_slice + ) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + hidden_states[start_idx:end_idx] = attn_slice + hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class XFormersCrossAttnProcessor: + def __init__(self, attention_op: "Optional[Callable]" = None): + self.attention_op = attention_op + + def __call__( + self, + attn: "CrossAttention", + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + ): + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + query = attn.to_q(hidden_states) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.cross_attention_norm: + encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + hidden_states = xformers.ops.memory_efficient_attention( + query, + key, + value, + attn_bias=attention_mask, + op=self.attention_op, + scale=attn.scale, + ) + hidden_states = hidden_states + hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True): + deprecated_kwargs = take_from + values = () + if not isinstance(args[0], tuple): + args = (args,) + for attribute, version_name, message in args: + if version.parse(version.parse(__version__).base_version) >= version.parse( + version_name + ): + raise ValueError( + f"The deprecation tuple {attribute, version_name, message} should be removed since diffusers' version {__version__} is >= {version_name}" + ) + warning = None + if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs: + values += (deprecated_kwargs.pop(attribute),) + warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}." + elif hasattr(deprecated_kwargs, attribute): + values += (getattr(deprecated_kwargs, attribute),) + warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}." + elif deprecated_kwargs is None: + warning = f"`{attribute}` is deprecated and will be removed in version {version_name}." + if warning is not None: + warning = warning + " " if standard_warn else "" + warnings.warn(warning + message, FutureWarning, stacklevel=2) + if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: + call_frame = inspect.getouterframes(inspect.currentframe())[1] + filename = call_frame.filename + line_number = call_frame.lineno + function = call_frame.function + key, value = next(iter(deprecated_kwargs.items())) + raise TypeError( + f"{function} in {filename} line {line_number - 1} got an unexpected keyword argument `{key}`" + ) + if len(values) == 0: + return + elif len(values) == 1: + return values[0] + return values + + +class CrossAttention(nn.Module): + """ + A cross attention layer. + + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: "int", + cross_attention_dim: "Optional[int]" = None, + heads: "int" = 8, + dim_head: "int" = 64, + dropout: "float" = 0.0, + bias=False, + upcast_attention: "bool" = False, + upcast_softmax: "bool" = False, + cross_attention_norm: "bool" = False, + added_kv_proj_dim: "Optional[int]" = None, + norm_num_groups: "Optional[int]" = None, + out_bias: "bool" = True, + scale_qk: "bool" = True, + processor: "Optional['AttnProcessor']" = None, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = ( + cross_attention_dim if cross_attention_dim is not None else query_dim + ) + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.cross_attention_norm = cross_attention_norm + self.scale = dim_head ** -0.5 if scale_qk else 1.0 + self.heads = heads + self.sliceable_head_dim = heads + self.added_kv_proj_dim = added_kv_proj_dim + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm( + num_channels=inner_dim, + num_groups=norm_num_groups, + eps=1e-05, + affine=True, + ) + else: + self.group_norm = None + if cross_attention_norm: + self.norm_cross = nn.LayerNorm(cross_attention_dim) + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + if processor is None: + processor = ( + AttnProcessor2_0() + if hasattr(F, "scaled_dot_product_attention") and scale_qk + else CrossAttnProcessor() + ) + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( + self, + use_memory_efficient_attention_xformers: "bool", + attention_op: "Optional[Callable]" = None, + ): + is_lora = hasattr(self, "processor") and isinstance( + self.processor, (LoRACrossAttnProcessor, LoRAXFormersCrossAttnProcessor) + ) + if use_memory_efficient_attention_xformers: + if self.added_kv_proj_dim is not None: + raise NotImplementedError( + "Memory efficient attention with `xformers` is currently not supported when `self.added_kv_proj_dim` is defined." + ) + elif not is_xformers_available(): + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only available for GPU " + ) + else: + try: + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + if is_lora: + processor = LoRAXFormersCrossAttnProcessor( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + processor + else: + processor = XFormersCrossAttnProcessor(attention_op=attention_op) + elif is_lora: + processor = LoRACrossAttnProcessor( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + ) + processor.load_state_dict(self.processor.state_dict()) + processor + else: + processor = CrossAttnProcessor() + self.set_processor(processor) + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError( + f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}." + ) + if slice_size is not None and self.added_kv_proj_dim is not None: + processor = SlicedAttnAddedKVProcessor(slice_size) + elif slice_size is not None: + processor = SlicedAttnProcessor(slice_size) + elif self.added_kv_proj_dim is not None: + processor = CrossAttnAddedKVProcessor() + else: + processor = CrossAttnProcessor() + self.set_processor(processor) + + def set_processor(self, processor: "'AttnProcessor'"): + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info( + f"You are removing possibly trained weights of {self.processor} with {processor}" + ) + self._modules.pop("processor") + self.processor = processor + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + **cross_attention_kwargs, + ): + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor): + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape( + batch_size // head_size, seq_len, dim * head_size + ) + return tensor + + def head_to_batch_dim(self, tensor): + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape( + batch_size * head_size, seq_len, dim // head_size + ) + return tensor + + def get_attention_scores(self, query, key, attention_mask=None): + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], + query.shape[1], + key.shape[1], + dtype=query.dtype, + device=query.device, + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + attention_scores = torch.baddbmm( + baddbmm_input, query, key.transpose(-1, -2), beta=beta, alpha=self.scale + ) + if self.upcast_softmax: + attention_scores = attention_scores.float() + attention_probs = attention_scores.softmax(dim=-1) + attention_probs = attention_probs + return attention_probs + + def prepare_attention_mask(self, attention_mask, target_length, batch_size=None): + if batch_size is None: + deprecate( + "batch_size=None", + "0.0.15", + "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to `prepare_attention_mask` when preparing the attention_mask.", + ) + batch_size = 1 + head_size = self.heads + if attention_mask is None: + return attention_mask + if attention_mask.shape[-1] != target_length: + if attention_mask.device.type == "mps": + padding_shape = ( + attention_mask.shape[0], + attention_mask.shape[1], + target_length, + ) + padding = torch.zeros( + padding_shape, + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + return attention_mask + + +class ApproximateGELU(nn.Module): + """ + The approximate form of Gaussian Error Linear Unit (GELU) + + For more details, see section 2: https://arxiv.org/abs/1606.08415 + """ + + def __init__(self, dim_in: "int", dim_out: "int"): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + + def forward(self, x): + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) + + +class GEGLU(nn.Module): + """ + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: "int", dim_out: "int"): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + return F.gelu(gate.to(dtype=torch.float32)) + + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class GELU(nn.Module): + """ + GELU activation function with tanh approximation support with `approximate="tanh"`. + """ + + def __init__(self, dim_in: "int", dim_out: "int", approximate: "str" = "none"): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + self.approximate = approximate + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate, approximate=self.approximate) + return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class FeedForward(nn.Module): + """ + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + + def __init__( + self, + dim: "int", + dim_out: "Optional[int]" = None, + mult: "int" = 4, + dropout: "float" = 0.0, + activation_fn: "str" = "geglu", + final_dropout: "bool" = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + self.net = nn.ModuleList([]) + self.net.append(act_fn) + self.net.append(nn.Dropout(dropout)) + self.net.append(nn.Linear(inner_dim, dim_out)) + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class BasicTransformerBlock(nn.Module): + """ + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: "int", + num_attention_heads: "int", + attention_head_dim: "int", + dropout=0.0, + cross_attention_dim: "Optional[int]" = None, + activation_fn: "str" = "geglu", + num_embeds_ada_norm: "Optional[int]" = None, + attention_bias: "bool" = False, + only_cross_attention: "bool" = False, + upcast_attention: "bool" = False, + norm_elementwise_affine: "bool" = True, + norm_type: "str" = "layer_norm", + final_dropout: "bool" = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm_zero = ( + num_embeds_ada_norm is not None and norm_type == "ada_norm_zero" + ) + self.use_ada_layer_norm = ( + num_embeds_ada_norm is not None and norm_type == "ada_norm" + ) + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + self.attn1 = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + ) + if cross_attention_dim is not None: + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + else: + self.attn2 = None + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + if cross_attention_dim is not None: + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + else: + self.norm2 = None + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + def forward( + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + cross_attention_kwargs=None, + class_labels=None, + ): + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + cross_attention_kwargs = ( + cross_attention_kwargs if cross_attention_kwargs is not None else {} + ) + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states + if self.only_cross_attention + else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm2(hidden_states) + ) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + norm_hidden_states = self.norm3(hidden_states) + if self.use_ada_layer_norm_zero: + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) + ff_output = self.ff(norm_hidden_states) + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + hidden_states = ff_output + hidden_states + return hidden_states + + +class AdaGroupNorm(nn.Module): + """ + GroupNorm layer modified to incorporate timestep embeddings. + """ + + def __init__( + self, + embedding_dim: "int", + out_dim: "int", + num_groups: "int", + act_fn: "Optional[str]" = None, + eps: "float" = 1e-05, + ): + super().__init__() + self.num_groups = num_groups + self.eps = eps + self.act = None + if act_fn == "swish": + self.act = lambda x: F.silu(x) + elif act_fn == "mish": + self.act = nn.Mish() + elif act_fn == "silu": + self.act = nn.SiLU() + elif act_fn == "gelu": + self.act = nn.GELU() + self.linear = nn.Linear(embedding_dim, out_dim * 2) + + def forward(self, x, emb): + if self.act: + emb = self.act(emb) + emb = self.linear(emb) + emb = emb[:, :, None, None] + scale, shift = emb.chunk(2, dim=1) + x = F.group_norm(x, self.num_groups, eps=self.eps) + x = x * (1 + scale) + shift + return x + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +class ControlNetConditioningEmbedding(nn.Module): + """ + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized + training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the + convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + """ + + def __init__( + self, + conditioning_embedding_channels: "int", + conditioning_channels: "int" = 3, + block_out_channels: "Tuple[int]" = (16, 32, 96, 256), + ): + super().__init__() + self.conv_in = nn.Conv2d( + conditioning_channels, block_out_channels[0], kernel_size=3, padding=1 + ) + self.blocks = nn.ModuleList([]) + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append( + nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1) + ) + self.blocks.append( + nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2) + ) + self.conv_out = zero_module( + nn.Conv2d( + block_out_channels[-1], + conditioning_embedding_channels, + kernel_size=3, + padding=1, + ) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + embedding = self.conv_out(embedding) + return embedding + + +STR_OPERATION_TO_FUNC = { + ">": op.gt, + ">=": op.ge, + "==": op.eq, + "!=": op.ne, + "<=": op.le, + "<": op.lt, +} + + +def compare_versions( + library_or_version: "Union[str, Version]", + operation: "str", + requirement_version: "str", +): + """ + Args: + Compares a library version to some requirement using a given operation. + library_or_version (`str` or `packaging.version.Version`): + A library name or a version to check. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="`. + requirement_version (`str`): + The version to compare the library version against + """ + if operation not in STR_OPERATION_TO_FUNC.keys(): + raise ValueError( + f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}" + ) + operation = STR_OPERATION_TO_FUNC[operation] + if isinstance(library_or_version, str): + library_or_version = parse(importlib_metadata.version(library_or_version)) + return operation(library_or_version, parse(requirement_version)) + + +def is_transformers_version(operation: "str", version: "str"): + """ + Args: + Compares the current Transformers version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _transformers_available: + return False + return compare_versions(parse(_transformers_version), operation, version) + + +def requires_backends(obj, backends): + if not isinstance(backends, (list, tuple)): + backends = [backends] + name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ + checks = (BACKENDS_MAPPING[backend] for backend in backends) + failed = [msg.format(name) for available, msg in checks if not available()] + if failed: + raise ImportError("".join(failed)) + if ( + name + in [ + "VersatileDiffusionTextToImagePipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionDualGuidedPipeline", + "StableDiffusionImageVariationPipeline", + "UnCLIPPipeline", + ] + and is_transformers_version("<", "4.25.0") + ): + raise ImportError( + f"You need to install `transformers>=4.25` in order to use {name}: \n```\n pip install --upgrade transformers \n```" + ) + if ( + name + in [ + "StableDiffusionDepth2ImgPipeline", + "StableDiffusionPix2PixZeroPipeline", + ] + and is_transformers_version("<", "4.26.0") + ): + raise ImportError( + f"You need to install `transformers>=4.26` in order to use {name}: \n```\n pip install --upgrade transformers \n```" + ) + + +class DummyObject(type): + """ + Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by + `requires_backend` each time a user tries to access any method of that class. + """ + + def __getattr__(cls, key): + if key.startswith("_"): + return super().__getattr__(cls, key) + requires_backends(cls, cls._backends) + + +class FrozenDict(OrderedDict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + for key, value in self.items(): + setattr(self, key, value) + self.__frozen = True + + def __delitem__(self, *args, **kwargs): + raise Exception( + f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance." + ) + + def setdefault(self, *args, **kwargs): + raise Exception( + f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance." + ) + + def pop(self, *args, **kwargs): + raise Exception( + f"You cannot use ``pop`` on a {self.__class__.__name__} instance." + ) + + def update(self, *args, **kwargs): + raise Exception( + f"You cannot use ``update`` on a {self.__class__.__name__} instance." + ) + + def __setattr__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception( + f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance." + ) + super().__setattr__(name, value) + + def __setitem__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception( + f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance." + ) + super().__setitem__(name, value) + + +HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" + + +def extract_commit_hash( + resolved_file: "Optional[str]", commit_hash: "Optional[str]" = None +): + """ + Extracts the commit hash from a resolved filename toward a cache file. + """ + if resolved_file is None or commit_hash is not None: + return commit_hash + resolved_file = str(Path(resolved_file).as_posix()) + search = re.search("snapshots/([^/]+)/", resolved_file) + if search is None: + return None + commit_hash = search.groups()[0] + return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None + + +ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} + +SESSION_ID = uuid4().hex + +_flax_version = "N/A" + +_jax_version = "N/A" + +_onnxruntime_version = "N/A" + +_torch_version = "N/A" + + +class ConfigMixin: + """ + Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all + methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with + - [`~ConfigMixin.from_config`] + - [`~ConfigMixin.save_config`] + + Class attributes: + - **config_name** (`str`) -- A filename under which the config should stored when calling + [`~ConfigMixin.save_config`] (should be overridden by parent class). + - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be + overridden by subclass). + - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass). + - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function + should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by + subclass). + """ + + config_name = None + ignore_for_config = [] + has_compatibles = False + _deprecated_kwargs = [] + + def register_to_config(self, **kwargs): + if self.config_name is None: + raise NotImplementedError( + f"Make sure that {self.__class__} has defined a class name `config_name`" + ) + kwargs.pop("kwargs", None) + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error(f"Can't set {key} with value {value} for {self}") + raise err + if not hasattr(self, "_internal_dict"): + internal_dict = kwargs + else: + previous_dict = dict(self._internal_dict) + internal_dict = {**self._internal_dict, **kwargs} + logger.debug(f"Updating config from {previous_dict} to {internal_dict}") + self._internal_dict = FrozenDict(internal_dict) + + def save_config( + self, + save_directory: "Union[str, os.PathLike]", + push_to_hub: "bool" = False, + **kwargs, + ): + if os.path.isfile(save_directory): + raise AssertionError( + f"Provided path ({save_directory}) should be a directory, not a file" + ) + os.makedirs(save_directory, exist_ok=True) + output_config_file = os.path.join(save_directory, self.config_name) + self.to_json_file(output_config_file) + logger.info(f"Configuration saved in {output_config_file}") + + @classmethod + def from_config( + cls, + config: "Union[FrozenDict, Dict[str, Any]]" = None, + return_unused_kwargs=False, + **kwargs, + ): + if "pretrained_model_name_or_path" in kwargs: + config = kwargs.pop("pretrained_model_name_or_path") + if config is None: + raise ValueError( + "Please make sure to provide a config as the first positional argument." + ) + if not isinstance(config, dict): + deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`." + if "Scheduler" in cls.__name__: + deprecation_message += f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead. Otherwise, please make sure to pass a configuration dictionary instead. This functionality will be removed in v1.0.0." + elif "Model" in cls.__name__: + deprecation_message += f"If you were trying to load a model, please use {cls}.load_config(...) followed by {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary instead. This functionality will be removed in v1.0.0." + deprecate( + "config-passed-as-path", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + config, kwargs = cls.load_config( + pretrained_model_name_or_path=config, + return_unused_kwargs=True, + **kwargs, + ) + init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs) + if "dtype" in unused_kwargs: + init_dict["dtype"] = unused_kwargs.pop("dtype") + for deprecated_kwarg in cls._deprecated_kwargs: + if deprecated_kwarg in unused_kwargs: + init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg) + model = cls(**init_dict) + model.register_to_config(**hidden_dict) + unused_kwargs = {**unused_kwargs, **hidden_dict} + if return_unused_kwargs: + return model, unused_kwargs + else: + return model + + @classmethod + def get_config_dict(cls, *args, **kwargs): + deprecation_message = f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be removed in version v1.0.0" + deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False) + return cls.load_config(*args, **kwargs) + + @classmethod + def load_config( + cls, + pretrained_model_name_or_path: "Union[str, os.PathLike]", + return_unused_kwargs=False, + return_commit_hash=False, + **kwargs, + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + Instantiate a Python class from a config dictionary + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an + organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g., + `./my_model_directory/`. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + return_unused_kwargs (`bool`, *optional*, defaults to `False): + Whether unused keyword arguments of the config shall be returned. + return_commit_hash (`bool`, *optional*, defaults to `False): + Whether the commit_hash of the loaded configuration shall be returned. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + use_auth_token = kwargs.pop("use_auth_token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + _ = kwargs.pop("mirror", None) + subfolder = kwargs.pop("subfolder", None) + user_agent = kwargs.pop("user_agent", {}) + user_agent = {**user_agent, "file_type": "config"} + user_agent = http_user_agent(user_agent) + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if cls.config_name is None: + raise ValueError( + "`self.config_name` is not defined. Note that one should not load a config from `ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`" + ) + if os.path.isfile(pretrained_model_name_or_path): + config_file = pretrained_model_name_or_path + elif os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile( + os.path.join(pretrained_model_name_or_path, cls.config_name) + ): + config_file = os.path.join( + pretrained_model_name_or_path, cls.config_name + ) + elif subfolder is not None and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) + ): + config_file = os.path.join( + pretrained_model_name_or_path, subfolder, cls.config_name + ) + else: + raise EnvironmentError( + f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." + ) + else: + pass + try: + config_dict = cls._dict_from_json_file(config_file) + commit_hash = extract_commit_hash(config_file) + except (json.JSONDecodeError, UnicodeDecodeError): + raise EnvironmentError( + f"It looks like the config file at '{config_file}' is not a valid JSON file." + ) + if not (return_unused_kwargs or return_commit_hash): + return config_dict + outputs = (config_dict,) + if return_unused_kwargs: + outputs += (kwargs,) + if return_commit_hash: + outputs += (commit_hash,) + return outputs + + @staticmethod + def _get_init_keys(cls): + return set(dict(inspect.signature(cls.__init__).parameters).keys()) + + @classmethod + def extract_init_dict(cls, config_dict, **kwargs): + original_dict = {k: v for k, v in config_dict.items()} + expected_keys = cls._get_init_keys(cls) + expected_keys.remove("self") + if "kwargs" in expected_keys: + expected_keys.remove("kwargs") + if hasattr(cls, "_flax_internal_args"): + for arg in cls._flax_internal_args: + expected_keys.remove(arg) + if len(cls.ignore_for_config) > 0: + expected_keys = expected_keys - set(cls.ignore_for_config) + diffusers_library = importlib.import_module(__name__.split(".")[0]) + if cls.has_compatibles: + compatible_classes = [ + c for c in cls._get_compatibles() if not isinstance(c, DummyObject) + ] + else: + compatible_classes = [] + expected_keys_comp_cls = set() + for c in compatible_classes: + expected_keys_c = cls._get_init_keys(c) + expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c) + expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls) + config_dict = { + k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls + } + orig_cls_name = config_dict.pop("_class_name", cls.__name__) + if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name): + orig_cls = getattr(diffusers_library, orig_cls_name) + unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys + config_dict = { + k: v + for k, v in config_dict.items() + if k not in unexpected_keys_from_orig + } + config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")} + init_dict = {} + for key in expected_keys: + if key in kwargs and key in config_dict: + config_dict[key] = kwargs.pop(key) + if key in kwargs: + init_dict[key] = kwargs.pop(key) + elif key in config_dict: + init_dict[key] = config_dict.pop(key) + if len(config_dict) > 0: + logger.warning( + f"The config attributes {config_dict} were passed to {cls.__name__}, but are not expected and will be ignored. Please verify your {cls.config_name} configuration file." + ) + passed_keys = set(init_dict.keys()) + if len(expected_keys - passed_keys) > 0: + logger.info( + f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values." + ) + unused_kwargs = {**config_dict, **kwargs} + hidden_config_dict = { + k: v for k, v in original_dict.items() if k not in init_dict + } + return init_dict, unused_kwargs, hidden_config_dict + + @classmethod + def _dict_from_json_file(cls, json_file: "Union[str, os.PathLike]"): + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + @property + def config(self) -> Dict[str, Any]: + """ + Returns the config of the class as a frozen dictionary + + Returns: + `Dict[str, Any]`: Config of the class. + """ + return self._internal_dict + + def to_json_string(self) -> str: + """ + Serializes this instance to a JSON string. + + Returns: + `str`: String containing all the attributes that make up this configuration instance in JSON format. + """ + config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {} + config_dict["_class_name"] = self.__class__.__name__ + config_dict["_diffusers_version"] = __version__ + + def to_json_saveable(value): + if isinstance(value, np.ndarray): + value = value.tolist() + elif isinstance(value, PosixPath): + value = str(value) + return value + + config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()} + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path: "Union[str, os.PathLike]"): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string()) + + +class ImagePositionalEmbeddings(nn.Module): + """ + Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the + height and width of the latent space. + + For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 + + For VQ-diffusion: + + Output vector embeddings are used as input for the transformer. + + Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. + + Args: + num_embed (`int`): + Number of embeddings for the latent pixels embeddings. + height (`int`): + Height of the latent image i.e. the number of height embeddings. + width (`int`): + Width of the latent image i.e. the number of width embeddings. + embed_dim (`int`): + Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. + """ + + def __init__(self, num_embed: "int", height: "int", width: "int", embed_dim: "int"): + super().__init__() + self.height = height + self.width = width + self.num_embed = num_embed + self.embed_dim = embed_dim + self.emb = nn.Embedding(self.num_embed, embed_dim) + self.height_emb = nn.Embedding(self.height, embed_dim) + self.width_emb = nn.Embedding(self.width, embed_dim) + + def forward(self, index): + emb = self.emb(index) + height_emb = self.height_emb( + torch.arange(self.height, device=index.device).view(1, self.height) + ) + height_emb = height_emb.unsqueeze(2) + width_emb = self.width_emb( + torch.arange(self.width, device=index.device).view(1, self.width) + ) + width_emb = width_emb.unsqueeze(1) + pos_emb = height_emb + width_emb + pos_emb = pos_emb.view(1, self.height * self.width, -1) + emb = emb + pos_emb[:, : emb.shape[1], :] + return emb + + +CONFIG_NAME = "config.json" + +FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" + +SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" + +WEIGHTS_NAME = "diffusion_pytorch_model.bin" + + +def _add_variant(weights_name: "str", variant: "Optional[str]" = None) -> str: + if variant is not None: + splits = weights_name.split(".") + splits = splits[:-1] + [variant] + splits[-1:] + weights_name = ".".join(splits) + return weights_name + + +DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] + + +def _get_model_file( + pretrained_model_name_or_path, + *, + weights_name, + subfolder, + cache_dir, + force_download, + proxies, + resume_download, + local_files_only, + use_auth_token, + user_agent, + revision, + commit_hash=None, +): + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isfile(pretrained_model_name_or_path): + return pretrained_model_name_or_path + elif os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): + model_file = os.path.join(pretrained_model_name_or_path, weights_name) + return model_file + elif subfolder is not None and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, weights_name) + ): + model_file = os.path.join( + pretrained_model_name_or_path, subfolder, weights_name + ) + return model_file + else: + raise EnvironmentError( + f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}." + ) + else: + if ( + revision in DEPRECATED_REVISION_ARGS + and ( + weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME + ) + and version.parse(version.parse(__version__).base_version) + >= version.parse("0.17.0") + ): + try: + model_file = hf_hub_download( + pretrained_model_name_or_path, + filename=_add_variant(weights_name, revision), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision or commit_hash, + ) + warnings.warn( + f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.", + FutureWarning, + ) + return model_file + except: + warnings.warn( + f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name, revision)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name, revision)}' so that the correct variant file can be added.", + FutureWarning, + ) + + +def _load_state_dict_into_model(model_to_load, state_dict): + state_dict = state_dict.copy() + error_msgs = [] + + def load(module: "torch.nn.Module", prefix=""): + args = state_dict, prefix, {}, True, [], [], error_msgs + module._load_from_state_dict(*args) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(model_to_load) + return error_msgs + + +def get_parameter_device(parameter: "torch.nn.Module"): + try: + return next(parameter.parameters()).device + except StopIteration: + + def find_tensor_attributes( + module: "torch.nn.Module", + ) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].device + + +def get_parameter_dtype(parameter: "torch.nn.Module"): + try: + return next(parameter.parameters()).dtype + except StopIteration: + + def find_tensor_attributes( + module: "torch.nn.Module", + ) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype + + +def is_torch_version(operation: "str", version: "str"): + """ + Args: + Compares the current PyTorch version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A string version of PyTorch + """ + return compare_versions(parse(_torch_version), operation, version) + + +def load_state_dict( + checkpoint_file: "Union[str, os.PathLike]", variant: "Optional[str]" = None +): + """ + Reads a checkpoint file, returning properly formatted errors if they arise. + """ + try: + if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant): + return torch.load(checkpoint_file, map_location="cpu") + else: + return safetensors.torch.load_file(checkpoint_file, device="cpu") + except Exception as e: + try: + with open(checkpoint_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please install git-lfs and run `git lfs install` followed by `git lfs pull` in the folder you cloned." + ) + else: + raise ValueError( + f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained model. Make sure you have saved the model properly." + ) from e + except (UnicodeDecodeError, ValueError): + raise OSError( + f"Unable to load weights from checkpoint file for '{checkpoint_file}' at '{checkpoint_file}'. If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True." + ) + + +class ModelMixin(torch.nn.Module): + """ + Base class for all models. + + [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading + and saving models. + + - **config_name** ([`str`]) -- A filename under which the model should be stored when calling + [`~models.ModelMixin.save_pretrained`]. + """ + + config_name = CONFIG_NAME + _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] + _supports_gradient_checkpointing = False + + def __init__(self): + super().__init__() + + @property + def is_gradient_checkpointing(self) -> bool: + """ + Whether gradient checkpointing is activated for this model or not. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + return any( + hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing + for m in self.modules() + ) + + def enable_gradient_checkpointing(self): + """ + Activates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + if not self._supports_gradient_checkpointing: + raise ValueError( + f"{self.__class__.__name__} does not support gradient checkpointing." + ) + self.apply(partial(self._set_gradient_checkpointing, value=True)) + + def disable_gradient_checkpointing(self): + """ + Deactivates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + if self._supports_gradient_checkpointing: + self.apply(partial(self._set_gradient_checkpointing, value=False)) + + def set_use_memory_efficient_attention_xformers( + self, valid: "bool", attention_op: "Optional[Callable]" = None + ) -> None: + def fn_recursive_set_mem_eff(module: "torch.nn.Module"): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid, attention_op) + for child in module.children(): + fn_recursive_set_mem_eff(child) + + for module in self.children(): + if isinstance(module, torch.nn.Module): + fn_recursive_set_mem_eff(module) + + def enable_xformers_memory_efficient_attention( + self, attention_op: "Optional[Callable]" = None + ): + self.set_use_memory_efficient_attention_xformers(True, attention_op) + + def disable_xformers_memory_efficient_attention(self): + """ + Disable memory efficient attention as implemented in xformers. + """ + self.set_use_memory_efficient_attention_xformers(False) + + def save_pretrained( + self, + save_directory: "Union[str, os.PathLike]", + is_main_process: "bool" = True, + save_function: "Callable" = None, + safe_serialization: "bool" = False, + variant: "Optional[str]" = None, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + `[`~models.ModelMixin.from_pretrained`]` class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `False`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + variant (`str`, *optional*): + If specified, weights are saved in the format pytorch_model..bin. + """ + if safe_serialization and not is_safetensors_available(): + raise ImportError( + "`safe_serialization` requires the `safetensors library: `pip install safetensors`." + ) + if os.path.isfile(save_directory): + logger.error( + f"Provided path ({save_directory}) should be a directory, not a file" + ) + return + os.makedirs(save_directory, exist_ok=True) + model_to_save = self + if is_main_process: + model_to_save.save_config(save_directory) + state_dict = model_to_save.state_dict() + weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + weights_name = _add_variant(weights_name, variant) + if safe_serialization: + safetensors.torch.save_file( + state_dict, + os.path.join(save_directory, weights_name), + metadata={"format": "pt"}, + ) + else: + torch.save(state_dict, os.path.join(save_directory, weights_name)) + logger.info( + f"Model weights saved in {os.path.join(save_directory, weights_name)}" + ) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: "Optional[Union[str, os.PathLike]]", + **kwargs, + ): + """ + Instantiate a pretrained pytorch model from a pre-trained model configuration. + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you should first set it back in training mode with `model.train()`. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., + `./my_model_directory/`. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `diffusers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + from_flax (`bool`, *optional*, defaults to `False`): + Load the model weights from a Flax checkpoint save file. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the + same device. + + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading by not initializing the weights and only loading the pre-trained weights. This + also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the + model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, + setting this argument to `True` will raise an error. + variant (`str`, *optional*): + If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is + ignored when using `from_flax`. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use + this method in a firewalled environment. + + + + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + from_flax = kwargs.pop("from_flax", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", None) + device_map = kwargs.pop("device_map", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + variant = kwargs.pop("variant", None) + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip install accelerate\n```\n." + ) + if device_map is not None and not is_accelerate_available(): + raise NotImplementedError( + "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set `device_map=None`. You can install accelerate with `pip install accelerate`." + ) + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set `device_map=None`." + ) + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set `low_cpu_mem_usage=False`." + ) + if low_cpu_mem_usage is False and device_map is not None: + raise ValueError( + f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and dispatching. Please make sure to set `low_cpu_mem_usage=True`." + ) + config_path = pretrained_model_name_or_path + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + config, unused_kwargs, commit_hash = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + device_map=device_map, + user_agent=user_agent, + **kwargs, + ) + model_file = None + if from_flax: + model = cls.from_config(config, **unused_kwargs) + else: + if is_safetensors_available(): + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + except: + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + if low_cpu_mem_usage: + with accelerate.init_empty_weights(): + model = cls.from_config(config, **unused_kwargs) + if device_map is None: + param_device = "cpu" + state_dict = load_state_dict(model_file, variant=variant) + missing_keys = set(model.state_dict().keys()) - set( + state_dict.keys() + ) + if len(missing_keys) > 0: + raise ValueError( + f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are missing: \n {', '.join(missing_keys)}. \n Please make sure to pass `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize those weights or else make sure your checkpoint file is correct." + ) + for param_name, param in state_dict.items(): + accepts_dtype = "dtype" in set( + inspect.signature( + set_module_tensor_to_device + ).parameters.keys() + ) + if accepts_dtype: + set_module_tensor_to_device( + model, + param_name, + param_device, + value=param, + dtype=torch_dtype, + ) + else: + set_module_tensor_to_device( + model, param_name, param_device, value=param + ) + else: + accelerate.load_checkpoint_and_dispatch( + model, model_file, device_map, dtype=torch_dtype + ) + loading_info = { + "missing_keys": [], + "unexpected_keys": [], + "mismatched_keys": [], + "error_msgs": [], + } + else: + model = cls.from_config(config, **unused_kwargs) + state_dict = load_state_dict(model_file, variant=variant) + ( + model, + missing_keys, + unexpected_keys, + mismatched_keys, + error_msgs, + ) = cls._load_pretrained_model( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + ) + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + elif torch_dtype is not None: + model = model + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + model.eval() + if output_loading_info: + return model, loading_info + return model + + @classmethod + def _load_pretrained_model( + cls, + model, + state_dict, + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=False, + ): + model_state_dict = model.state_dict() + loaded_keys = [k for k in state_dict.keys()] + expected_keys = list(model_state_dict.keys()) + original_loaded_keys = loaded_keys + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + model_to_load = model + + def _find_mismatched_keys( + state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape + != model_state_dict[model_key].shape + ): + mismatched_keys.append( + ( + checkpoint_key, + state_dict[checkpoint_key].shape, + model_state_dict[model_key].shape, + ) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if state_dict is not None: + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = _load_state_dict_into_model(model_to_load, state_dict) + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + raise RuntimeError( + f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}" + ) + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." + ) + else: + logger.info( + f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n" + ) + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint was trained on, you can already use {model.__class__.__name__} for predictions without further training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} and are newly initialized because the shapes did not match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + + @property + def device(self) -> device: + """ + `torch.device`: The device on which the module is (assuming that all the module parameters are on the same + device). + """ + return get_parameter_device(self) + + @property + def dtype(self) -> torch.dtype: + """ + `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + return get_parameter_dtype(self) + + def num_parameters( + self, only_trainable: "bool" = False, exclude_embeddings: "bool" = False + ) -> int: + """ + Get number of (optionally, trainable or non-embeddings) parameters in the module. + + Args: + only_trainable (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of trainable parameters + + exclude_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of non-embeddings parameters + + Returns: + `int`: The number of parameters. + """ + if exclude_embeddings: + embedding_param_names = [ + f"{name}.weight" + for name, module_type in self.named_modules() + if isinstance(module_type, torch.nn.Embedding) + ] + non_embedding_parameters = [ + parameter + for name, parameter in self.named_parameters() + if name not in embedding_param_names + ] + return sum( + p.numel() + for p in non_embedding_parameters + if p.requires_grad or not only_trainable + ) + else: + return sum( + p.numel() + for p in self.parameters() + if p.requires_grad or not only_trainable + ) + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000 ** omega + pos = pos.reshape(-1) + out = np.einsum("m,d->md", pos, omega) + emb_sin = np.sin(out) + emb_cos = np.cos(out) + emb = np.concatenate([emb_sin, emb_cos], axis=1) + return emb + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) + emb = np.concatenate([emb_h, emb_w], axis=1) + return emb + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) + grid = np.stack(grid, axis=0) + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate( + [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0 + ) + return pos_embed + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + ): + super().__init__() + num_patches = height // patch_size * (width // patch_size) + self.flatten = flatten + self.layer_norm = layer_norm + self.proj = nn.Conv2d( + in_channels, + embed_dim, + kernel_size=(patch_size, patch_size), + stride=patch_size, + bias=bias, + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-06) + else: + self.norm = None + pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches ** 0.5)) + self.register_buffer( + "pos_embed", + torch.from_numpy(pos_embed).float().unsqueeze(0), + persistent=False, + ) + + def forward(self, latent): + latent = self.proj(latent) + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) + if self.layer_norm: + latent = self.norm(latent) + return latent + self.pos_embed + + +class BaseOutput(OrderedDict): + """ + Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a + tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular + python dictionary. + + + + You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple + before. + + + """ + + def __post_init__(self): + class_fields = fields(self) + if not len(class_fields): + raise ValueError(f"{self.__class__.__name__} has no fields.") + first_field = getattr(self, class_fields[0].name) + other_fields_are_none = all( + getattr(self, field.name) is None for field in class_fields[1:] + ) + if other_fields_are_none and isinstance(first_field, dict): + for key, value in first_field.items(): + self[key] = value + else: + for field in class_fields: + v = getattr(self, field.name) + if v is not None: + self[field.name] = v + + def __delitem__(self, *args, **kwargs): + raise Exception( + f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance." + ) + + def setdefault(self, *args, **kwargs): + raise Exception( + f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance." + ) + + def pop(self, *args, **kwargs): + raise Exception( + f"You cannot use ``pop`` on a {self.__class__.__name__} instance." + ) + + def update(self, *args, **kwargs): + raise Exception( + f"You cannot use ``update`` on a {self.__class__.__name__} instance." + ) + + def __getitem__(self, k): + if isinstance(k, str): + inner_dict = {k: v for k, v in self.items()} + return inner_dict[k] + else: + return self.to_tuple()[k] + + def __setattr__(self, name, value): + if name in self.keys() and value is not None: + super().__setitem__(name, value) + super().__setattr__(name, value) + + def __setitem__(self, key, value): + super().__setitem__(key, value) + super().__setattr__(key, value) + + def to_tuple(self) -> Tuple[Any]: + """ + Convert self to a tuple containing all the attributes/keys that are not `None`. + """ + return tuple(self[k] for k in self.keys()) + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions + for the unnoised latent pixels. + """ + + sample: "torch.FloatTensor" + + +def register_to_config(init): + """ + Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are + automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that + shouldn't be registered in the config, use the `ignore_for_config` class variable + + Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! + """ + + @functools.wraps(init) + def inner_init(self, *args, **kwargs): + init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")} + if not isinstance(self, ConfigMixin): + raise RuntimeError( + f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does not inherit from `ConfigMixin`." + ) + ignore = getattr(self, "ignore_for_config", []) + new_kwargs = {} + signature = inspect.signature(init) + parameters = { + name: p.default + for i, (name, p) in enumerate(signature.parameters.items()) + if i > 0 and name not in ignore + } + for arg, name in zip(args, parameters.keys()): + new_kwargs[name] = arg + new_kwargs.update( + { + k: init_kwargs.get(k, default) + for k, default in parameters.items() + if k not in ignore and k not in new_kwargs + } + ) + new_kwargs = {**config_init_kwargs, **new_kwargs} + getattr(self, "register_to_config")(**new_kwargs) + init(self, *args, **init_kwargs) + + return inner_init + + +class Transformer2DModel(ModelMixin, ConfigMixin): + """ + Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual + embeddings) inputs. + + When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard + transformer action. Finally, reshape to image. + + When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional + embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict + classes of unnoised image. + + Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised + image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: "int" = 16, + attention_head_dim: "int" = 88, + in_channels: "Optional[int]" = None, + out_channels: "Optional[int]" = None, + num_layers: "int" = 1, + dropout: "float" = 0.0, + norm_num_groups: "int" = 32, + cross_attention_dim: "Optional[int]" = None, + attention_bias: "bool" = False, + sample_size: "Optional[int]" = None, + num_vector_embeds: "Optional[int]" = None, + patch_size: "Optional[int]" = None, + activation_fn: "str" = "geglu", + num_embeds_ada_norm: "Optional[int]" = None, + use_linear_projection: "bool" = False, + only_cross_attention: "bool" = False, + upcast_attention: "bool" = False, + norm_type: "str" = "layer_norm", + norm_elementwise_affine: "bool" = True, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.is_input_continuous = in_channels is not None and patch_size is None + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config. Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for the `transformer/config.json` file" + deprecate( + "norm_type!=num_embeds_ada_norm", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + norm_type = "ada_norm" + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif ( + not self.is_input_continuous + and not self.is_input_vectorized + and not self.is_input_patches + ): + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size: {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + if self.is_input_continuous: + self.in_channels = in_channels + self.norm = torch.nn.GroupNorm( + num_groups=norm_num_groups, + num_channels=in_channels, + eps=1e-06, + affine=True, + ) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + elif self.is_input_vectorized: + assert ( + sample_size is not None + ), "Transformer2DModel over discrete input must provide sample_size" + assert ( + num_vector_embeds is not None + ), "Transformer2DModel over discrete input must provide num_embed" + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, + embed_dim=inner_dim, + height=self.height, + width=self.width, + ) + elif self.is_input_patches: + assert ( + sample_size is not None + ), "Transformer2DModel over patched input must provide sample_size" + self.height = sample_size + self.width = sample_size + self.patch_size = patch_size + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + ) + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + ) + for d in range(num_layers) + ] + ) + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + if use_linear_projection: + self.proj_out = nn.Linear(inner_dim, in_channels) + else: + self.proj_out = nn.Conv2d( + inner_dim, in_channels, kernel_size=1, stride=1, padding=0 + ) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches: + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-06) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear( + inner_dim, patch_size * patch_size * self.out_channels + ) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + return_dict: "bool" = True, + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels + conditioning. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: + [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch, height * width, inner_dim + ) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch, height * width, inner_dim + ) + hidden_states = self.proj_in(hidden_states) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + hidden_states = self.pos_embed(hidden_states) + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, width, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, width, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + logits = logits.permute(0, 2, 1) + output = F.log_softmax(logits.double(), dim=1).float() + elif self.is_input_patches: + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = ( + self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + ) + hidden_states = self.proj_out_2(hidden_states) + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=( + -1, + height, + width, + self.patch_size, + self.patch_size, + self.out_channels, + ) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=( + -1, + self.out_channels, + height * self.patch_size, + width * self.patch_size, + ) + ) + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + +class DualTransformer2DModel(nn.Module): + """ + Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + def __init__( + self, + num_attention_heads: "int" = 16, + attention_head_dim: "int" = 88, + in_channels: "Optional[int]" = None, + num_layers: "int" = 1, + dropout: "float" = 0.0, + norm_num_groups: "int" = 32, + cross_attention_dim: "Optional[int]" = None, + attention_bias: "bool" = False, + sample_size: "Optional[int]" = None, + num_vector_embeds: "Optional[int]" = None, + activation_fn: "str" = "geglu", + num_embeds_ada_norm: "Optional[int]" = None, + ): + super().__init__() + self.transformers = nn.ModuleList( + [ + Transformer2DModel( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + num_layers=num_layers, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + sample_size=sample_size, + num_vector_embeds=num_vector_embeds, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + ) + for _ in range(2) + ] + ) + self.mix_ratio = 0.5 + self.condition_lengths = [77, 257] + self.transformer_index_for_condition = [1, 0] + + def forward( + self, + hidden_states, + encoder_hidden_states, + timestep=None, + attention_mask=None, + cross_attention_kwargs=None, + return_dict: "bool" = True, + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + attention_mask (`torch.FloatTensor`, *optional*): + Optional attention mask to be applied in CrossAttention + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: + [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + input_states = hidden_states + encoded_states = [] + tokens_start = 0 + for i in range(2): + condition_state = encoder_hidden_states[ + :, tokens_start : tokens_start + self.condition_lengths[i] + ] + transformer_index = self.transformer_index_for_condition[i] + encoded_state = self.transformers[transformer_index]( + input_states, + encoder_hidden_states=condition_state, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + encoded_states.append(encoded_state - input_states) + tokens_start += self.condition_lengths[i] + output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * ( + 1 - self.mix_ratio + ) + output_states = output_states + input_states + if not return_dict: + return (output_states,) + return Transformer2DModelOutput(sample=output_states) + + +class GaussianFourierProjection(nn.Module): + """Gaussian Fourier embeddings for noise levels.""" + + def __init__( + self, + embedding_size: "int" = 256, + scale: "float" = 1.0, + set_W_to_weight=True, + log=True, + flip_sin_to_cos=False, + ): + super().__init__() + self.weight = nn.Parameter( + torch.randn(embedding_size) * scale, requires_grad=False + ) + self.log = log + self.flip_sin_to_cos = flip_sin_to_cos + if set_W_to_weight: + self.W = nn.Parameter( + torch.randn(embedding_size) * scale, requires_grad=False + ) + self.weight = self.W + + def forward(self, x): + if self.log: + x = torch.log(x) + x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi + if self.flip_sin_to_cos: + out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) + else: + out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + return out + + +@dataclass +class PriorTransformerOutput(BaseOutput): + """ + Args: + predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): + The predicted CLIP image embedding conditioned on the CLIP text embedding input. + """ + + predicted_image_embedding: "torch.FloatTensor" + + +class PriorTransformer(ModelMixin, ConfigMixin): + """ + The prior transformer from unCLIP is used to predict CLIP image embeddings from CLIP text embeddings. Note that the + transformer predicts the image embeddings through a denoising diffusion process. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + For more details, see the original paper: https://arxiv.org/abs/2204.06125 + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use. + embedding_dim (`int`, *optional*, defaults to 768): The dimension of the CLIP embeddings. Note that CLIP + image embeddings and text embeddings are both the same dimension. + num_embeddings (`int`, *optional*, defaults to 77): The max number of clip embeddings allowed. I.e. the + length of the prompt after it has been tokenized. + additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the + projected hidden_states. The actual length of the used hidden_states is `num_embeddings + + additional_embeddings`. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + + """ + + @register_to_config + def __init__( + self, + num_attention_heads: "int" = 32, + attention_head_dim: "int" = 64, + num_layers: "int" = 20, + embedding_dim: "int" = 768, + num_embeddings=77, + additional_embeddings=4, + dropout: "float" = 0.0, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.additional_embeddings = additional_embeddings + self.time_proj = Timesteps(inner_dim, True, 0) + self.time_embedding = TimestepEmbedding(inner_dim, inner_dim) + self.proj_in = nn.Linear(embedding_dim, inner_dim) + self.embedding_proj = nn.Linear(embedding_dim, inner_dim) + self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim) + self.positional_embedding = nn.Parameter( + torch.zeros(1, num_embeddings + additional_embeddings, inner_dim) + ) + self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim)) + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + activation_fn="gelu", + attention_bias=True, + ) + for d in range(num_layers) + ] + ) + self.norm_out = nn.LayerNorm(inner_dim) + self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim) + causal_attention_mask = torch.full( + [ + num_embeddings + additional_embeddings, + num_embeddings + additional_embeddings, + ], + -10000.0, + ) + causal_attention_mask.triu_(1) + causal_attention_mask = causal_attention_mask[None, ...] + self.register_buffer( + "causal_attention_mask", causal_attention_mask, persistent=False + ) + self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim)) + self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim)) + + def forward( + self, + hidden_states, + timestep: "Union[torch.Tensor, float, int]", + proj_embedding: "torch.FloatTensor", + encoder_hidden_states: "torch.FloatTensor", + attention_mask: "Optional[torch.BoolTensor]" = None, + return_dict: "bool" = True, + ): + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): + x_t, the currently predicted image embeddings. + timestep (`torch.long`): + Current denoising step. + proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): + Projected embedding vector the denoising process is conditioned on. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`): + Hidden states of the text embeddings the denoising process is conditioned on. + attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`): + Text mask for the text embeddings. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.prior_transformer.PriorTransformerOutput`] instead of a plain + tuple. + + Returns: + [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`: + [`~models.prior_transformer.PriorTransformerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + batch_size = hidden_states.shape[0] + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor( + [timesteps], dtype=torch.long, device=hidden_states.device + ) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None] + timesteps = timesteps * torch.ones( + batch_size, dtype=timesteps.dtype, device=timesteps.device + ) + timesteps_projected = self.time_proj(timesteps) + timesteps_projected = timesteps_projected + time_embeddings = self.time_embedding(timesteps_projected) + proj_embeddings = self.embedding_proj(proj_embedding) + encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states) + hidden_states = self.proj_in(hidden_states) + prd_embedding = self.prd_embedding.expand(batch_size, -1, -1) + positional_embeddings = self.positional_embedding + hidden_states = torch.cat( + [ + encoder_hidden_states, + proj_embeddings[:, None, :], + time_embeddings[:, None, :], + hidden_states[:, None, :], + prd_embedding, + ], + dim=1, + ) + hidden_states = hidden_states + positional_embeddings + if attention_mask is not None: + attention_mask = (1 - attention_mask) * -10000.0 + attention_mask = F.pad( + attention_mask, (0, self.additional_embeddings), value=0.0 + ) + attention_mask = attention_mask[:, None, :] + self.causal_attention_mask + attention_mask = attention_mask.repeat_interleave( + self.config.num_attention_heads, dim=0 + ) + for block in self.transformer_blocks: + hidden_states = block(hidden_states, attention_mask=attention_mask) + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states[:, -1] + predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states) + if not return_dict: + return (predicted_image_embedding,) + return PriorTransformerOutput( + predicted_image_embedding=predicted_image_embedding + ) + + def post_process_latents(self, prior_latents): + prior_latents = prior_latents * self.clip_std + self.clip_mean + return prior_latents + + +class Upsample1D(nn.Module): + """ + An upsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + use_conv_transpose: + out_channels: + """ + + def __init__( + self, + channels, + use_conv=False, + use_conv_transpose=False, + out_channels=None, + name="conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + self.conv = None + if use_conv_transpose: + self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(x) + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample1D(nn.Module): + """ + A downsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + out_channels: + padding: + """ + + def __init__( + self, channels, use_conv=False, out_channels=None, padding=1, name="conv" + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + if use_conv: + self.conv = nn.Conv1d( + self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.conv(x) + + +class Upsample2D(nn.Module): + """ + An upsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + use_conv_transpose: + out_channels: + """ + + def __init__( + self, + channels, + use_conv=False, + use_conv_transpose=False, + out_channels=None, + name="conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + conv = None + if use_conv_transpose: + conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, hidden_states, output_size=None): + assert hidden_states.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(hidden_states) + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + if output_size is None: + hidden_states = F.interpolate( + hidden_states, scale_factor=2.0, mode="nearest" + ) + else: + hidden_states = F.interpolate( + hidden_states, size=output_size, mode="nearest" + ) + if dtype == torch.bfloat16: + hidden_states = hidden_states + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + return hidden_states + + +class Downsample2D(nn.Module): + """ + A downsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + out_channels: + padding: + """ + + def __init__( + self, channels, use_conv=False, out_channels=None, padding=1, name="conv" + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + if use_conv: + conv = nn.Conv2d( + self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + conv = nn.AvgPool2d(kernel_size=stride, stride=stride) + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: + pad = 0, 1, 0, 1 + hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + return hidden_states + + +def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)): + up_x = up_y = up + down_x = down_y = down + pad_x0 = pad_y0 = pad[0] + pad_x1 = pad_y1 = pad[1] + _, channel, in_h, in_w = tensor.shape + tensor = tensor.reshape(-1, in_h, in_w, 1) + _, in_h, in_w, minor = tensor.shape + kernel_h, kernel_w = kernel.shape + out = tensor.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + out = F.pad( + out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + ) + out = out + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + return out.view(-1, channel, out_h, out_w) + + +class FirUpsample2D(nn.Module): + def __init__( + self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1) + ): + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = nn.Conv2d( + channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.use_conv = use_conv + self.fir_kernel = fir_kernel + self.out_channels = out_channels + + def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): + """Fused `upsample_2d()` followed by `Conv2d()`. + + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of + arbitrary order. + + Args: + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + weight: Weight tensor of the shape `[filterH, filterW, inChannels, + outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same + datatype as `hidden_states`. + """ + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + kernel = kernel * (gain * factor ** 2) + if self.use_conv: + convH = weight.shape[2] + convW = weight.shape[3] + inC = weight.shape[1] + pad_value = kernel.shape[0] - factor - (convW - 1) + stride = factor, factor + output_shape = (hidden_states.shape[2] - 1) * factor + convH, ( + hidden_states.shape[3] - 1 + ) * factor + convW + output_padding = ( + output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH, + output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW, + ) + assert output_padding[0] >= 0 and output_padding[1] >= 0 + num_groups = hidden_states.shape[1] // inC + weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) + weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) + weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) + inverse_conv = F.conv_transpose2d( + hidden_states, + weight, + stride=stride, + output_padding=output_padding, + padding=0, + ) + output = upfirdn2d_native( + inverse_conv, + torch.tensor(kernel, device=inverse_conv.device), + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1), + ) + else: + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + torch.tensor(kernel, device=hidden_states.device), + up=factor, + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), + ) + return output + + def forward(self, hidden_states): + if self.use_conv: + height = self._upsample_2d( + hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel + ) + height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) + else: + height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) + return height + + +class FirDownsample2D(nn.Module): + def __init__( + self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1) + ): + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = nn.Conv2d( + channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.fir_kernel = fir_kernel + self.use_conv = use_conv + self.out_channels = out_channels + + def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): + """Fused `Conv2d()` followed by `downsample_2d()`. + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of + arbitrary order. + + Args: + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + weight: + Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be + performed by `inChannels = x.shape[0] // numGroups`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * + factor`, which corresponds to average pooling. + factor: Integer downsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and + same datatype as `x`. + """ + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + kernel = kernel * gain + if self.use_conv: + _, _, convH, convW = weight.shape + pad_value = kernel.shape[0] - factor + (convW - 1) + stride_value = [factor, factor] + upfirdn_input = upfirdn2d_native( + hidden_states, + torch.tensor(kernel, device=hidden_states.device), + pad=((pad_value + 1) // 2, pad_value // 2), + ) + output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) + else: + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + torch.tensor(kernel, device=hidden_states.device), + down=factor, + pad=((pad_value + 1) // 2, pad_value // 2), + ) + return output + + def forward(self, hidden_states): + if self.use_conv: + downsample_input = self._downsample_2d( + hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel + ) + hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) + else: + hidden_states = self._downsample_2d( + hidden_states, kernel=self.fir_kernel, factor=2 + ) + return hidden_states + + +class KDownsample2D(nn.Module): + def __init__(self, pad_mode="reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) + self.pad = kernel_1d.shape[1] // 2 - 1 + self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) + + def forward(self, x): + x = F.pad(x, (self.pad,) * 4, self.pad_mode) + weight = x.new_zeros( + [x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]] + ) + indices = torch.arange(x.shape[1], device=x.device) + weight[indices, indices] = self.kernel + return F.conv2d(x, weight, stride=2) + + +class KUpsample2D(nn.Module): + def __init__(self, pad_mode="reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2 + self.pad = kernel_1d.shape[1] // 2 - 1 + self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) + + def forward(self, x): + x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode) + weight = x.new_zeros( + [x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]] + ) + indices = torch.arange(x.shape[1], device=x.device) + weight[indices, indices] = self.kernel + return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1) + + +def downsample_2d(hidden_states, kernel=None, factor=2, gain=1): + """Downsample2D a batch of 2D images with the given filter. + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the + given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the + specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its + shape is a multiple of the downsampling factor. + + Args: + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to average pooling. + factor: Integer downsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + output: Tensor of the shape `[N, C, H // factor, W // factor]` + """ + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + kernel = kernel * gain + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, kernel, down=factor, pad=((pad_value + 1) // 2, pad_value // 2) + ) + return output + + +def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): + """Upsample2D a batch of 2D images with the given filter. + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given + filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified + `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is + a: multiple of the upsampling factor. + + Args: + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + output: Tensor of the shape `[N, C, H * factor, W * factor]` + """ + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + kernel = kernel * (gain * factor ** 2) + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + kernel, + up=factor, + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), + ) + return output + + +class ResnetBlock2D(nn.Module): + """ + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv2d layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + groups_out (`int`, *optional*, default to None): + The number of groups to use for the second normalization layer. if set to None, same as `groups`. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. + time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. + By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or + "ada_group" for a stronger conditioning with scale and shift. + kernal (`torch.FloatTensor`, optional, default to None): FIR filter, see + [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. + output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. + use_in_shortcut (`bool`, *optional*, default to `True`): + If `True`, add a 1x1 nn.conv2d layer for skip-connection. + up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. + down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. + conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the + `conv_shortcut` output. + conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. + If None, same as `out_channels`. + """ + + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-06, + non_linearity="swish", + time_embedding_norm="default", + kernel=None, + output_scale_factor=1.0, + use_in_shortcut=None, + up=False, + down=False, + conv_shortcut_bias: bool = True, + conv_2d_out_channels: Optional[int] = None, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.up = up + self.down = down + self.output_scale_factor = output_scale_factor + self.time_embedding_norm = time_embedding_norm + if groups_out is None: + groups_out = groups + if self.time_embedding_norm == "ada_group": + self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) + else: + self.norm1 = torch.nn.GroupNorm( + num_groups=groups, num_channels=in_channels, eps=eps, affine=True + ) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if temb_channels is not None: + if self.time_embedding_norm == "default": + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) + elif self.time_embedding_norm == "scale_shift": + self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) + elif self.time_embedding_norm == "ada_group": + self.time_emb_proj = None + else: + raise ValueError( + f"unknown time_embedding_norm : {self.time_embedding_norm} " + ) + else: + self.time_emb_proj = None + if self.time_embedding_norm == "ada_group": + self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) + else: + self.norm2 = torch.nn.GroupNorm( + num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True + ) + self.dropout = torch.nn.Dropout(dropout) + conv_2d_out_channels = conv_2d_out_channels or out_channels + self.conv2 = torch.nn.Conv2d( + out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1 + ) + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = nn.Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + elif non_linearity == "gelu": + self.nonlinearity = nn.GELU() + self.upsample = self.downsample = None + if self.up: + if kernel == "fir": + fir_kernel = 1, 3, 3, 1 + self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") + else: + self.upsample = Upsample2D(in_channels, use_conv=False) + elif self.down: + if kernel == "fir": + fir_kernel = 1, 3, 3, 1 + self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) + else: + self.downsample = Downsample2D( + in_channels, use_conv=False, padding=1, name="op" + ) + self.use_in_shortcut = ( + self.in_channels != conv_2d_out_channels + if use_in_shortcut is None + else use_in_shortcut + ) + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, + conv_2d_out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=conv_shortcut_bias, + ) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + if self.time_embedding_norm == "ada_group": + hidden_states = self.norm1(hidden_states, temb) + else: + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + if self.upsample is not None: + if hidden_states.shape[0] >= 64: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() + input_tensor = self.upsample(input_tensor) + hidden_states = self.upsample(hidden_states) + elif self.downsample is not None: + input_tensor = self.downsample(input_tensor) + hidden_states = self.downsample(hidden_states) + hidden_states = self.conv1(hidden_states) + if self.time_emb_proj is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + if self.time_embedding_norm == "ada_group": + hidden_states = self.norm2(hidden_states, temb) + else: + hidden_states = self.norm2(hidden_states) + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + return output_tensor + + +class Mish(torch.nn.Module): + def forward(self, hidden_states): + return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) + + +def rearrange_dims(tensor): + if len(tensor.shape) == 2: + return tensor[:, :, None] + if len(tensor.shape) == 3: + return tensor[:, :, None, :] + elif len(tensor.shape) == 4: + return tensor[:, :, 0, :] + else: + raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") + + +class Conv1dBlock(nn.Module): + """ + Conv1d --> GroupNorm --> Mish + """ + + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): + super().__init__() + self.conv1d = nn.Conv1d( + inp_channels, out_channels, kernel_size, padding=kernel_size // 2 + ) + self.group_norm = nn.GroupNorm(n_groups, out_channels) + self.mish = nn.Mish() + + def forward(self, x): + x = self.conv1d(x) + x = rearrange_dims(x) + x = self.group_norm(x) + x = rearrange_dims(x) + x = self.mish(x) + return x + + +class ResidualTemporalBlock1D(nn.Module): + def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): + super().__init__() + self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) + self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size) + self.time_emb_act = nn.Mish() + self.time_emb = nn.Linear(embed_dim, out_channels) + self.residual_conv = ( + nn.Conv1d(inp_channels, out_channels, 1) + if inp_channels != out_channels + else nn.Identity() + ) + + def forward(self, x, t): + """ + Args: + x : [ batch_size x inp_channels x horizon ] + t : [ batch_size x embed_dim ] + + returns: + out : [ batch_size x out_channels x horizon ] + """ + t = self.time_emb_act(t) + t = self.time_emb(t) + out = self.conv_in(x) + rearrange_dims(t) + out = self.conv_out(out) + return out + self.residual_conv(x) + + +@dataclass +class UNet1DOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`): + Hidden states output. Output of last layer of model. + """ + + sample: "torch.FloatTensor" + + +class LinearMultiDim(nn.Linear): + def __init__(self, in_features, out_features=None, second_dim=4, *args, **kwargs): + in_features = ( + [in_features, second_dim, 1] + if isinstance(in_features, int) + else list(in_features) + ) + if out_features is None: + out_features = in_features + out_features = ( + [out_features, second_dim, 1] + if isinstance(out_features, int) + else list(out_features) + ) + self.in_features_multidim = in_features + self.out_features_multidim = out_features + super().__init__(np.array(in_features).prod(), np.array(out_features).prod()) + + def forward(self, input_tensor, *args, **kwargs): + shape = input_tensor.shape + n_dim = len(self.in_features_multidim) + input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_features) + output_tensor = super().forward(input_tensor) + output_tensor = output_tensor.view( + *shape[0:-n_dim], *self.out_features_multidim + ) + return output_tensor + + +class ResnetBlockFlat(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-06, + time_embedding_norm="default", + use_in_shortcut=None, + second_dim=4, + **kwargs, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + in_channels = ( + [in_channels, second_dim, 1] + if isinstance(in_channels, int) + else list(in_channels) + ) + self.in_channels_prod = np.array(in_channels).prod() + self.channels_multidim = in_channels + if out_channels is not None: + out_channels = ( + [out_channels, second_dim, 1] + if isinstance(out_channels, int) + else list(out_channels) + ) + out_channels_prod = np.array(out_channels).prod() + self.out_channels_multidim = out_channels + else: + out_channels_prod = self.in_channels_prod + self.out_channels_multidim = self.channels_multidim + self.time_embedding_norm = time_embedding_norm + if groups_out is None: + groups_out = groups + self.norm1 = torch.nn.GroupNorm( + num_groups=groups, num_channels=self.in_channels_prod, eps=eps, affine=True + ) + self.conv1 = torch.nn.Conv2d( + self.in_channels_prod, out_channels_prod, kernel_size=1, padding=0 + ) + if temb_channels is not None: + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels_prod) + else: + self.time_emb_proj = None + self.norm2 = torch.nn.GroupNorm( + num_groups=groups_out, num_channels=out_channels_prod, eps=eps, affine=True + ) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels_prod, out_channels_prod, kernel_size=1, padding=0 + ) + self.nonlinearity = nn.SiLU() + self.use_in_shortcut = ( + self.in_channels_prod != out_channels_prod + if use_in_shortcut is None + else use_in_shortcut + ) + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + self.in_channels_prod, + out_channels_prod, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, input_tensor, temb): + shape = input_tensor.shape + n_dim = len(self.channels_multidim) + input_tensor = input_tensor.reshape( + *shape[0:-n_dim], self.in_channels_prod, 1, 1 + ) + input_tensor = input_tensor.view(-1, self.in_channels_prod, 1, 1) + hidden_states = input_tensor + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + hidden_states = hidden_states + temb + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + output_tensor = input_tensor + hidden_states + output_tensor = output_tensor.view(*shape[0:-n_dim], -1) + output_tensor = output_tensor.view( + *shape[0:-n_dim], *self.out_channels_multidim + ) + return output_tensor + + +class CrossAttnDownBlockFlat(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + LinearMultiDim( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + cross_attention_kwargs=None, + ): + output_states = () + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + output_states += (hidden_states,) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + output_states += (hidden_states,) + return hidden_states, output_states + + +class DownBlockFlat(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.resnets = nn.ModuleList(resnets) + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + LinearMultiDim( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + output_states += (hidden_states,) + return hidden_states, output_states + + +# def get_down_block( +# down_block_type, +# num_layers, +# in_channels, +# out_channels, +# temb_channels, +# add_downsample, +# resnet_eps, +# resnet_act_fn, +# attn_num_head_channels, +# resnet_groups=None, +# cross_attention_dim=None, +# downsample_padding=None, +# dual_cross_attention=False, +# use_linear_projection=False, +# only_cross_attention=False, +# upcast_attention=False, +# resnet_time_scale_shift="default", +# ): +# down_block_type = ( +# down_block_type[7:] +# if down_block_type.startswith("UNetRes") +# else down_block_type +# ) +# if down_block_type == "DownBlockFlat": +# return DownBlockFlat( +# num_layers=num_layers, +# in_channels=in_channels, +# out_channels=out_channels, +# temb_channels=temb_channels, +# add_downsample=add_downsample, +# resnet_eps=resnet_eps, +# resnet_act_fn=resnet_act_fn, +# resnet_groups=resnet_groups, +# downsample_padding=downsample_padding, +# resnet_time_scale_shift=resnet_time_scale_shift, +# ) +# elif down_block_type == "CrossAttnDownBlockFlat": +# if cross_attention_dim is None: +# raise ValueError( +# "cross_attention_dim must be specified for CrossAttnDownBlockFlat" +# ) +# return CrossAttnDownBlockFlat( +# num_layers=num_layers, +# in_channels=in_channels, +# out_channels=out_channels, +# temb_channels=temb_channels, +# add_downsample=add_downsample, +# resnet_eps=resnet_eps, +# resnet_act_fn=resnet_act_fn, +# resnet_groups=resnet_groups, +# downsample_padding=downsample_padding, +# cross_attention_dim=cross_attention_dim, +# attn_num_head_channels=attn_num_head_channels, +# dual_cross_attention=dual_cross_attention, +# use_linear_projection=use_linear_projection, +# only_cross_attention=only_cross_attention, +# resnet_time_scale_shift=resnet_time_scale_shift, +# ) +# raise ValueError(f"{down_block_type} is not supported.") + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + down_block_type = ( + down_block_type[7:] + if down_block_type.startswith("UNetRes") + else down_block_type + ) + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "ResnetDownsampleBlock2D": + return ResnetDownsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnDownBlock2D": + return AttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for CrossAttnDownBlock2D" + ) + return CrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "SimpleCrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D" + ) + return SimpleCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "SkipDownBlock2D": + return SkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnSkipDownBlock2D": + return AttnSkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "DownEncoderBlock2D": + return DownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnDownEncoderBlock2D": + return AttnDownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "KDownBlock2D": + return KDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif down_block_type == "KCrossAttnDownBlock2D": + return KCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + add_self_attention=True if not add_downsample else False, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +class MidResTemporalBlock1D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + embed_dim, + num_layers: "int" = 1, + add_downsample: "bool" = False, + add_upsample: "bool" = False, + non_linearity=None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.add_downsample = add_downsample + resnets = [ + ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim) + ] + for _ in range(num_layers): + resnets.append( + ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim) + ) + self.resnets = nn.ModuleList(resnets) + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = nn.Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + else: + self.nonlinearity = None + self.upsample = None + if add_upsample: + self.upsample = Downsample1D(out_channels, use_conv=True) + self.downsample = None + if add_downsample: + self.downsample = Downsample1D(out_channels, use_conv=True) + if self.upsample and self.downsample: + raise ValueError("Block cannot downsample and upsample") + + def forward(self, hidden_states, temb): + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + if self.upsample: + hidden_states = self.upsample(hidden_states) + if self.downsample: + self.downsample = self.downsample(hidden_states) + return hidden_states + + +_kernels = { + "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8], + "cubic": [ + -0.01171875, + -0.03515625, + 0.11328125, + 0.43359375, + 0.43359375, + 0.11328125, + -0.03515625, + -0.01171875, + ], + "lanczos3": [ + 0.003689131001010537, + 0.015056144446134567, + -0.03399861603975296, + -0.066637322306633, + 0.13550527393817902, + 0.44638532400131226, + 0.44638532400131226, + 0.13550527393817902, + -0.066637322306633, + -0.03399861603975296, + 0.015056144446134567, + 0.003689131001010537, + ], +} + + +class Downsample1d(nn.Module): + def __init__(self, kernel="linear", pad_mode="reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer("kernel", kernel_1d) + + def forward(self, hidden_states): + hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode) + weight = hidden_states.new_zeros( + [hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]] + ) + indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) + weight[indices, indices] = self.kernel + return F.conv1d(hidden_states, weight, stride=2) + + +class ResConvBlock(nn.Module): + def __init__(self, in_channels, mid_channels, out_channels, is_last=False): + super().__init__() + self.is_last = is_last + self.has_conv_skip = in_channels != out_channels + if self.has_conv_skip: + self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False) + self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2) + self.group_norm_1 = nn.GroupNorm(1, mid_channels) + self.gelu_1 = nn.GELU() + self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2) + if not self.is_last: + self.group_norm_2 = nn.GroupNorm(1, out_channels) + self.gelu_2 = nn.GELU() + + def forward(self, hidden_states): + residual = ( + self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states + ) + hidden_states = self.conv_1(hidden_states) + hidden_states = self.group_norm_1(hidden_states) + hidden_states = self.gelu_1(hidden_states) + hidden_states = self.conv_2(hidden_states) + if not self.is_last: + hidden_states = self.group_norm_2(hidden_states) + hidden_states = self.gelu_2(hidden_states) + output = hidden_states + residual + return output + + +class SelfAttention1d(nn.Module): + def __init__(self, in_channels, n_head=1, dropout_rate=0.0): + super().__init__() + self.channels = in_channels + self.group_norm = nn.GroupNorm(1, num_channels=in_channels) + self.num_heads = n_head + self.query = nn.Linear(self.channels, self.channels) + self.key = nn.Linear(self.channels, self.channels) + self.value = nn.Linear(self.channels, self.channels) + self.proj_attn = nn.Linear(self.channels, self.channels, 1) + self.dropout = nn.Dropout(dropout_rate, inplace=True) + + def transpose_for_scores(self, projection: "torch.Tensor") -> torch.Tensor: + new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) + new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) + return new_projection + + def forward(self, hidden_states): + residual = hidden_states + batch, channel_dim, seq = hidden_states.shape + hidden_states = self.group_norm(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + query_states = self.transpose_for_scores(query_proj) + key_states = self.transpose_for_scores(key_proj) + value_states = self.transpose_for_scores(value_proj) + scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1])) + attention_scores = torch.matmul( + query_states * scale, key_states.transpose(-1, -2) * scale + ) + attention_probs = torch.softmax(attention_scores, dim=-1) + hidden_states = torch.matmul(attention_probs, value_states) + hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() + new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) + hidden_states = hidden_states.view(new_hidden_states_shape) + hidden_states = self.proj_attn(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.dropout(hidden_states) + output = hidden_states + residual + return output + + +class Upsample1d(nn.Module): + def __init__(self, kernel="linear", pad_mode="reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) * 2 + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer("kernel", kernel_1d) + + def forward(self, hidden_states, temb=None): + hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) + weight = hidden_states.new_zeros( + [hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]] + ) + indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) + weight[indices, indices] = self.kernel + return F.conv_transpose1d( + hidden_states, weight, stride=2, padding=self.pad * 2 + 1 + ) + + +class UNetMidBlock1D(nn.Module): + def __init__(self, mid_channels, in_channels, out_channels=None): + super().__init__() + out_channels = in_channels if out_channels is None else out_channels + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + self.up = Upsample1d(kernel="cubic") + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + hidden_states = self.down(hidden_states) + for attn, resnet in zip(self.attentions, self.resnets): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + hidden_states = self.up(hidden_states) + return hidden_states + + +class ValueFunctionMidBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, embed_dim): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.embed_dim = embed_dim + self.res1 = ResidualTemporalBlock1D( + in_channels, in_channels // 2, embed_dim=embed_dim + ) + self.down1 = Downsample1D(out_channels // 2, use_conv=True) + self.res2 = ResidualTemporalBlock1D( + in_channels // 2, in_channels // 4, embed_dim=embed_dim + ) + self.down2 = Downsample1D(out_channels // 4, use_conv=True) + + def forward(self, x, temb=None): + x = self.res1(x, temb) + x = self.down1(x) + x = self.res2(x, temb) + x = self.down2(x) + return x + + +def get_mid_block( + mid_block_type, + num_layers, + in_channels, + mid_channels, + out_channels, + embed_dim, + add_downsample, +): + if mid_block_type == "MidResTemporalBlock1D": + return MidResTemporalBlock1D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + embed_dim=embed_dim, + add_downsample=add_downsample, + ) + elif mid_block_type == "ValueFunctionMidBlock1D": + return ValueFunctionMidBlock1D( + in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim + ) + elif mid_block_type == "UNetMidBlock1D": + return UNetMidBlock1D( + in_channels=in_channels, + mid_channels=mid_channels, + out_channels=out_channels, + ) + raise ValueError(f"{mid_block_type} does not exist.") + + +class OutConv1DBlock(nn.Module): + def __init__(self, num_groups_out, out_channels, embed_dim, act_fn): + super().__init__() + self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) + self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) + if act_fn == "silu": + self.final_conv1d_act = nn.SiLU() + if act_fn == "mish": + self.final_conv1d_act = nn.Mish() + self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) + + def forward(self, hidden_states, temb=None): + hidden_states = self.final_conv1d_1(hidden_states) + hidden_states = rearrange_dims(hidden_states) + hidden_states = self.final_conv1d_gn(hidden_states) + hidden_states = rearrange_dims(hidden_states) + hidden_states = self.final_conv1d_act(hidden_states) + hidden_states = self.final_conv1d_2(hidden_states) + return hidden_states + + +class OutValueFunctionBlock(nn.Module): + def __init__(self, fc_dim, embed_dim): + super().__init__() + self.final_block = nn.ModuleList( + [ + nn.Linear(fc_dim + embed_dim, fc_dim // 2), + nn.Mish(), + nn.Linear(fc_dim // 2, 1), + ] + ) + + def forward(self, hidden_states, temb): + hidden_states = hidden_states.view(hidden_states.shape[0], -1) + hidden_states = torch.cat((hidden_states, temb), dim=-1) + for layer in self.final_block: + hidden_states = layer(hidden_states) + return hidden_states + + +def get_out_block( + *, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim +): + if out_block_type == "OutConv1DBlock": + return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) + elif out_block_type == "ValueFunction": + return OutValueFunctionBlock(fc_dim, embed_dim) + return None + + +class CrossAttnUpBlockFlat(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + prev_output_channel: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + for i in range(num_layers): + res_skip_channels = in_channels if i == num_layers - 1 else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + resnets.append( + ResnetBlockFlat( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + if add_upsample: + self.upsamplers = nn.ModuleList( + [LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + cross_attention_kwargs=None, + upsample_size=None, + attention_mask=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + return hidden_states + + +class UpBlockFlat(nn.Module): + def __init__( + self, + in_channels: "int", + prev_output_channel: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + res_skip_channels = in_channels if i == num_layers - 1 else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + resnets.append( + ResnetBlockFlat( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.resnets = nn.ModuleList(resnets) + if add_upsample: + self.upsamplers = nn.ModuleList( + [LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + self.gradient_checkpointing = False + + def forward( + self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None + ): + for resnet in self.resnets: + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + return hidden_states + + +# def get_up_block( +# up_block_type, +# num_layers, +# in_channels, +# out_channels, +# prev_output_channel, +# temb_channels, +# add_upsample, +# resnet_eps, +# resnet_act_fn, +# attn_num_head_channels, +# resnet_groups=None, +# cross_attention_dim=None, +# dual_cross_attention=False, +# use_linear_projection=False, +# only_cross_attention=False, +# upcast_attention=False, +# resnet_time_scale_shift="default", +# ): +# up_block_type = ( +# up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type +# ) +# if up_block_type == "UpBlockFlat": +# return UpBlockFlat( +# num_layers=num_layers, +# in_channels=in_channels, +# out_channels=out_channels, +# prev_output_channel=prev_output_channel, +# temb_channels=temb_channels, +# add_upsample=add_upsample, +# resnet_eps=resnet_eps, +# resnet_act_fn=resnet_act_fn, +# resnet_groups=resnet_groups, +# resnet_time_scale_shift=resnet_time_scale_shift, +# ) +# elif up_block_type == "CrossAttnUpBlockFlat": +# if cross_attention_dim is None: +# raise ValueError( +# "cross_attention_dim must be specified for CrossAttnUpBlockFlat" +# ) +# return CrossAttnUpBlockFlat( +# num_layers=num_layers, +# in_channels=in_channels, +# out_channels=out_channels, +# prev_output_channel=prev_output_channel, +# temb_channels=temb_channels, +# add_upsample=add_upsample, +# resnet_eps=resnet_eps, +# resnet_act_fn=resnet_act_fn, +# resnet_groups=resnet_groups, +# cross_attention_dim=cross_attention_dim, +# attn_num_head_channels=attn_num_head_channels, +# dual_cross_attention=dual_cross_attention, +# use_linear_projection=use_linear_projection, +# only_cross_attention=only_cross_attention, +# resnet_time_scale_shift=resnet_time_scale_shift, +# ) +# raise ValueError(f"{up_block_type} is not supported.") +# + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + up_block_type = ( + up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + ) + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "ResnetUpsampleBlock2D": + return ResnetUpsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for CrossAttnUpBlock2D" + ) + return CrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "SimpleCrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D" + ) + return SimpleCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "AttnUpBlock2D": + return AttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "SkipUpBlock2D": + return SkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "AttnSkipUpBlock2D": + return AttnSkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "UpDecoderBlock2D": + return UpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "AttnUpDecoderBlock2D": + return AttnUpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "KUpBlock2D": + return KUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == "KCrossAttnUpBlock2D": + return KCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + ) + + raise ValueError(f"{up_block_type} does not exist.") + + +class UNet1DModel(ModelMixin, ConfigMixin): + """ + UNet1DModel is a 1D UNet model that takes in a noisy sample and a timestep and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime. + in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. + time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. + freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for fourier time embedding. + flip_sin_to_cos (`bool`, *optional*, defaults to : + obj:`False`): Whether to flip sin to cos for fourier time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`): Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(32, 32, 64)`): Tuple of block output channels. + mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet. + out_block_type (`str`, *optional*, defaults to `None`): optional output processing of UNet. + act_fn (`str`, *optional*, defaults to None): optional activitation function in UNet blocks. + norm_num_groups (`int`, *optional*, defaults to 8): group norm member count in UNet blocks. + layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block. + downsample_each_block (`int`, *optional*, defaults to False: + experimental feature for using a UNet without upsampling. + """ + + @register_to_config + def __init__( + self, + sample_size: "int" = 65536, + sample_rate: "Optional[int]" = None, + in_channels: "int" = 2, + out_channels: "int" = 2, + extra_in_channels: "int" = 0, + time_embedding_type: "str" = "fourier", + flip_sin_to_cos: "bool" = True, + use_timestep_embedding: "bool" = False, + freq_shift: "float" = 0.0, + down_block_types: "Tuple[str]" = ( + "DownBlock1DNoSkip", + "DownBlock1D", + "AttnDownBlock1D", + ), + up_block_types: "Tuple[str]" = ( + "AttnUpBlock1D", + "UpBlock1D", + "UpBlock1DNoSkip", + ), + mid_block_type: "Tuple[str]" = "UNetMidBlock1D", + out_block_type: "str" = None, + block_out_channels: "Tuple[int]" = (32, 32, 64), + act_fn: "str" = None, + norm_num_groups: "int" = 8, + layers_per_block: "int" = 1, + downsample_each_block: "bool" = False, + ): + super().__init__() + self.sample_size = sample_size + if time_embedding_type == "fourier": + self.time_proj = GaussianFourierProjection( + embedding_size=8, + set_W_to_weight=False, + log=False, + flip_sin_to_cos=flip_sin_to_cos, + ) + timestep_input_dim = 2 * block_out_channels[0] + elif time_embedding_type == "positional": + self.time_proj = Timesteps( + block_out_channels[0], + flip_sin_to_cos=flip_sin_to_cos, + downscale_freq_shift=freq_shift, + ) + timestep_input_dim = block_out_channels[0] + if use_timestep_embedding: + time_embed_dim = block_out_channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=timestep_input_dim, + time_embed_dim=time_embed_dim, + act_fn=act_fn, + out_dim=block_out_channels[0], + ) + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + self.out_block = None + output_channel = in_channels + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + if i == 0: + input_channel += extra_in_channels + is_final_block = i == len(block_out_channels) - 1 + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=block_out_channels[0], + add_downsample=not is_final_block or downsample_each_block, + ) + self.down_blocks.append(down_block) + self.mid_block = get_mid_block( + mid_block_type, + in_channels=block_out_channels[-1], + mid_channels=block_out_channels[-1], + out_channels=block_out_channels[-1], + embed_dim=block_out_channels[0], + num_layers=layers_per_block, + add_downsample=downsample_each_block, + ) + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + if out_block_type is None: + final_upsample_channels = out_channels + else: + final_upsample_channels = block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = ( + reversed_block_out_channels[i + 1] + if i < len(up_block_types) - 1 + else final_upsample_channels + ) + is_final_block = i == len(block_out_channels) - 1 + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block, + in_channels=prev_output_channel, + out_channels=output_channel, + temb_channels=block_out_channels[0], + add_upsample=not is_final_block, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + num_groups_out = ( + norm_num_groups + if norm_num_groups is not None + else min(block_out_channels[0] // 4, 32) + ) + self.out_block = get_out_block( + out_block_type=out_block_type, + num_groups_out=num_groups_out, + embed_dim=block_out_channels[0], + out_channels=out_channels, + act_fn=act_fn, + fc_dim=block_out_channels[-1] // 4, + ) + + def forward( + self, + sample: "torch.FloatTensor", + timestep: "Union[torch.Tensor, float, int]", + return_dict: "bool" = True, + ) -> Union[UNet1DOutput, Tuple]: + """ + Args: + sample (`torch.FloatTensor`): `(batch_size, num_channels, sample_size)` noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor( + [timesteps], dtype=torch.long, device=sample.device + ) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None] + timestep_embed = self.time_proj(timesteps) + if self.config.use_timestep_embedding: + timestep_embed = self.time_mlp(timestep_embed) + else: + timestep_embed = timestep_embed[..., None] + timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]) + timestep_embed = timestep_embed.broadcast_to( + sample.shape[:1] + timestep_embed.shape[1:] + ) + down_block_res_samples = () + for downsample_block in self.down_blocks: + sample, res_samples = downsample_block( + hidden_states=sample, temb=timestep_embed + ) + down_block_res_samples += res_samples + if self.mid_block: + sample = self.mid_block(sample, timestep_embed) + for i, upsample_block in enumerate(self.up_blocks): + res_samples = down_block_res_samples[-1:] + down_block_res_samples = down_block_res_samples[:-1] + sample = upsample_block( + sample, res_hidden_states_tuple=res_samples, temb=timestep_embed + ) + if self.out_block: + sample = self.out_block(sample, timestep_embed) + if not return_dict: + return (sample,) + return UNet1DOutput(sample=sample) + + +class DownResnetBlock1D(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + num_layers=1, + conv_shortcut=False, + temb_channels=32, + groups=32, + groups_out=None, + non_linearity=None, + time_embedding_norm="default", + output_scale_factor=1.0, + add_downsample=True, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.add_downsample = add_downsample + self.output_scale_factor = output_scale_factor + if groups_out is None: + groups_out = groups + resnets = [ + ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels) + ] + for _ in range(num_layers): + resnets.append( + ResidualTemporalBlock1D( + out_channels, out_channels, embed_dim=temb_channels + ) + ) + self.resnets = nn.ModuleList(resnets) + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = nn.Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + else: + self.nonlinearity = None + self.downsample = None + if add_downsample: + self.downsample = Downsample1D(out_channels, use_conv=True, padding=1) + + def forward(self, hidden_states, temb=None): + output_states = () + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) + if self.nonlinearity is not None: + hidden_states = self.nonlinearity(hidden_states) + if self.downsample is not None: + hidden_states = self.downsample(hidden_states) + return hidden_states, output_states + + +class UpResnetBlock1D(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + num_layers=1, + temb_channels=32, + groups=32, + groups_out=None, + non_linearity=None, + time_embedding_norm="default", + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.time_embedding_norm = time_embedding_norm + self.add_upsample = add_upsample + self.output_scale_factor = output_scale_factor + if groups_out is None: + groups_out = groups + resnets = [ + ResidualTemporalBlock1D( + 2 * in_channels, out_channels, embed_dim=temb_channels + ) + ] + for _ in range(num_layers): + resnets.append( + ResidualTemporalBlock1D( + out_channels, out_channels, embed_dim=temb_channels + ) + ) + self.resnets = nn.ModuleList(resnets) + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = nn.Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + else: + self.nonlinearity = None + self.upsample = None + if add_upsample: + self.upsample = Upsample1D(out_channels, use_conv_transpose=True) + + def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None): + if res_hidden_states_tuple is not None: + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + if self.nonlinearity is not None: + hidden_states = self.nonlinearity(hidden_states) + if self.upsample is not None: + hidden_states = self.upsample(hidden_states) + return hidden_states + + +class AttnDownBlock1D(nn.Module): + def __init__(self, out_channels, in_channels, mid_channels=None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + hidden_states = self.down(hidden_states) + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + return hidden_states, (hidden_states,) + + +class DownBlock1D(nn.Module): + def __init__(self, out_channels, in_channels, mid_channels=None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + hidden_states = self.down(hidden_states) + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + return hidden_states, (hidden_states,) + + +class DownBlock1DNoSkip(nn.Module): + def __init__(self, out_channels, in_channels, mid_channels=None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + hidden_states = torch.cat([hidden_states, temb], dim=1) + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + return hidden_states, (hidden_states,) + + +class AttnUpBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.up = Upsample1d(kernel="cubic") + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + hidden_states = self.up(hidden_states) + return hidden_states + + +class UpBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + mid_channels = in_channels if mid_channels is None else mid_channels + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + self.resnets = nn.ModuleList(resnets) + self.up = Upsample1d(kernel="cubic") + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + hidden_states = self.up(hidden_states) + return hidden_states + + +class UpBlock1DNoSkip(nn.Module): + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + mid_channels = in_channels if mid_channels is None else mid_channels + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True), + ] + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + return hidden_states + + +@dataclass +class UNet2DOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Hidden states output. Output of last layer of model. + """ + + sample: "torch.FloatTensor" + + +class UNetMidBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + add_attention: "bool" = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + ): + super().__init__() + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + self.add_attention = add_attention + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + for _ in range(num_layers): + if self.add_attention: + attentions.append( + AttentionBlock( + in_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + ) + ) + else: + attentions.append(None) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + hidden_states = attn(hidden_states) + hidden_states = resnet(hidden_states, temb) + return hidden_states + + +class UNet2DModel(ModelMixin, ConfigMixin): + """ + UNet2DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. + out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use. + freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding. + flip_sin_to_cos (`bool`, *optional*, defaults to : + obj:`True`): Whether to flip sin to cos for fourier time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block + types. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`): + The mid block type. Choose from `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(224, 448, 672, 896)`): Tuple of block output channels. + layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block. + mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block. + downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. + norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization. + norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization. + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to None): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, or `"identity"`. + num_class_embeds (`int`, *optional*, defaults to None): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + """ + + @register_to_config + def __init__( + self, + sample_size: "Optional[Union[int, Tuple[int, int]]]" = None, + in_channels: "int" = 3, + out_channels: "int" = 3, + center_input_sample: "bool" = False, + time_embedding_type: "str" = "positional", + freq_shift: "int" = 0, + flip_sin_to_cos: "bool" = True, + down_block_types: "Tuple[str]" = ( + "DownBlock2D", + "AttnDownBlock2D", + "AttnDownBlock2D", + "AttnDownBlock2D", + ), + up_block_types: "Tuple[str]" = ( + "AttnUpBlock2D", + "AttnUpBlock2D", + "AttnUpBlock2D", + "UpBlock2D", + ), + block_out_channels: "Tuple[int]" = (224, 448, 672, 896), + layers_per_block: "int" = 2, + mid_block_scale_factor: "float" = 1, + downsample_padding: "int" = 1, + act_fn: "str" = "silu", + attention_head_dim: "Optional[int]" = 8, + norm_num_groups: "int" = 32, + norm_eps: "float" = 1e-05, + resnet_time_scale_shift: "str" = "default", + add_attention: "bool" = True, + class_embed_type: "Optional[str]" = None, + num_class_embeds: "Optional[int]" = None, + ): + super().__init__() + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1) + ) + if time_embedding_type == "fourier": + self.time_proj = GaussianFourierProjection( + embedding_size=block_out_channels[0], scale=16 + ) + timestep_input_dim = 2 * block_out_channels[0] + elif time_embedding_type == "positional": + self.time_proj = Timesteps( + block_out_channels[0], flip_sin_to_cos, freq_shift + ) + timestep_input_dim = block_out_channels[0] + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attn_num_head_channels=attention_head_dim, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + attn_num_head_channels=attention_head_dim, + resnet_groups=norm_num_groups, + add_attention=add_attention, + ) + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[ + min(i + 1, len(block_out_channels) - 1) + ] + is_final_block = i == len(block_out_channels) - 1 + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attn_num_head_channels=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + num_groups_out = ( + norm_num_groups + if norm_num_groups is not None + else min(block_out_channels[0] // 4, 32) + ) + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps + ) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=3, padding=1 + ) + + def forward( + self, + sample: "torch.FloatTensor", + timestep: "Union[torch.Tensor, float, int]", + class_labels: "Optional[torch.Tensor]" = None, + return_dict: "bool" = True, + ) -> Union[UNet2DOutput, Tuple]: + """ + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps + class_labels (`torch.FloatTensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor( + [timesteps], dtype=torch.long, device=sample.device + ) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None] + timesteps = timesteps * torch.ones( + sample.shape[0], dtype=timesteps.dtype, device=timesteps.device + ) + t_emb = self.time_proj(timesteps) + t_emb = t_emb + emb = self.time_embedding(t_emb) + if self.class_embedding is not None: + if class_labels is None: + raise ValueError( + "class_labels should be provided when doing class conditioning" + ) + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + class_emb = self.class_embedding(class_labels) + emb = emb + class_emb + skip_sample = sample + sample = self.conv_in(sample) + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "skip_conv"): + sample, res_samples, skip_sample = downsample_block( + hidden_states=sample, temb=emb, skip_sample=skip_sample + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + down_block_res_samples += res_samples + sample = self.mid_block(sample, emb) + skip_sample = None + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[ + : -len(upsample_block.resnets) + ] + if hasattr(upsample_block, "skip_conv"): + sample, skip_sample = upsample_block( + sample, res_samples, emb, skip_sample + ) + else: + sample = upsample_block(sample, res_samples, emb) + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + if skip_sample is not None: + sample += skip_sample + if self.config.time_embedding_type == "fourier": + timesteps = timesteps.reshape( + (sample.shape[0], *([1] * len(sample.shape[1:]))) + ) + sample = sample / timesteps + if not return_dict: + return (sample,) + return UNet2DOutput(sample=sample) + + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + ): + super().__init__() + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + for _ in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + cross_attention_kwargs=None, + ): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + hidden_states = resnet(hidden_states, temb) + return hidden_states + + +class UNetMidBlock2DSimpleCrossAttn(nn.Module): + def __init__( + self, + in_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + ): + super().__init__() + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + self.num_heads = in_channels // self.attn_num_head_channels + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + for _ in range(num_layers): + attentions.append( + CrossAttention( + query_dim=in_channels, + cross_attention_dim=in_channels, + heads=self.num_heads, + dim_head=attn_num_head_channels, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + processor=CrossAttnAddedKVProcessor(), + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + cross_attention_kwargs=None, + ): + cross_attention_kwargs = ( + cross_attention_kwargs if cross_attention_kwargs is not None else {} + ) + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + hidden_states = resnet(hidden_states, temb) + return hidden_states + + +class AttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states, temb=None): + output_states = () + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states += (hidden_states,) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + output_states += (hidden_states,) + return hidden_states, output_states + + +class CrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + cross_attention_kwargs=None, + ): + output_states = () + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + output_states += (hidden_states,) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + output_states += (hidden_states,) + return hidden_states, output_states + + +class DownBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.resnets = nn.ModuleList(resnets) + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + output_states += (hidden_states,) + return hidden_states, output_states + + +class DownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.resnets = nn.ModuleList(resnets) + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + return hidden_states + + +class AttnDownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + attentions = [] + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states): + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = attn(hidden_states) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + return hidden_states + + +class AttnSkipDownBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + output_scale_factor=np.sqrt(2.0), + downsample_padding=1, + add_downsample=True, + ): + super().__init__() + self.attentions = nn.ModuleList([]) + self.resnets = nn.ModuleList([]) + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + ) + ) + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + down=True, + kernel="fir", + ) + self.downsamplers = nn.ModuleList( + [FirDownsample2D(out_channels, out_channels=out_channels)] + ) + self.skip_conv = nn.Conv2d( + 3, out_channels, kernel_size=(1, 1), stride=(1, 1) + ) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def forward(self, hidden_states, temb=None, skip_sample=None): + output_states = () + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states += (hidden_states,) + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + hidden_states = self.skip_conv(skip_sample) + hidden_states + output_states += (hidden_states,) + return hidden_states, output_states, skip_sample + + +class SkipDownBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_pre_norm: "bool" = True, + output_scale_factor=np.sqrt(2.0), + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + self.resnets = nn.ModuleList([]) + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + down=True, + kernel="fir", + ) + self.downsamplers = nn.ModuleList( + [FirDownsample2D(out_channels, out_channels=out_channels)] + ) + self.skip_conv = nn.Conv2d( + 3, out_channels, kernel_size=(1, 1), stride=(1, 1) + ) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def forward(self, hidden_states, temb=None, skip_sample=None): + output_states = () + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + hidden_states = self.skip_conv(skip_sample) + hidden_states + output_states += (hidden_states,) + return hidden_states, output_states, skip_sample + + +class ResnetDownsampleBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + output_scale_factor=1.0, + add_downsample=True, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.resnets = nn.ModuleList(resnets) + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + down=True, + ) + ] + ) + else: + self.downsamplers = None + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, temb) + output_states += (hidden_states,) + return hidden_states, output_states + + +class SimpleCrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_downsample=True, + ): + super().__init__() + self.has_cross_attention = True + resnets = [] + attentions = [] + self.attn_num_head_channels = attn_num_head_channels + self.num_heads = out_channels // self.attn_num_head_channels + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + CrossAttention( + query_dim=out_channels, + cross_attention_dim=out_channels, + heads=self.num_heads, + dim_head=attn_num_head_channels, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + processor=CrossAttnAddedKVProcessor(), + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + down=True, + ) + ] + ) + else: + self.downsamplers = None + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + cross_attention_kwargs=None, + ): + output_states = () + cross_attention_kwargs = ( + cross_attention_kwargs if cross_attention_kwargs is not None else {} + ) + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + output_states += (hidden_states,) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, temb) + output_states += (hidden_states,) + return hidden_states, output_states + + +class KDownBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 4, + resnet_eps: "float" = 1e-05, + resnet_act_fn: "str" = "gelu", + resnet_group_size: "int" = 32, + add_downsample=False, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=groups, + groups_out=groups_out, + eps=resnet_eps, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + self.resnets = nn.ModuleList(resnets) + if add_downsample: + self.downsamplers = nn.ModuleList([KDownsample2D()]) + else: + self.downsamplers = None + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + return hidden_states, output_states + + +class KAttentionBlock(nn.Module): + """ + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: "int", + num_attention_heads: "int", + attention_head_dim: "int", + dropout: "float" = 0.0, + cross_attention_dim: "Optional[int]" = None, + attention_bias: "bool" = False, + upcast_attention: "bool" = False, + temb_channels: "int" = 768, + add_self_attention: "bool" = False, + cross_attention_norm: "bool" = False, + group_size: "int" = 32, + ): + super().__init__() + self.add_self_attention = add_self_attention + if add_self_attention: + self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) + self.attn1 = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + cross_attention_norm=None, + ) + self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + cross_attention_norm=cross_attention_norm, + ) + + def _to_3d(self, hidden_states, height, weight): + return hidden_states.permute(0, 2, 3, 1).reshape( + hidden_states.shape[0], height * weight, -1 + ) + + def _to_4d(self, hidden_states, height, weight): + return hidden_states.permute(0, 2, 1).reshape( + hidden_states.shape[0], -1, height, weight + ) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + emb=None, + attention_mask=None, + cross_attention_kwargs=None, + ): + cross_attention_kwargs = ( + cross_attention_kwargs if cross_attention_kwargs is not None else {} + ) + if self.add_self_attention: + norm_hidden_states = self.norm1(hidden_states, emb) + height, weight = norm_hidden_states.shape[2:] + norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) + attn_output = self.attn1( + norm_hidden_states, encoder_hidden_states=None, **cross_attention_kwargs + ) + attn_output = self._to_4d(attn_output, height, weight) + hidden_states = attn_output + hidden_states + norm_hidden_states = self.norm2(hidden_states, emb) + height, weight = norm_hidden_states.shape[2:] + norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + **cross_attention_kwargs, + ) + attn_output = self._to_4d(attn_output, height, weight) + hidden_states = attn_output + hidden_states + return hidden_states + + +class KCrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + temb_channels: "int", + cross_attention_dim: "int", + dropout: "float" = 0.0, + num_layers: "int" = 4, + resnet_group_size: "int" = 32, + add_downsample=True, + attn_num_head_channels: "int" = 64, + add_self_attention: "bool" = False, + resnet_eps: "float" = 1e-05, + resnet_act_fn: "str" = "gelu", + ): + super().__init__() + resnets = [] + attentions = [] + self.has_cross_attention = True + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=groups, + groups_out=groups_out, + eps=resnet_eps, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + attentions.append( + KAttentionBlock( + out_channels, + out_channels // attn_num_head_channels, + attn_num_head_channels, + cross_attention_dim=cross_attention_dim, + temb_channels=temb_channels, + attention_bias=True, + add_self_attention=add_self_attention, + cross_attention_norm=True, + group_size=resnet_group_size, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + if add_downsample: + self.downsamplers = nn.ModuleList([KDownsample2D()]) + else: + self.downsamplers = None + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + cross_attention_kwargs=None, + ): + output_states = () + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + attention_mask, + cross_attention_kwargs, + ) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + if self.downsamplers is None: + output_states += (None,) + else: + output_states += (hidden_states,) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + return hidden_states, output_states + + +class AttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + prev_output_channel: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + for i in range(num_layers): + res_skip_channels = in_channels if i == num_layers - 1 else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + for resnet, attn in zip(self.resnets, self.attentions): + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + return hidden_states + + +class CrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + prev_output_channel: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + for i in range(num_layers): + res_skip_channels = in_channels if i == num_layers - 1 else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + cross_attention_kwargs=None, + upsample_size=None, + attention_mask=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + return hidden_states + + +class UpBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + prev_output_channel: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + res_skip_channels = in_channels if i == num_layers - 1 else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.resnets = nn.ModuleList(resnets) + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + self.gradient_checkpointing = False + + def forward( + self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None + ): + for resnet in self.resnets: + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + return hidden_states + + +class UpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.resnets = nn.ModuleList(resnets) + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + def forward(self, hidden_states): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + return hidden_states + + +class AttnUpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + def forward(self, hidden_states): + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = attn(hidden_states) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + return hidden_states + + +class AttnSkipUpBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + prev_output_channel: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + output_scale_factor=np.sqrt(2.0), + upsample_padding=1, + add_upsample=True, + ): + super().__init__() + self.attentions = nn.ModuleList([]) + self.resnets = nn.ModuleList([]) + for i in range(num_layers): + res_skip_channels = in_channels if i == num_layers - 1 else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(resnet_in_channels + res_skip_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + ) + ) + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2d( + out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) + ) + self.skip_norm = torch.nn.GroupNorm( + num_groups=min(out_channels // 4, 32), + num_channels=out_channels, + eps=resnet_eps, + affine=True, + ) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + def forward( + self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None + ): + for resnet in self.resnets: + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + hidden_states = self.attentions[0](hidden_states) + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + skip_sample = skip_sample + skip_sample_states + hidden_states = self.resnet_up(hidden_states, temb) + return hidden_states, skip_sample + + +class SkipUpBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + prev_output_channel: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_pre_norm: "bool" = True, + output_scale_factor=np.sqrt(2.0), + add_upsample=True, + upsample_padding=1, + ): + super().__init__() + self.resnets = nn.ModuleList([]) + for i in range(num_layers): + res_skip_channels = in_channels if i == num_layers - 1 else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min((resnet_in_channels + res_skip_channels) // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2d( + out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) + ) + self.skip_norm = torch.nn.GroupNorm( + num_groups=min(out_channels // 4, 32), + num_channels=out_channels, + eps=resnet_eps, + affine=True, + ) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + def forward( + self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None + ): + for resnet in self.resnets: + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + skip_sample = skip_sample + skip_sample_states + hidden_states = self.resnet_up(hidden_states, temb) + return hidden_states, skip_sample + + +class ResnetUpsampleBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + prev_output_channel: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + res_skip_channels = in_channels if i == num_layers - 1 else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.resnets = nn.ModuleList(resnets) + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + up=True, + ) + ] + ) + else: + self.upsamplers = None + self.gradient_checkpointing = False + + def forward( + self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None + ): + for resnet in self.resnets: + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, temb) + return hidden_states + + +class SimpleCrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + prev_output_channel: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + self.num_heads = out_channels // self.attn_num_head_channels + for i in range(num_layers): + res_skip_channels = in_channels if i == num_layers - 1 else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + CrossAttention( + query_dim=out_channels, + cross_attention_dim=out_channels, + heads=self.num_heads, + dim_head=attn_num_head_channels, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + processor=CrossAttnAddedKVProcessor(), + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + up=True, + ) + ] + ) + else: + self.upsamplers = None + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + cross_attention_kwargs=None, + ): + cross_attention_kwargs = ( + cross_attention_kwargs if cross_attention_kwargs is not None else {} + ) + for resnet, attn in zip(self.resnets, self.attentions): + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, temb) + return hidden_states + + +class KUpBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 5, + resnet_eps: "float" = 1e-05, + resnet_act_fn: "str" = "gelu", + resnet_group_size: "Optional[int]" = 32, + add_upsample=True, + ): + super().__init__() + resnets = [] + k_in_channels = 2 * out_channels + k_out_channels = in_channels + num_layers = num_layers - 1 + for i in range(num_layers): + in_channels = k_in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=k_out_channels + if i == num_layers - 1 + else out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=groups, + groups_out=groups_out, + dropout=dropout, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + self.resnets = nn.ModuleList(resnets) + if add_upsample: + self.upsamplers = nn.ModuleList([KUpsample2D()]) + else: + self.upsamplers = None + self.gradient_checkpointing = False + + def forward( + self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None + ): + res_hidden_states_tuple = res_hidden_states_tuple[-1] + if res_hidden_states_tuple is not None: + hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + return hidden_states + + +class KCrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 4, + resnet_eps: "float" = 1e-05, + resnet_act_fn: "str" = "gelu", + resnet_group_size: "int" = 32, + attn_num_head_channels=1, + cross_attention_dim: "int" = 768, + add_upsample: "bool" = True, + upcast_attention: "bool" = False, + ): + super().__init__() + resnets = [] + attentions = [] + is_first_block = in_channels == out_channels == temb_channels + is_middle_block = in_channels != out_channels + add_self_attention = True if is_first_block else False + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + k_in_channels = out_channels if is_first_block else 2 * out_channels + k_out_channels = in_channels + num_layers = num_layers - 1 + for i in range(num_layers): + in_channels = k_in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + if is_middle_block and i == num_layers - 1: + conv_2d_out_channels = k_out_channels + else: + conv_2d_out_channels = None + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + conv_2d_out_channels=conv_2d_out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=groups, + groups_out=groups_out, + dropout=dropout, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + attentions.append( + KAttentionBlock( + k_out_channels if i == num_layers - 1 else out_channels, + k_out_channels // attn_num_head_channels + if i == num_layers - 1 + else out_channels // attn_num_head_channels, + attn_num_head_channels, + cross_attention_dim=cross_attention_dim, + temb_channels=temb_channels, + attention_bias=True, + add_self_attention=add_self_attention, + cross_attention_norm=True, + upcast_attention=upcast_attention, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + if add_upsample: + self.upsamplers = nn.ModuleList([KUpsample2D()]) + else: + self.upsamplers = None + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + cross_attention_kwargs=None, + upsample_size=None, + attention_mask=None, + ): + res_hidden_states_tuple = res_hidden_states_tuple[-1] + if res_hidden_states_tuple is not None: + hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + attention_mask, + cross_attention_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + return hidden_states + + +AttnProcessor = Union[ + CrossAttnProcessor, + XFormersCrossAttnProcessor, + SlicedAttnProcessor, + CrossAttnAddedKVProcessor, + SlicedAttnAddedKVProcessor, + LoRACrossAttnProcessor, + LoRAXFormersCrossAttnProcessor, +] + +LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" + +LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" + + +class UNet2DConditionLoadersMixin: + def load_attn_procs( + self, + pretrained_model_name_or_path_or_dict: "Union[str, Dict[str, torch.Tensor]]", + **kwargs, + ): + """ + Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be + defined in + [cross_attention.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py) + and be a `torch.nn.Module` class. + + + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., + `./my_model_directory/`. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `diffusers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use + this method in a firewalled environment. + + + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + model_file = None + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + if ( + is_safetensors_available() + and weight_name is None + or weight_name is not None + and weight_name.endswith(".safetensors") + ): + try: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except EnvironmentError: + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path_or_dict + attn_processors = {} + is_lora = all("lora" in k for k in state_dict.keys()) + if is_lora: + lora_grouped_dict = defaultdict(dict) + for key, value in state_dict.items(): + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join( + key.split(".")[-3:] + ) + lora_grouped_dict[attn_processor_key][sub_key] = value + for key, value_dict in lora_grouped_dict.items(): + rank = value_dict["to_k_lora.down.weight"].shape[0] + cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] + hidden_size = value_dict["to_k_lora.up.weight"].shape[0] + attn_processors[key] = LoRACrossAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + rank=rank, + ) + attn_processors[key].load_state_dict(value_dict) + else: + raise ValueError( + f"{model_file} does not seem to be in the correct format expected by LoRA training." + ) + attn_processors = {k: v for k, v in attn_processors.items()} + self.set_attn_processor(attn_processors) + + def save_attn_procs( + self, + save_directory: "Union[str, os.PathLike]", + is_main_process: "bool" = True, + weight_name: "str" = None, + save_function: "Callable" = None, + safe_serialization: "bool" = False, + **kwargs, + ): + """ + Save an attention processor to a directory, so that it can be re-loaded using the + `[`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`]` method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + """ + weight_name = weight_name or deprecate( + "weights_name", + "0.18.0", + "`weights_name` is deprecated, please use `weight_name` instead.", + take_from=kwargs, + ) + if os.path.isfile(save_directory): + logger.error( + f"Provided path ({save_directory}) should be a directory, not a file" + ) + return + if save_function is None: + if safe_serialization: + + def save_function(weights, filename): + return safetensors.torch.save_file( + weights, filename, metadata={"format": "pt"} + ) + + else: + save_function = torch.save + os.makedirs(save_directory, exist_ok=True) + model_to_save = AttnProcsLayers(self.attn_processors) + state_dict = model_to_save.state_dict() + if weight_name is None: + if safe_serialization: + weight_name = LORA_WEIGHT_NAME_SAFE + else: + weight_name = LORA_WEIGHT_NAME + save_function(state_dict, os.path.join(save_directory, weight_name)) + logger.info( + f"Model weights saved in {os.path.join(save_directory, weight_name)}" + ) + + +@dataclass +class UNet2DConditionOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: "torch.FloatTensor" + + +class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + """ + UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep + and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the + mid block layer if `None`. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, it will skip the normalization and activation layers in post-processing + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to None): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, or `"projection"`. + num_class_embeds (`int`, *optional*, defaults to None): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, default to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + timestep_post_act (`str, *optional*, default to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, default to `None`): + The dimension of `cond_proj` layer in timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: "Optional[int]" = None, + in_channels: "int" = 4, + out_channels: "int" = 4, + center_input_sample: "bool" = False, + flip_sin_to_cos: "bool" = True, + freq_shift: "int" = 0, + down_block_types: "Tuple[str]" = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: "Optional[str]" = "UNetMidBlock2DCrossAttn", + up_block_types: "Tuple[str]" = ( + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + ), + only_cross_attention: "Union[bool, Tuple[bool]]" = False, + block_out_channels: "Tuple[int]" = (320, 640, 1280, 1280), + layers_per_block: "int" = 2, + downsample_padding: "int" = 1, + mid_block_scale_factor: "float" = 1, + act_fn: "str" = "silu", + norm_num_groups: "Optional[int]" = 32, + norm_eps: "float" = 1e-05, + cross_attention_dim: "int" = 1280, + attention_head_dim: "Union[int, Tuple[int]]" = 8, + dual_cross_attention: "bool" = False, + use_linear_projection: "bool" = False, + class_embed_type: "Optional[str]" = None, + num_class_embeds: "Optional[int]" = None, + upcast_attention: "bool" = False, + resnet_time_scale_shift: "str" = "default", + time_embedding_type: "str" = "positional", + timestep_post_act: "Optional[str]" = None, + time_cond_proj_dim: "Optional[int]" = None, + conv_in_kernel: "int" = 3, + conv_out_kernel: "int" = 3, + projection_class_embeddings_input_dim: "Optional[int]" = None, + ): + super().__init__() + self.sample_size = sample_size + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + if not isinstance(only_cross_attention, bool) and len( + only_cross_attention + ) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len( + down_block_types + ): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=conv_in_kernel, + padding=conv_in_padding, + ) + if time_embedding_type == "fourier": + time_embed_dim = block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError( + f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}." + ) + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, + set_W_to_weight=False, + log=False, + flip_sin_to_cos=flip_sin_to_cos, + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps( + block_out_channels[0], flip_sin_to_cos, freq_shift + ) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`." + ) + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = TimestepEmbedding( + projection_class_embeddings_input_dim, time_embed_dim + ) + else: + self.class_embedding = None + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + self.mid_block = UNetMidBlock2DSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + self.num_upsamplers = 0 + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[ + min(i + 1, len(block_out_channels) - 1) + ] + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], + num_groups=norm_num_groups, + eps=norm_eps, + ) + self.conv_act = nn.SiLU() + else: + self.conv_norm_out = None + self.conv_act = None + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], + out_channels, + kernel_size=conv_out_kernel, + padding=conv_out_padding, + ) + + @property + def attn_processors(self) -> Dict[str, AttnProcessor]: + """ + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + processors = {} + + def fn_recursive_add_processors( + name: "str", + module: "torch.nn.Module", + processors: "Dict[str, AttnProcessor]", + ): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + return processors + + def set_attn_processor( + self, processor: "Union[AttnProcessor, Dict[str, AttnProcessor]]" + ): + """ + Parameters: + `processor (`dict` of `AttnProcessor` or `AttnProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + of **all** `CrossAttention` layers. + In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: + + """ + count = len(self.attn_processors.keys()) + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor( + name: "str", module: "torch.nn.Module", processor + ): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_attention_slice(self, slice_size): + """ + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: "torch.nn.Module"): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + num_slicable_layers = len(sliceable_head_dims) + if slice_size == "auto": + slice_size = [(dim // 2) for dim in sliceable_head_dims] + elif slice_size == "max": + slice_size = num_slicable_layers * [1] + slice_size = ( + num_slicable_layers * [slice_size] + if not isinstance(slice_size, list) + else slice_size + ) + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + def fn_recursive_set_attention_slice( + module: "torch.nn.Module", slice_size: "List[int]" + ): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance( + module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D) + ): + module.gradient_checkpointing = value + + def forward( + self, + sample: "torch.FloatTensor", + timestep: "Union[torch.Tensor, float, int]", + encoder_hidden_states: "torch.Tensor", + class_labels: "Optional[torch.Tensor]" = None, + timestep_cond: "Optional[torch.Tensor]" = None, + attention_mask: "Optional[torch.Tensor]" = None, + cross_attention_kwargs: "Optional[Dict[str, Any]]" = None, + down_block_additional_residuals: "Optional[Tuple[torch.Tensor]]" = None, + mid_block_additional_residual: "Optional[torch.Tensor]" = None, + return_dict: "bool" = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + """ + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + default_overall_up_factor = 2 ** self.num_upsamplers + forward_upsample_size = False + upsample_size = None + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + if attention_mask is not None: + attention_mask = (1 - attention_mask) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + timesteps = timestep + if not torch.is_tensor(timesteps): + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None] + timesteps = timesteps.expand(sample.shape[0]) + t_emb = self.time_proj(timesteps) + t_emb = t_emb + emb = self.time_embedding(t_emb, timestep_cond) + if self.class_embedding is not None: + if class_labels is None: + raise ValueError( + "class_labels should be provided when num_class_embeds > 0" + ) + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + class_emb = self.class_embedding(class_labels) + emb = emb + class_emb + sample = self.conv_in(sample) + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if ( + hasattr(downsample_block, "has_cross_attention") + and downsample_block.has_cross_attention + ): + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + down_block_res_samples += res_samples + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = ( + down_block_res_sample + down_block_additional_residual + ) + new_down_block_res_samples += (down_block_res_sample,) + down_block_res_samples = new_down_block_res_samples + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[ + : -len(upsample_block.resnets) + ] + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + if ( + hasattr(upsample_block, "has_cross_attention") + and upsample_block.has_cross_attention + ): + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + ) + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + if not return_dict: + return (sample,) + return UNet2DConditionOutput(sample=sample) + + +def rename_key(key): + regex = "\\w+[.]\\d+" + pats = re.findall(regex, key) + for pat in pats: + key = key.replace(pat, "_".join(pat.split("."))) + return key + + +class Encoder(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + down_block_types=("DownEncoderBlock2D",), + block_out_channels=(64,), + layers_per_block=2, + norm_num_groups=32, + act_fn="silu", + double_z=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + self.conv_in = torch.nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1 + ) + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + down_block = get_down_block( + down_block_type, + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=1e-06, + downsample_padding=0, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attn_num_head_channels=None, + temb_channels=None, + ) + self.down_blocks.append(down_block) + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-06, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attn_num_head_channels=None, + resnet_groups=norm_num_groups, + temb_channels=None, + ) + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-06 + ) + self.conv_act = nn.SiLU() + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = nn.Conv2d( + block_out_channels[-1], conv_out_channels, 3, padding=1 + ) + + def forward(self, x): + sample = x + sample = self.conv_in(sample) + for down_block in self.down_blocks: + sample = down_block(sample) + sample = self.mid_block(sample) + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + return sample + + +class Decoder(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + up_block_types=("UpDecoderBlock2D",), + block_out_channels=(64,), + layers_per_block=2, + norm_num_groups=32, + act_fn="silu", + ): + super().__init__() + self.layers_per_block = layers_per_block + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1 + ) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-06, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attn_num_head_channels=None, + resnet_groups=norm_num_groups, + temb_channels=None, + ) + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + up_block = get_up_block( + up_block_type, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + prev_output_channel=None, + add_upsample=not is_final_block, + resnet_eps=1e-06, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attn_num_head_channels=None, + temb_channels=None, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-06 + ) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + def forward(self, z): + sample = z + sample = self.conv_in(sample) + sample = self.mid_block(sample) + for up_block in self.up_blocks: + sample = up_block(sample) + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + return sample + + +class VectorQuantizer(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix + multiplications and allows for post-hoc remapping of indices. + """ + + def __init__( + self, + n_e, + vq_embed_dim, + beta, + remap=None, + unknown_index="random", + sane_index_shape=False, + legacy=True, + ): + super().__init__() + self.n_e = n_e + self.vq_embed_dim = vq_embed_dim + self.beta = beta + self.legacy = legacy + self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + None + else: + self.re_embed = n_e + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used + if self.re_embed > self.used.shape[0]: + inds[inds >= self.used.shape[0]] = 0 + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z): + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.vq_embed_dim) + min_encoding_indices = torch.argmin( + torch.cdist(z_flattened, self.embedding.weight), dim=1 + ) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean( + (z_q - z.detach()) ** 2 + ) + else: + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean( + (z_q - z.detach()) ** 2 + ) + z_q = z + (z_q - z).detach() + z_q = z_q.permute(0, 3, 1, 2).contiguous() + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape( + z_q.shape[0], z_q.shape[2], z_q.shape[3] + ) + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + if self.remap is not None: + indices = indices.reshape(shape[0], -1) + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) + z_q = self.embedding(indices) + if shape is not None: + z_q = z_q.view(shape) + z_q = z_q.permute(0, 3, 1, 2).contiguous() + return z_q + + +@dataclass +class DecoderOutput(BaseOutput): + """ + Output of decoding method. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Decoded output sample of the model. Output of the last layer of the model. + """ + + sample: "torch.FloatTensor" + + +@dataclass +class VQEncoderOutput(BaseOutput): + """ + Output of VQModel encoding method. + + Args: + latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Encoded output sample of the model. Output of the last layer of the model. + """ + + latents: "torch.FloatTensor" + + +class VQModel(ModelMixin, ConfigMixin): + """VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray + Kavukcuoglu. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(64,)`): Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): TODO + num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. + vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. + scaling_factor (`float`, *optional*, defaults to `0.18215`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + """ + + @register_to_config + def __init__( + self, + in_channels: "int" = 3, + out_channels: "int" = 3, + down_block_types: "Tuple[str]" = ("DownEncoderBlock2D",), + up_block_types: "Tuple[str]" = ("UpDecoderBlock2D",), + block_out_channels: "Tuple[int]" = (64,), + layers_per_block: "int" = 1, + act_fn: "str" = "silu", + latent_channels: "int" = 3, + sample_size: "int" = 32, + num_vq_embeddings: "int" = 256, + norm_num_groups: "int" = 32, + vq_embed_dim: "Optional[int]" = None, + scaling_factor: "float" = 0.18215, + ): + super().__init__() + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=False, + ) + vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels + self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1) + self.quantize = VectorQuantizer( + num_vq_embeddings, + vq_embed_dim, + beta=0.25, + remap=None, + sane_index_shape=False, + ) + self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1) + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + ) + + def encode( + self, x: "torch.FloatTensor", return_dict: "bool" = True + ) -> VQEncoderOutput: + h = self.encoder(x) + h = self.quant_conv(h) + if not return_dict: + return (h,) + return VQEncoderOutput(latents=h) + + def decode( + self, + h: "torch.FloatTensor", + force_not_quantize: "bool" = False, + return_dict: "bool" = True, + ) -> Union[DecoderOutput, torch.FloatTensor]: + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, sample: "torch.FloatTensor", return_dict: "bool" = True + ) -> Union[DecoderOutput, torch.FloatTensor]: + """ + Args: + sample (`torch.FloatTensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + h = self.encode(x).latents + dec = self.decode(h).sample + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + +class LDMBertAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: "int", + num_heads: "int", + head_dim: "int", + dropout: "float" = 0.0, + is_decoder: "bool" = False, + bias: "bool" = False, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = head_dim + self.inner_dim = head_dim * num_heads + self.scaling = self.head_dim ** -0.5 + self.is_decoder = is_decoder + self.k_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.out_proj = nn.Linear(self.inner_dim, embed_dim) + + def _shape(self, tensor: "torch.Tensor", seq_len: "int", bsz: "int"): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: "torch.Tensor", + key_value_states: "Optional[torch.Tensor]" = None, + past_key_value: "Optional[Tuple[torch.Tensor]]" = None, + attention_mask: "Optional[torch.Tensor]" = None, + layer_head_mask: "Optional[torch.Tensor]" = None, + output_attentions: "bool" = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) * self.scaling + if is_cross_attention and past_key_value is not None: + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + if self.is_decoder: + past_key_value = key_states, value_states + proj_shape = bsz * self.num_heads, -1, self.head_dim + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {bsz * self.num_heads, tgt_len, src_len}, but is {attn_weights.size()}" + ) + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {bsz, 1, tgt_len, src_len}, but is {attention_mask.size()}" + ) + attn_weights = ( + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + attention_mask + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {self.num_heads,}, but is {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + if output_attentions: + attn_weights_reshaped = attn_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights_reshaped.view( + bsz * self.num_heads, tgt_len, src_len + ) + else: + attn_weights_reshaped = None + attn_probs = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + attn_output = torch.bmm(attn_probs, value_states) + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {bsz, self.num_heads, tgt_len, self.head_dim}, but is {attn_output.size()}" + ) + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim) + attn_output = self.out_proj(attn_output) + return attn_output, attn_weights_reshaped, past_key_value + + +class ClassInstantier(OrderedDict): + def __getitem__(self, key): + content = super().__getitem__(key) + cls, kwargs = content if isinstance(content, tuple) else (content, {}) + return cls(**kwargs) + + +# ACT2CLS = { +# "gelu": GELUActivation, +# "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}), +# "gelu_fast": FastGELUActivation, +# "gelu_new": NewGELUActivation, +# "gelu_python": (GELUActivation, {"use_gelu_python": True}), +# "gelu_pytorch_tanh": PytorchGELUTanh, +# "linear": LinearActivation, +# "mish": MishActivation, +# "quick_gelu": QuickGELUActivation, +# "relu": nn.ReLU, +# "relu6": nn.ReLU6, +# "sigmoid": nn.Sigmoid, +# "silu": SiLUActivation, +# "swish": SiLUActivation, +# "tanh": nn.Tanh, +# } +# ACT2FN = ClassInstantier(ACT2CLS) + + +class LDMBertEncoderLayer(nn.Module): + def __init__(self, config: "LDMBertConfig"): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = LDMBertAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + head_dim=config.head_dim, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: "torch.FloatTensor", + attention_mask: "torch.FloatTensor", + layer_head_mask: "torch.FloatTensor", + output_attentions: "Optional[bool]" = False, + ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout( + hidden_states, p=self.activation_dropout, training=self.training + ) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + hidden_states = residual + hidden_states + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + outputs = (hidden_states,) + if output_attentions: + outputs += (attn_weights,) + return outputs + + +class PaintByExampleMapper(nn.Module): + def __init__(self, config): + super().__init__() + num_layers = (config.num_hidden_layers + 1) // 5 + hid_size = config.hidden_size + num_heads = 1 + self.blocks = nn.ModuleList( + [ + BasicTransformerBlock( + hid_size, + num_heads, + hid_size, + activation_fn="gelu", + attention_bias=True, + ) + for _ in range(num_layers) + ] + ) + + def forward(self, hidden_states): + for block in self.blocks: + hidden_states = block(hidden_states) + return hidden_states + + +class GaussianSmoothing(torch.nn.Module): + """ + Arguments: + Apply gaussian smoothing on a 1d, 2d or 3d tensor. Filtering is performed seperately for each channel in the input + using a depthwise convolution. + channels (int, sequence): Number of channels of the input tensors. Output will + have this number of channels as well. + kernel_size (int, sequence): Size of the gaussian kernel. sigma (float, sequence): Standard deviation of the + gaussian kernel. dim (int, optional): The number of dimensions of the data. + Default value is 2 (spatial). + """ + + def __init__( + self, + channels: "int" = 1, + kernel_size: "int" = 3, + sigma: "float" = 0.5, + dim: "int" = 2, + ): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = [kernel_size] * dim + if isinstance(sigma, float): + sigma = [sigma] * dim + kernel = 1 + meshgrids = torch.meshgrid( + [torch.arange(size, dtype=torch.float32) for size in kernel_size] + ) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= ( + 1 + / (std * math.sqrt(2 * math.pi)) + * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2)) + ) + kernel = kernel / torch.sum(kernel) + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *([1] * (kernel.dim() - 1))) + self.register_buffer("weight", kernel) + self.groups = channels + if dim == 1: + self.conv = F.conv1d + elif dim == 2: + self.conv = F.conv2d + elif dim == 3: + self.conv = F.conv3d + else: + raise RuntimeError( + "Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim) + ) + + def forward(self, input): + """ + Arguments: + Apply gaussian filter to input. + input (torch.Tensor): Input to apply gaussian filter on. + Returns: + filtered (torch.Tensor): Filtered output. + """ + return self.conv(input, weight=self.weight, groups=self.groups) + + +class StableUnCLIPImageNormalizer(ModelMixin, ConfigMixin): + """ + This class is used to hold the mean and standard deviation of the CLIP embedder used in stable unCLIP. + + It is used to normalize the image embeddings before the noise is applied and un-normalize the noised image + embeddings. + """ + + @register_to_config + def __init__(self, embedding_dim: "int" = 768): + super().__init__() + self.mean = nn.Parameter(torch.zeros(1, embedding_dim)) + self.std = nn.Parameter(torch.ones(1, embedding_dim)) + + def scale(self, embeds): + embeds = (embeds - self.mean) * 1.0 / self.std + return embeds + + def unscale(self, embeds): + embeds = embeds * self.std + self.mean + return embeds + + +class UnCLIPTextProjModel(ModelMixin, ConfigMixin): + """ + Utility class for CLIP embeddings. Used to combine the image and text embeddings into a format usable by the + decoder. + + For more details, see the original paper: https://arxiv.org/abs/2204.06125 section 2.1 + """ + + @register_to_config + def __init__( + self, + *, + clip_extra_context_tokens: int = 4, + clip_embeddings_dim: int = 768, + time_embed_dim: int, + cross_attention_dim, + ): + super().__init__() + self.learned_classifier_free_guidance_embeddings = nn.Parameter( + torch.zeros(clip_embeddings_dim) + ) + self.embedding_proj = nn.Linear(clip_embeddings_dim, time_embed_dim) + self.clip_image_embeddings_project_to_time_embeddings = nn.Linear( + clip_embeddings_dim, time_embed_dim + ) + self.clip_extra_context_tokens = clip_extra_context_tokens + self.clip_extra_context_tokens_proj = nn.Linear( + clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim + ) + self.encoder_hidden_states_proj = nn.Linear( + clip_embeddings_dim, cross_attention_dim + ) + self.text_encoder_hidden_states_norm = nn.LayerNorm(cross_attention_dim) + + def forward( + self, + *, + image_embeddings, + prompt_embeds, + text_encoder_hidden_states, + do_classifier_free_guidance, + ): + if do_classifier_free_guidance: + image_embeddings_batch_size = image_embeddings.shape[0] + classifier_free_guidance_embeddings = ( + self.learned_classifier_free_guidance_embeddings.unsqueeze(0) + ) + classifier_free_guidance_embeddings = ( + classifier_free_guidance_embeddings.expand( + image_embeddings_batch_size, -1 + ) + ) + image_embeddings = torch.cat( + [classifier_free_guidance_embeddings, image_embeddings], dim=0 + ) + assert image_embeddings.shape[0] == prompt_embeds.shape[0] + batch_size = prompt_embeds.shape[0] + time_projected_prompt_embeds = self.embedding_proj(prompt_embeds) + time_projected_image_embeddings = ( + self.clip_image_embeddings_project_to_time_embeddings(image_embeddings) + ) + additive_clip_time_embeddings = ( + time_projected_image_embeddings + time_projected_prompt_embeds + ) + clip_extra_context_tokens = self.clip_extra_context_tokens_proj( + image_embeddings + ) + clip_extra_context_tokens = clip_extra_context_tokens.reshape( + batch_size, -1, self.clip_extra_context_tokens + ) + text_encoder_hidden_states = self.encoder_hidden_states_proj( + text_encoder_hidden_states + ) + text_encoder_hidden_states = self.text_encoder_hidden_states_norm( + text_encoder_hidden_states + ) + text_encoder_hidden_states = text_encoder_hidden_states.permute(0, 2, 1) + text_encoder_hidden_states = torch.cat( + [clip_extra_context_tokens, text_encoder_hidden_states], dim=2 + ) + return text_encoder_hidden_states, additive_clip_time_embeddings + + +class UNetMidBlockFlatCrossAttn(nn.Module): + def __init__( + self, + in_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + ): + super().__init__() + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + resnets = [ + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + for _ in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + cross_attention_kwargs=None, + ): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + hidden_states = resnet(hidden_states, temb) + return hidden_states + + +class UNetMidBlockFlatSimpleCrossAttn(nn.Module): + def __init__( + self, + in_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + ): + super().__init__() + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + self.num_heads = in_channels // self.attn_num_head_channels + resnets = [ + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + for _ in range(num_layers): + attentions.append( + CrossAttention( + query_dim=in_channels, + cross_attention_dim=in_channels, + heads=self.num_heads, + dim_head=attn_num_head_channels, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + processor=CrossAttnAddedKVProcessor(), + ) + ) + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + cross_attention_kwargs=None, + ): + cross_attention_kwargs = ( + cross_attention_kwargs if cross_attention_kwargs is not None else {} + ) + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + hidden_states = resnet(hidden_states, temb) + return hidden_states + + +class UNetFlatConditionModel(ModelMixin, ConfigMixin): + """ + UNetFlatConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a + timestep and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`): + The mid block type. Choose from `UNetMidBlockFlatCrossAttn` or `UNetMidBlockFlatSimpleCrossAttn`, will skip + the mid block layer if `None`. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",)`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, it will skip the normalization and activation layers in post-processing + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to None): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, or `"projection"`. + num_class_embeds (`int`, *optional*, defaults to None): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, default to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + timestep_post_act (`str, *optional*, default to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, default to `None`): + The dimension of `cond_proj` layer in timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: "Optional[int]" = None, + in_channels: "int" = 4, + out_channels: "int" = 4, + center_input_sample: "bool" = False, + flip_sin_to_cos: "bool" = True, + freq_shift: "int" = 0, + down_block_types: "Tuple[str]" = ( + "CrossAttnDownBlockFlat", + "CrossAttnDownBlockFlat", + "CrossAttnDownBlockFlat", + "DownBlockFlat", + ), + mid_block_type: "Optional[str]" = "UNetMidBlockFlatCrossAttn", + up_block_types: "Tuple[str]" = ( + "UpBlockFlat", + "CrossAttnUpBlockFlat", + "CrossAttnUpBlockFlat", + "CrossAttnUpBlockFlat", + ), + only_cross_attention: "Union[bool, Tuple[bool]]" = False, + block_out_channels: "Tuple[int]" = (320, 640, 1280, 1280), + layers_per_block: "int" = 2, + downsample_padding: "int" = 1, + mid_block_scale_factor: "float" = 1, + act_fn: "str" = "silu", + norm_num_groups: "Optional[int]" = 32, + norm_eps: "float" = 1e-05, + cross_attention_dim: "int" = 1280, + attention_head_dim: "Union[int, Tuple[int]]" = 8, + dual_cross_attention: "bool" = False, + use_linear_projection: "bool" = False, + class_embed_type: "Optional[str]" = None, + num_class_embeds: "Optional[int]" = None, + upcast_attention: "bool" = False, + resnet_time_scale_shift: "str" = "default", + time_embedding_type: "str" = "positional", + timestep_post_act: "Optional[str]" = None, + time_cond_proj_dim: "Optional[int]" = None, + conv_in_kernel: "int" = 3, + conv_out_kernel: "int" = 3, + projection_class_embeddings_input_dim: "Optional[int]" = None, + ): + super().__init__() + self.sample_size = sample_size + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + if not isinstance(only_cross_attention, bool) and len( + only_cross_attention + ) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len( + down_block_types + ): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = LinearMultiDim( + in_channels, + block_out_channels[0], + kernel_size=conv_in_kernel, + padding=conv_in_padding, + ) + if time_embedding_type == "fourier": + time_embed_dim = block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError( + f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}." + ) + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, + set_W_to_weight=False, + log=False, + flip_sin_to_cos=flip_sin_to_cos, + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps( + block_out_channels[0], flip_sin_to_cos, freq_shift + ) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`." + ) + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = TimestepEmbedding( + projection_class_embeddings_input_dim, time_embed_dim + ) + else: + self.class_embedding = None + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + if mid_block_type == "UNetMidBlockFlatCrossAttn": + self.mid_block = UNetMidBlockFlatCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + elif mid_block_type == "UNetMidBlockFlatSimpleCrossAttn": + self.mid_block = UNetMidBlockFlatSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + self.num_upsamplers = 0 + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[ + min(i + 1, len(block_out_channels) - 1) + ] + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], + num_groups=norm_num_groups, + eps=norm_eps, + ) + self.conv_act = nn.SiLU() + else: + self.conv_norm_out = None + self.conv_act = None + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = LinearMultiDim( + block_out_channels[0], + out_channels, + kernel_size=conv_out_kernel, + padding=conv_out_padding, + ) + + @property + def attn_processors(self) -> Dict[str, AttnProcessor]: + """ + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + processors = {} + + def fn_recursive_add_processors( + name: "str", + module: "torch.nn.Module", + processors: "Dict[str, AttnProcessor]", + ): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + return processors + + def set_attn_processor( + self, processor: "Union[AttnProcessor, Dict[str, AttnProcessor]]" + ): + """ + Parameters: + `processor (`dict` of `AttnProcessor` or `AttnProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + of **all** `CrossAttention` layers. + In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: + + """ + count = len(self.attn_processors.keys()) + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor( + name: "str", module: "torch.nn.Module", processor + ): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_attention_slice(self, slice_size): + """ + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: "torch.nn.Module"): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + num_slicable_layers = len(sliceable_head_dims) + if slice_size == "auto": + slice_size = [(dim // 2) for dim in sliceable_head_dims] + elif slice_size == "max": + slice_size = num_slicable_layers * [1] + slice_size = ( + num_slicable_layers * [slice_size] + if not isinstance(slice_size, list) + else slice_size + ) + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + def fn_recursive_set_attention_slice( + module: "torch.nn.Module", slice_size: "List[int]" + ): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance( + module, + (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat), + ): + module.gradient_checkpointing = value + + def forward( + self, + sample: "torch.FloatTensor", + timestep: "Union[torch.Tensor, float, int]", + encoder_hidden_states: "torch.Tensor", + class_labels: "Optional[torch.Tensor]" = None, + timestep_cond: "Optional[torch.Tensor]" = None, + attention_mask: "Optional[torch.Tensor]" = None, + cross_attention_kwargs: "Optional[Dict[str, Any]]" = None, + down_block_additional_residuals: "Optional[Tuple[torch.Tensor]]" = None, + mid_block_additional_residual: "Optional[torch.Tensor]" = None, + return_dict: "bool" = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + """ + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + default_overall_up_factor = 2 ** self.num_upsamplers + forward_upsample_size = False + upsample_size = None + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + if attention_mask is not None: + attention_mask = (1 - attention_mask) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + timesteps = timestep + if not torch.is_tensor(timesteps): + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None] + timesteps = timesteps.expand(sample.shape[0]) + t_emb = self.time_proj(timesteps) + t_emb = t_emb + emb = self.time_embedding(t_emb, timestep_cond) + if self.class_embedding is not None: + if class_labels is None: + raise ValueError( + "class_labels should be provided when num_class_embeds > 0" + ) + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + class_emb = self.class_embedding(class_labels) + emb = emb + class_emb + sample = self.conv_in(sample) + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if ( + hasattr(downsample_block, "has_cross_attention") + and downsample_block.has_cross_attention + ): + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + down_block_res_samples += res_samples + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = ( + down_block_res_sample + down_block_additional_residual + ) + new_down_block_res_samples += (down_block_res_sample,) + down_block_res_samples = new_down_block_res_samples + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[ + : -len(upsample_block.resnets) + ] + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + if ( + hasattr(upsample_block, "has_cross_attention") + and upsample_block.has_cross_attention + ): + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + ) + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + if not return_dict: + return (sample,) + return UNet2DConditionOutput(sample=sample) + + +class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin): + """ + Utility class for storing learned text embeddings for classifier free sampling + """ + + @register_to_config + def __init__( + self, + learnable: "bool", + hidden_size: "Optional[int]" = None, + length: "Optional[int]" = None, + ): + super().__init__() + self.learnable = learnable + if self.learnable: + assert ( + hidden_size is not None + ), "learnable=True requires `hidden_size` to be set" + assert length is not None, "learnable=True requires `length` to be set" + embeddings = torch.zeros(length, hidden_size) + else: + embeddings = None + self.embeddings = torch.nn.Parameter(embeddings) diff --git a/pi/models/unet/__init__.py b/pi/models/unet/__init__.py deleted file mode 100644 index 1d34370..0000000 --- a/pi/models/unet/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .unet_2d_condition import UNet2DConditionModel \ No newline at end of file diff --git a/pi/models/unet/attention.py b/pi/models/unet/attention.py deleted file mode 100644 index 1bd0215..0000000 --- a/pi/models/unet/attention.py +++ /dev/null @@ -1,519 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math -from typing import Callable, Optional - -from ... import nn -from ... import pi -from ...nn import functional as F - -from .cross_attention import CrossAttention -from .embeddings import CombinedTimestepLabelEmbeddings - - -xformers = None - - -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted - to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - Uses three q, k, v linear layers to compute attention. - - Parameters: - channels (`int`): The number of channels in the input and output. - num_head_channels (`int`, *optional*): - The number of channels in each head. If None, then `num_heads` = 1. - norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm. - rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by. - eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. - """ - - # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore - - def __init__( - self, - channels: int, - num_head_channels: Optional[int] = None, - norm_num_groups: int = 32, - rescale_output_factor: float = 1.0, - eps: float = 1e-5, - ): - super().__init__() - self.channels = channels - - self.num_heads = ( - channels // num_head_channels if num_head_channels is not None else 1 - ) - self.num_head_size = num_head_channels - self.group_norm = nn.GroupNorm( - num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True - ) - - # define q,k,v as linear layers - self.query = nn.Linear(channels, channels) - self.key = nn.Linear(channels, channels) - self.value = nn.Linear(channels, channels) - - self.rescale_output_factor = rescale_output_factor - self.proj_attn = nn.Linear(channels, channels, 1) - - self._use_memory_efficient_attention_xformers = False - self._attention_op = None - - def reshape_heads_to_batch_dim(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.num_heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape( - batch_size * head_size, seq_len, dim // head_size - ) - return tensor - - def reshape_batch_dim_to_heads(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.num_heads - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape( - batch_size // head_size, seq_len, dim * head_size - ) - return tensor - - def set_use_memory_efficient_attention_xformers( - self, - use_memory_efficient_attention_xformers: bool, - attention_op: Optional[Callable] = None, - ): - # if use_memory_efficient_attention_xformers: - # if not is_xformers_available(): - # raise ModuleNotFoundError( - # ( - # "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - # " xformers" - # ), - # name="xformers", - # ) - # elif not pi.cuda.is_available(): - # raise ValueError( - # "pi.cuda.is_available() should be True but is False. xformers' memory efficient attention is" - # " only available for GPU " - # ) - # else: - # try: - # # Make sure we can run the memory efficient attention - # _ = xformers.ops.memory_efficient_attention( - # pi.randn((1, 2, 40), device="cuda"), - # pi.randn((1, 2, 40), device="cuda"), - # pi.randn((1, 2, 40), device="cuda"), - # ) - # except Exception as e: - # raise e - # self._use_memory_efficient_attention_xformers = ( - # use_memory_efficient_attention_xformers - # ) - # self._attention_op = attention_op - raise NotImplementedError - - def forward(self, hidden_states): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose( - 1, 2 - ) - - # proj to q, k, v - query_proj = self.query(hidden_states) - key_proj = self.key(hidden_states) - value_proj = self.value(hidden_states) - - scale = 1 / math.sqrt(self.channels / self.num_heads) - - query_proj = self.reshape_heads_to_batch_dim(query_proj) - key_proj = self.reshape_heads_to_batch_dim(key_proj) - value_proj = self.reshape_heads_to_batch_dim(value_proj) - - if self._use_memory_efficient_attention_xformers: - # Memory efficient attention - hidden_states = xformers.ops.memory_efficient_attention( - query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op - ) - hidden_states = hidden_states.to(query_proj.dtype) - else: - attention_scores = pi.baddbmm( - pi.empty( - query_proj.shape[0], - query_proj.shape[1], - key_proj.shape[1], - dtype=query_proj.dtype, - device=query_proj.device, - ), - query_proj, - key_proj.transpose(-1, -2), - beta=0, - alpha=scale, - ) - attention_probs = pi.softmax(attention_scores.float(), dim=-1).type( - attention_scores.dtype - ) - hidden_states = pi.bmm(attention_probs, value_proj) - - # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - - # compute next hidden_states - hidden_states = self.proj_attn(hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch, channel, height, width - ) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - -class BasicTransformerBlock(nn.Module): - r""" - A basic Transformer block. - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm (: - obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. - attention_bias (: - obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. - """ - - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - dropout=0.0, - cross_attention_dim: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - attention_bias: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - norm_elementwise_affine: bool = True, - norm_type: str = "layer_norm", - final_dropout: bool = False, - ): - super().__init__() - self.only_cross_attention = only_cross_attention - - self.use_ada_layer_norm_zero = ( - num_embeds_ada_norm is not None - ) and norm_type == "ada_norm_zero" - self.use_ada_layer_norm = ( - num_embeds_ada_norm is not None - ) and norm_type == "ada_norm" - - if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: - raise ValueError( - f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" - f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." - ) - - # 1. Self-Attn - self.attn1 = CrossAttention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - ) - - self.ff = FeedForward( - dim, - dropout=dropout, - activation_fn=activation_fn, - final_dropout=final_dropout, - ) - - # 2. Cross-Attn - if cross_attention_dim is not None: - self.attn2 = CrossAttention( - query_dim=dim, - cross_attention_dim=cross_attention_dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - ) # is self-attn if encoder_hidden_states is none - else: - self.attn2 = None - - if self.use_ada_layer_norm: - self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) - elif self.use_ada_layer_norm_zero: - self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) - else: - self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - - if cross_attention_dim is not None: - # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. - # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during - # the second cross attention block. - self.norm2 = ( - AdaLayerNorm(dim, num_embeds_ada_norm) - if self.use_ada_layer_norm - else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - ) - else: - self.norm2 = None - - # 3. Feed-forward - self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - - def forward( - self, - hidden_states, - encoder_hidden_states=None, - timestep=None, - attention_mask=None, - cross_attention_kwargs=None, - class_labels=None, - ): - if self.use_ada_layer_norm: - norm_hidden_states = self.norm1(hidden_states, timestep) - elif self.use_ada_layer_norm_zero: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( - hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - else: - norm_hidden_states = self.norm1(hidden_states) - - # 1. Self-Attention - cross_attention_kwargs = ( - cross_attention_kwargs if cross_attention_kwargs is not None else {} - ) - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states - if self.only_cross_attention - else None, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - if self.use_ada_layer_norm_zero: - attn_output = gate_msa.unsqueeze(1) * attn_output - hidden_states = attn_output + hidden_states - - if self.attn2 is not None: - norm_hidden_states = ( - self.norm2(hidden_states, timestep) - if self.use_ada_layer_norm - else self.norm2(hidden_states) - ) - - # 2. Cross-Attention - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - hidden_states = attn_output + hidden_states - - # 3. Feed-forward - norm_hidden_states = self.norm3(hidden_states) - - if self.use_ada_layer_norm_zero: - norm_hidden_states = ( - norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - ) - - ff_output = self.ff(norm_hidden_states) - - if self.use_ada_layer_norm_zero: - ff_output = gate_mlp.unsqueeze(1) * ff_output - - hidden_states = ff_output + hidden_states - - return hidden_states - - -class FeedForward(nn.Module): - r""" - A feed-forward layer. - - Parameters: - dim (`int`): The number of channels in the input. - dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. - """ - - def __init__( - self, - dim: int, - dim_out: Optional[int] = None, - mult: int = 4, - dropout: float = 0.0, - activation_fn: str = "geglu", - final_dropout: bool = False, - ): - super().__init__() - inner_dim = int(dim * mult) - dim_out = dim_out if dim_out is not None else dim - - if activation_fn == "gelu": - act_fn = GELU(dim, inner_dim) - if activation_fn == "gelu-approximate": - act_fn = GELU(dim, inner_dim, approximate="tanh") - elif activation_fn == "geglu": - act_fn = GEGLU(dim, inner_dim) - elif activation_fn == "geglu-approximate": - act_fn = ApproximateGELU(dim, inner_dim) - - self.net = nn.ModuleList([]) - # project in - self.net.append(act_fn) - # project dropout - self.net.append(nn.Dropout(dropout)) - # project out - self.net.append(nn.Linear(inner_dim, dim_out)) - # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout - if final_dropout: - self.net.append(nn.Dropout(dropout)) - - def forward(self, hidden_states): - for module in self.net: - hidden_states = module(hidden_states) - return hidden_states - - -class GELU(nn.Module): - r""" - GELU activation function with tanh approximation support with `approximate="tanh"`. - """ - - def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out) - self.approximate = approximate - - def gelu(self, gate): - if gate.device.type != "mps": - return F.gelu(gate, approximate=self.approximate) - # mps: gelu is not implemented for float16 - return F.gelu(gate.to(dtype=pi.float32), approximate=self.approximate).to( - dtype=gate.dtype - ) - - def forward(self, hidden_states): - hidden_states = self.proj(hidden_states) - hidden_states = self.gelu(hidden_states) - return hidden_states - - -class GEGLU(nn.Module): - r""" - A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. - - Parameters: - dim_in (`int`): The number of channels in the input. - dim_out (`int`): The number of channels in the output. - """ - - def __init__(self, dim_in: int, dim_out: int): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) - - def gelu(self, gate): - if gate.device.type != "mps": - return F.gelu(gate) - # mps: gelu is not implemented for float16 - return F.gelu(gate.to(dtype=pi.float32)).to(dtype=gate.dtype) - - def forward(self, hidden_states): - hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) - return hidden_states * self.gelu(gate) - - -class ApproximateGELU(nn.Module): - """ - The approximate form of Gaussian Error Linear Unit (GELU) - - For more details, see section 2: https://arxiv.org/abs/1606.08415 - """ - - def __init__(self, dim_in: int, dim_out: int): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out) - - def forward(self, x): - x = self.proj(x) - return x * pi.sigmoid(1.702 * x) - - -class AdaLayerNorm(nn.Module): - """ - Norm layer modified to incorporate timestep embeddings. - """ - - def __init__(self, embedding_dim, num_embeddings): - super().__init__() - self.emb = nn.Embedding(num_embeddings, embedding_dim) - self.silu = nn.SiLU() - self.linear = nn.Linear(embedding_dim, embedding_dim * 2) - self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) - - def forward(self, x, timestep): - emb = self.linear(self.silu(self.emb(timestep))) - scale, shift = pi.chunk(emb, 2) - x = self.norm(x) * (1 + scale) + shift - return x - - -class AdaLayerNormZero(nn.Module): - """ - Norm layer adaptive layer norm zero (adaLN-Zero). - """ - - def __init__(self, embedding_dim, num_embeddings): - super().__init__() - - self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) - - self.silu = nn.SiLU() - self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) - self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) - - def forward(self, x, timestep, class_labels, hidden_dtype=None): - emb = self.linear( - self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)) - ) - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk( - 6, dim=1 - ) - x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] - return x, gate_msa, shift_mlp, scale_mlp, gate_mlp diff --git a/pi/models/unet/cross_attention.py b/pi/models/unet/cross_attention.py deleted file mode 100644 index dc15f86..0000000 --- a/pi/models/unet/cross_attention.py +++ /dev/null @@ -1,679 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Callable, Optional, Union - -from ... import nn -from ... import pi -from ...nn import functional as F - - -xformers = None - - -class CrossAttention(nn.Module): - r""" - A cross attention layer. - - Parameters: - query_dim (`int`): The number of channels in the query. - cross_attention_dim (`int`, *optional*): - The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. - heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. - dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - bias (`bool`, *optional*, defaults to False): - Set to `True` for the query, key, and value linear layers to contain a bias parameter. - """ - - def __init__( - self, - query_dim: int, - cross_attention_dim: Optional[int] = None, - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - bias=False, - upcast_attention: bool = False, - upcast_softmax: bool = False, - added_kv_proj_dim: Optional[int] = None, - norm_num_groups: Optional[int] = None, - processor: Optional["AttnProcessor"] = None, - ): - super().__init__() - inner_dim = dim_head * heads - cross_attention_dim = ( - cross_attention_dim if cross_attention_dim is not None else query_dim - ) - self.upcast_attention = upcast_attention - self.upcast_softmax = upcast_softmax - - self.scale = dim_head ** -0.5 - - self.heads = heads - # for slice_size > 0 the attention score computation - # is split across the batch axis to save memory - # You can set slice_size with `set_attention_slice` - self.sliceable_head_dim = heads - - self.added_kv_proj_dim = added_kv_proj_dim - - if norm_num_groups is not None: - self.group_norm = nn.GroupNorm( - num_channels=inner_dim, - num_groups=norm_num_groups, - eps=1e-5, - affine=True, - ) - else: - self.group_norm = None - - self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) - self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) - self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) - - if self.added_kv_proj_dim is not None: - self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) - self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) - - self.to_out = nn.ModuleList([]) - self.to_out.append(nn.Linear(inner_dim, query_dim)) - self.to_out.append(nn.Dropout(dropout)) - - # set attention processor - processor = processor if processor is not None else CrossAttnProcessor() - self.set_processor(processor) - - def set_use_memory_efficient_attention_xformers( - self, - use_memory_efficient_attention_xformers: bool, - attention_op: Optional[Callable] = None, - ): - if use_memory_efficient_attention_xformers: - # if self.added_kv_proj_dim is not None: - # # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP - # # which uses this type of cross attention ONLY because the attention mask of format - # # [0, ..., -10.000, ..., 0, ...,] is not supported - # raise NotImplementedError( - # "Memory efficient attention with `xformers` is currently not supported when" - # " `self.added_kv_proj_dim` is defined." - # ) - # elif not is_xformers_available(): - # raise ModuleNotFoundError( - # ( - # "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - # " xformers" - # ), - # name="xformers", - # ) - # elif not pi.cuda.is_available(): - # raise ValueError( - # "pi.cuda.is_available() should be True but is False. xformers' memory efficient attention is" - # " only available for GPU " - # ) - # else: - # try: - # # Make sure we can run the memory efficient attention - # _ = xformers.ops.memory_efficient_attention( - # pi.randn((1, 2, 40), device="cuda"), - # pi.randn((1, 2, 40), device="cuda"), - # pi.randn((1, 2, 40), device="cuda"), - # ) - # except Exception as e: - # raise e - # - # processor = XFormersCrossAttnProcessor(attention_op=attention_op) - raise NotImplementedError - else: - processor = CrossAttnProcessor() - - self.set_processor(processor) - - def set_attention_slice(self, slice_size): - if slice_size is not None and slice_size > self.sliceable_head_dim: - raise ValueError( - f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}." - ) - - if slice_size is not None and self.added_kv_proj_dim is not None: - processor = SlicedAttnAddedKVProcessor(slice_size) - elif slice_size is not None: - processor = SlicedAttnProcessor(slice_size) - elif self.added_kv_proj_dim is not None: - processor = CrossAttnAddedKVProcessor() - else: - processor = CrossAttnProcessor() - - self.set_processor(processor) - - def set_processor(self, processor: "AttnProcessor"): - self.processor = processor - - def forward( - self, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - **cross_attention_kwargs, - ): - # The `CrossAttention` class can call different attention processors / attention functions - # here we simply pass along all tensors to the selected processor class - # For standard processors that are defined here, `**cross_attention_kwargs` is empty - return self.processor( - self, - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - - def batch_to_head_dim(self, tensor): - head_size = self.heads - batch_size, seq_len, dim = tensor.shape - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape( - batch_size // head_size, seq_len, dim * head_size - ) - return tensor - - def head_to_batch_dim(self, tensor): - head_size = self.heads - batch_size, seq_len, dim = tensor.shape - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape( - batch_size * head_size, seq_len, dim // head_size - ) - return tensor - - def get_attention_scores(self, query, key, attention_mask=None): - dtype = query.dtype - if self.upcast_attention: - query = query.float() - key = key.float() - - attention_scores = pi.baddbmm( - pi.empty( - query.shape[0], - query.shape[1], - key.shape[1], - dtype=query.dtype, - device=query.device, - ), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - - if attention_mask is not None: - attention_scores = attention_scores + attention_mask - - if self.upcast_softmax: - attention_scores = attention_scores.float() - - attention_probs = attention_scores.softmax(dim=-1) - attention_probs = attention_probs.to(dtype) - - return attention_probs - - def prepare_attention_mask(self, attention_mask, target_length): - head_size = self.heads - if attention_mask is None: - return attention_mask - - if attention_mask.shape[-1] != target_length: - if attention_mask.device.type == "mps": - # HACK: MPS: Does not support padding by greater than dimension of input tensor. - # Instead, we can manually construct the padding tensor. - padding_shape = ( - attention_mask.shape[0], - attention_mask.shape[1], - target_length, - ) - padding = pi.zeros(padding_shape, device=attention_mask.device) - attention_mask = pi.concat([attention_mask, padding], dim=2) - else: - attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) - attention_mask = attention_mask.repeat_interleave(head_size, dim=0) - return attention_mask - - -class CrossAttnProcessor: - def __call__( - self, - attn: CrossAttention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - ): - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - - query = attn.to_q(hidden_states) - query = attn.head_to_batch_dim(query) - - encoder_hidden_states = ( - encoder_hidden_states - if encoder_hidden_states is not None - else hidden_states - ) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = pi.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - - -class LoRALinearLayer(nn.Module): - def __init__(self, in_features, out_features, rank=4): - super().__init__() - - if rank > min(in_features, out_features): - raise ValueError( - f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}" - ) - - self.down = nn.Linear(in_features, rank, bias=False) - self.up = nn.Linear(rank, out_features, bias=False) - self.scale = 1.0 - - nn.init.normal_(self.down.weight, std=1 / rank) - nn.init.zeros_(self.up.weight) - - def forward(self, hidden_states): - orig_dtype = hidden_states.dtype - dtype = self.down.weight.dtype - - down_hidden_states = self.down(hidden_states.to(dtype)) - up_hidden_states = self.up(down_hidden_states) - - return up_hidden_states.to(orig_dtype) - - -class LoRACrossAttnProcessor(nn.Module): - def __init__(self, hidden_size, cross_attention_dim=None, rank=4): - super().__init__() - - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size) - self.to_k_lora = LoRALinearLayer( - cross_attention_dim or hidden_size, hidden_size - ) - self.to_v_lora = LoRALinearLayer( - cross_attention_dim or hidden_size, hidden_size - ) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) - - def __call__( - self, - attn: CrossAttention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - scale=1.0, - ): - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - - query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) - query = attn.head_to_batch_dim(query) - - encoder_hidden_states = ( - encoder_hidden_states - if encoder_hidden_states is not None - else hidden_states - ) - - key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora( - encoder_hidden_states - ) - value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora( - encoder_hidden_states - ) - - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = pi.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora( - hidden_states - ) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - - -class CrossAttnAddedKVProcessor: - def __call__( - self, - attn: CrossAttention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - ): - residual = hidden_states - hidden_states = hidden_states.view( - hidden_states.shape[0], hidden_states.shape[1], -1 - ).transpose(1, 2) - batch_size, sequence_length, _ = hidden_states.shape - encoder_hidden_states = encoder_hidden_states.transpose(1, 2) - - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - query = attn.head_to_batch_dim(query) - - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.head_to_batch_dim( - encoder_hidden_states_key_proj - ) - encoder_hidden_states_value_proj = attn.head_to_batch_dim( - encoder_hidden_states_value_proj - ) - - key = pi.concat([encoder_hidden_states_key_proj, key], dim=1) - value = pi.concat([encoder_hidden_states_value_proj, value], dim=1) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = pi.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) - hidden_states = hidden_states + residual - - return hidden_states - - -class XFormersCrossAttnProcessor: - def __init__(self, attention_op: Optional[Callable] = None): - self.attention_op = attention_op - - def __call__( - self, - attn: CrossAttention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - ): - batch_size, sequence_length, _ = hidden_states.shape - - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - - query = attn.to_q(hidden_states) - - encoder_hidden_states = ( - encoder_hidden_states - if encoder_hidden_states is not None - else hidden_states - ) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - query = attn.head_to_batch_dim(query).contiguous() - key = attn.head_to_batch_dim(key).contiguous() - value = attn.head_to_batch_dim(value).contiguous() - - hidden_states = xformers.ops.memory_efficient_attention( - query, key, value, attn_bias=attention_mask, op=self.attention_op - ) - hidden_states = hidden_states.to(query.dtype) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - return hidden_states - - -class LoRAXFormersCrossAttnProcessor(nn.Module): - def __init__(self, hidden_size, cross_attention_dim, rank=4): - super().__init__() - - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size) - self.to_k_lora = LoRALinearLayer( - cross_attention_dim or hidden_size, hidden_size - ) - self.to_v_lora = LoRALinearLayer( - cross_attention_dim or hidden_size, hidden_size - ) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) - - def __call__( - self, - attn: CrossAttention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - scale=1.0, - ): - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - - query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) - query = attn.head_to_batch_dim(query).contiguous() - - encoder_hidden_states = ( - encoder_hidden_states - if encoder_hidden_states is not None - else hidden_states - ) - - key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora( - encoder_hidden_states - ) - value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora( - encoder_hidden_states - ) - - key = attn.head_to_batch_dim(key).contiguous() - value = attn.head_to_batch_dim(value).contiguous() - - hidden_states = xformers.ops.memory_efficient_attention( - query, key, value, attn_bias=attention_mask - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora( - hidden_states - ) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - - -class SlicedAttnProcessor: - def __init__(self, slice_size): - self.slice_size = slice_size - - def __call__( - self, - attn: CrossAttention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - ): - batch_size, sequence_length, _ = hidden_states.shape - - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - - query = attn.to_q(hidden_states) - dim = query.shape[-1] - query = attn.head_to_batch_dim(query) - - encoder_hidden_states = ( - encoder_hidden_states - if encoder_hidden_states is not None - else hidden_states - ) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - batch_size_attention = query.shape[0] - hidden_states = pi.zeros( - (batch_size_attention, sequence_length, dim // attn.heads), - device=query.device, - dtype=query.dtype, - ) - - for i in range(hidden_states.shape[0] // self.slice_size): - start_idx = i * self.slice_size - end_idx = (i + 1) * self.slice_size - - query_slice = query[start_idx:end_idx] - key_slice = key[start_idx:end_idx] - attn_mask_slice = ( - attention_mask[start_idx:end_idx] - if attention_mask is not None - else None - ) - - attn_slice = attn.get_attention_scores( - query_slice, key_slice, attn_mask_slice - ) - - attn_slice = pi.bmm(attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice - - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - - -class SlicedAttnAddedKVProcessor: - def __init__(self, slice_size): - self.slice_size = slice_size - - def __call__( - self, - attn: "CrossAttention", - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - ): - residual = hidden_states - hidden_states = hidden_states.view( - hidden_states.shape[0], hidden_states.shape[1], -1 - ).transpose(1, 2) - encoder_hidden_states = encoder_hidden_states.transpose(1, 2) - - batch_size, sequence_length, _ = hidden_states.shape - - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - dim = query.shape[-1] - query = attn.head_to_batch_dim(query) - - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - encoder_hidden_states_key_proj = attn.head_to_batch_dim( - encoder_hidden_states_key_proj - ) - encoder_hidden_states_value_proj = attn.head_to_batch_dim( - encoder_hidden_states_value_proj - ) - - key = pi.concat([encoder_hidden_states_key_proj, key], dim=1) - value = pi.concat([encoder_hidden_states_value_proj, value], dim=1) - - batch_size_attention = query.shape[0] - hidden_states = pi.zeros( - (batch_size_attention, sequence_length, dim // attn.heads), - device=query.device, - dtype=query.dtype, - ) - - for i in range(hidden_states.shape[0] // self.slice_size): - start_idx = i * self.slice_size - end_idx = (i + 1) * self.slice_size - - query_slice = query[start_idx:end_idx] - key_slice = key[start_idx:end_idx] - attn_mask_slice = ( - attention_mask[start_idx:end_idx] - if attention_mask is not None - else None - ) - - attn_slice = attn.get_attention_scores( - query_slice, key_slice, attn_mask_slice - ) - - attn_slice = pi.bmm(attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice - - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) - hidden_states = hidden_states + residual - - return hidden_states - - -AttnProcessor = Union[ - CrossAttnProcessor, - XFormersCrossAttnProcessor, - SlicedAttnProcessor, - CrossAttnAddedKVProcessor, - SlicedAttnAddedKVProcessor, -] diff --git a/pi/models/unet/embeddings.py b/pi/models/unet/embeddings.py deleted file mode 100644 index de0cbe8..0000000 --- a/pi/models/unet/embeddings.py +++ /dev/null @@ -1,385 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math - -import numpy as np -from ... import nn -from ... import pi -from ...nn import functional as F - - -def get_timestep_embedding( - timesteps: pi.Tensor, - embedding_dim: int, - flip_sin_to_cos: bool = False, - downscale_freq_shift: float = 1, - scale: float = 1, - max_period: int = 10000, -): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. - - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the - embeddings. :return: an [N x dim] Tensor of positional embeddings. - """ - assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" - - half_dim = embedding_dim // 2 - exponent = -math.log(max_period) * pi.arange( - start=0, end=half_dim, dtype=pi.float32, device=timesteps.device - ) - exponent = exponent / (half_dim - downscale_freq_shift) - - emb = pi.exp(exponent) - emb = timesteps[:, None].float() * emb[None, :] - - # scale embeddings - emb = scale * emb - - # concat sine and cosine embeddings - emb = pi.cat([pi.sin(emb), pi.cos(emb)], dim=-1) - - # flip sine and cosine embeddings - if flip_sin_to_cos: - emb = pi.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) - - # zero pad - if embedding_dim % 2 == 1: - emb = pi.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb - - -def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): - """ - grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or - [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) - """ - grid_h = np.arange(grid_size, dtype=np.float32) - grid_w = np.arange(grid_size, dtype=np.float32) - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - grid = grid.reshape([2, 1, grid_size, grid_size]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) - if cls_token and extra_tokens > 0: - pos_embed = np.concatenate( - [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0 - ) - return pos_embed - - -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): - if embed_dim % 2 != 0: - raise ValueError("embed_dim must be divisible by 2") - - # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) - return emb - - -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): - """ - embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) - """ - if embed_dim % 2 != 0: - raise ValueError("embed_dim must be divisible by 2") - - omega = np.arange(embed_dim // 2, dtype=np.float64) - omega /= embed_dim / 2.0 - omega = 1.0 / 10000 ** omega # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product - - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) - - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - return emb - - -class PatchEmbed(nn.Module): - """2D Image to Patch Embedding""" - - def __init__( - self, - height=224, - width=224, - patch_size=16, - in_channels=3, - embed_dim=768, - layer_norm=False, - flatten=True, - bias=True, - ): - super().__init__() - - num_patches = (height // patch_size) * (width // patch_size) - self.flatten = flatten - self.layer_norm = layer_norm - - self.proj = nn.Conv2d( - in_channels, - embed_dim, - kernel_size=(patch_size, patch_size), - stride=patch_size, - bias=bias, - ) - if layer_norm: - self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) - else: - self.norm = None - - pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches ** 0.5)) - self.register_buffer( - "pos_embed", pi.from_numpy(pos_embed).float().unsqueeze(0), persistent=False - ) - - def forward(self, latent): - latent = self.proj(latent) - if self.flatten: - latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC - if self.layer_norm: - latent = self.norm(latent) - return latent + self.pos_embed - - -class TimestepEmbedding(nn.Module): - def __init__( - self, - in_channels: int, - time_embed_dim: int, - act_fn: str = "silu", - out_dim: int = None, - ): - super().__init__() - - self.linear_1 = nn.Linear(in_channels, time_embed_dim) - self.act = None - if act_fn == "silu": - self.act = nn.SiLU() - elif act_fn == "mish": - self.act = nn.Mish() - - if out_dim is not None: - time_embed_dim_out = out_dim - else: - time_embed_dim_out = time_embed_dim - self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) - - def forward(self, sample): - sample = self.linear_1(sample) - - if self.act is not None: - sample = self.act(sample) - - sample = self.linear_2(sample) - return sample - - -class Timesteps(nn.Module): - def __init__( - self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float - ): - super().__init__() - self.num_channels = num_channels - self.flip_sin_to_cos = flip_sin_to_cos - self.downscale_freq_shift = downscale_freq_shift - - def forward(self, timesteps): - t_emb = get_timestep_embedding( - timesteps, - self.num_channels, - flip_sin_to_cos=self.flip_sin_to_cos, - downscale_freq_shift=self.downscale_freq_shift, - ) - return t_emb - - -class GaussianFourierProjection(nn.Module): - """Gaussian Fourier embeddings for noise levels.""" - - def __init__( - self, - embedding_size: int = 256, - scale: float = 1.0, - set_W_to_weight=True, - log=True, - flip_sin_to_cos=False, - ): - super().__init__() - self.weight = nn.Parameter( - pi.randn(embedding_size) * scale, requires_grad=False - ) - self.log = log - self.flip_sin_to_cos = flip_sin_to_cos - - if set_W_to_weight: - # to delete later - self.W = nn.Parameter(pi.randn(embedding_size) * scale, requires_grad=False) - - self.weight = self.W - - def forward(self, x): - if self.log: - x = pi.log(x) - - x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi - - if self.flip_sin_to_cos: - out = pi.cat([pi.cos(x_proj), pi.sin(x_proj)], dim=-1) - else: - out = pi.cat([pi.sin(x_proj), pi.cos(x_proj)], dim=-1) - return out - - -class ImagePositionalEmbeddings(nn.Module): - """ - Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the - height and width of the latent space. - - For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 - - For VQ-diffusion: - - Output vector embeddings are used as input for the transformer. - - Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. - - Args: - num_embed (`int`): - Number of embeddings for the latent pixels embeddings. - height (`int`): - Height of the latent image i.e. the number of height embeddings. - width (`int`): - Width of the latent image i.e. the number of width embeddings. - embed_dim (`int`): - Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. - """ - - def __init__( - self, - num_embed: int, - height: int, - width: int, - embed_dim: int, - ): - super().__init__() - - self.height = height - self.width = width - self.num_embed = num_embed - self.embed_dim = embed_dim - - self.emb = nn.Embedding(self.num_embed, embed_dim) - self.height_emb = nn.Embedding(self.height, embed_dim) - self.width_emb = nn.Embedding(self.width, embed_dim) - - def forward(self, index): - emb = self.emb(index) - - height_emb = self.height_emb( - pi.arange(self.height, device=index.device).view(1, self.height) - ) - - # 1 x H x D -> 1 x H x 1 x D - height_emb = height_emb.unsqueeze(2) - - width_emb = self.width_emb( - pi.arange(self.width, device=index.device).view(1, self.width) - ) - - # 1 x W x D -> 1 x 1 x W x D - width_emb = width_emb.unsqueeze(1) - - pos_emb = height_emb + width_emb - - # 1 x H x W x D -> 1 x L xD - pos_emb = pos_emb.view(1, self.height * self.width, -1) - - emb = emb + pos_emb[:, : emb.shape[1], :] - - return emb - - -class LabelEmbedding(nn.Module): - """ - Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. - - Args: - num_classes (`int`): The number of classes. - hidden_size (`int`): The size of the vector embeddings. - dropout_prob (`float`): The probability of dropping a label. - """ - - def __init__(self, num_classes, hidden_size, dropout_prob): - super().__init__() - use_cfg_embedding = dropout_prob > 0 - self.embedding_table = nn.Embedding( - num_classes + use_cfg_embedding, hidden_size - ) - self.num_classes = num_classes - self.dropout_prob = dropout_prob - - def token_drop(self, labels, force_drop_ids=None): - """ - Drops labels to enable classifier-free guidance. - """ - if force_drop_ids is None: - drop_ids = ( - pi.rand(labels.shape[0], device=labels.device) < self.dropout_prob - ) - else: - drop_ids = pi.tensor(force_drop_ids == 1) - labels = pi.where(drop_ids, self.num_classes, labels) - return labels - - def forward(self, labels, force_drop_ids=None): - use_dropout = self.dropout_prob > 0 - if (self.training and use_dropout) or (force_drop_ids is not None): - labels = self.token_drop(labels, force_drop_ids) - embeddings = self.embedding_table(labels) - return embeddings - - -class CombinedTimestepLabelEmbeddings(nn.Module): - def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): - super().__init__() - - self.time_proj = Timesteps( - num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1 - ) - self.timestep_embedder = TimestepEmbedding( - in_channels=256, time_embed_dim=embedding_dim - ) - self.class_embedder = LabelEmbedding( - num_classes, embedding_dim, class_dropout_prob - ) - - def forward(self, timestep, class_labels, hidden_dtype=None): - timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder( - timesteps_proj.to(dtype=hidden_dtype) - ) # (N, D) - - class_labels = self.class_embedder(class_labels) # (N, D) - - conditioning = timesteps_emb + class_labels # (N, D) - - return conditioning diff --git a/pi/models/unet/resnet.py b/pi/models/unet/resnet.py deleted file mode 100644 index ec95523..0000000 --- a/pi/models/unet/resnet.py +++ /dev/null @@ -1,752 +0,0 @@ -from functools import partial - -from ... import nn -from ... import pi -from ...nn import functional as F - - -class Upsample1D(nn.Module): - """ - An upsampling layer with an optional convolution. - - Parameters: - channels: channels in the inputs and outputs. - use_conv: a bool determining if a convolution is applied. - use_conv_transpose: - out_channels: - """ - - def __init__( - self, - channels, - use_conv=False, - use_conv_transpose=False, - out_channels=None, - name="conv", - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_conv_transpose = use_conv_transpose - self.name = name - - self.conv = None - if use_conv_transpose: - self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) - elif use_conv: - self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) - - def forward(self, x): - assert x.shape[1] == self.channels - if self.use_conv_transpose: - return self.conv(x) - - x = F.interpolate(x, scale_factor=2.0, mode="nearest") - - if self.use_conv: - x = self.conv(x) - - return x - - -class Downsample1D(nn.Module): - """ - A downsampling layer with an optional convolution. - - Parameters: - channels: channels in the inputs and outputs. - use_conv: a bool determining if a convolution is applied. - out_channels: - padding: - """ - - def __init__( - self, channels, use_conv=False, out_channels=None, padding=1, name="conv" - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.padding = padding - stride = 2 - self.name = name - - if use_conv: - self.conv = nn.Conv1d( - self.channels, self.out_channels, 3, stride=stride, padding=padding - ) - else: - assert self.channels == self.out_channels - self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) - - def forward(self, x): - assert x.shape[1] == self.channels - return self.conv(x) - - -class Upsample2D(nn.Module): - """ - An upsampling layer with an optional convolution. - - Parameters: - channels: channels in the inputs and outputs. - use_conv: a bool determining if a convolution is applied. - use_conv_transpose: - out_channels: - """ - - def __init__( - self, - channels, - use_conv=False, - use_conv_transpose=False, - out_channels=None, - name="conv", - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_conv_transpose = use_conv_transpose - self.name = name - - conv = None - if use_conv_transpose: - conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) - elif use_conv: - conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) - - # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed - if name == "conv": - self.conv = conv - else: - self.Conv2d_0 = conv - - def forward(self, hidden_states, output_size=None): - assert hidden_states.shape[1] == self.channels - - if self.use_conv_transpose: - return self.conv(hidden_states) - - # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 - # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch - # https://github.com/pytorch/pytorch/issues/86679 - dtype = hidden_states.dtype - if dtype == pi.bfloat16: - hidden_states = hidden_states.to(pi.float32) - - # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 - if hidden_states.shape[0] >= 64: - hidden_states = hidden_states.contiguous() - - # if `output_size` is passed we force the interpolation output - # size and do not make use of `scale_factor=2` - if output_size is None: - hidden_states = F.interpolate( - hidden_states, scale_factor=2.0, mode="nearest" - ) - else: - hidden_states = F.interpolate( - hidden_states, size=output_size, mode="nearest" - ) - - # If the input is bfloat16, we cast back to bfloat16 - if dtype == pi.bfloat16: - hidden_states = hidden_states.to(dtype) - - # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed - if self.use_conv: - if self.name == "conv": - hidden_states = self.conv(hidden_states) - else: - hidden_states = self.Conv2d_0(hidden_states) - - return hidden_states - - -class Downsample2D(nn.Module): - """ - A downsampling layer with an optional convolution. - - Parameters: - channels: channels in the inputs and outputs. - use_conv: a bool determining if a convolution is applied. - out_channels: - padding: - """ - - def __init__( - self, channels, use_conv=False, out_channels=None, padding=1, name="conv" - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.padding = padding - stride = 2 - self.name = name - - if use_conv: - conv = nn.Conv2d( - self.channels, self.out_channels, 3, stride=stride, padding=padding - ) - else: - assert self.channels == self.out_channels - conv = nn.AvgPool2d(kernel_size=stride, stride=stride) - - # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed - if name == "conv": - self.Conv2d_0 = conv - self.conv = conv - elif name == "Conv2d_0": - self.conv = conv - else: - self.conv = conv - - def forward(self, hidden_states): - assert hidden_states.shape[1] == self.channels - if self.use_conv and self.padding == 0: - pad = (0, 1, 0, 1) - hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) - - assert hidden_states.shape[1] == self.channels - hidden_states = self.conv(hidden_states) - - return hidden_states - - -class FirUpsample2D(nn.Module): - def __init__( - self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1) - ): - super().__init__() - out_channels = out_channels if out_channels else channels - if use_conv: - self.Conv2d_0 = nn.Conv2d( - channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - self.use_conv = use_conv - self.fir_kernel = fir_kernel - self.out_channels = out_channels - - def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): - """Fused `upsample_2d()` followed by `Conv2d()`. - - Padding is performed only once at the beginning, not between the operations. The fused op is considerably more - efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of - arbitrary order. - - Args: - hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - weight: Weight tensor of the shape `[filterH, filterW, inChannels, - outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. - kernel: FIR filter of the shape `[firH, firW]` or `[firN]` - (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. - factor: Integer upsampling factor (default: 2). - gain: Scaling factor for signal magnitude (default: 1.0). - - Returns: - output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same - datatype as `hidden_states`. - """ - - assert isinstance(factor, int) and factor >= 1 - - # Setup filter kernel. - if kernel is None: - kernel = [1] * factor - - # setup kernel - kernel = pi.tensor(kernel, dtype=pi.float32) - if kernel.ndim == 1: - kernel = pi.outer(kernel, kernel) - kernel /= pi.sum(kernel) - - kernel = kernel * (gain * (factor ** 2)) - - if self.use_conv: - convH = weight.shape[2] - convW = weight.shape[3] - inC = weight.shape[1] - - pad_value = (kernel.shape[0] - factor) - (convW - 1) - - stride = (factor, factor) - # Determine data dimensions. - output_shape = ( - (hidden_states.shape[2] - 1) * factor + convH, - (hidden_states.shape[3] - 1) * factor + convW, - ) - output_padding = ( - output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH, - output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW, - ) - assert output_padding[0] >= 0 and output_padding[1] >= 0 - num_groups = hidden_states.shape[1] // inC - - # Transpose weights. - weight = pi.reshape(weight, (num_groups, -1, inC, convH, convW)) - weight = pi.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) - weight = pi.reshape(weight, (num_groups * inC, -1, convH, convW)) - - inverse_conv = F.conv_transpose2d( - hidden_states, - weight, - stride=stride, - output_padding=output_padding, - padding=0, - ) - - output = upfirdn2d_native( - inverse_conv, - pi.tensor(kernel, device=inverse_conv.device), - pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1), - ) - else: - pad_value = kernel.shape[0] - factor - output = upfirdn2d_native( - hidden_states, - pi.tensor(kernel, device=hidden_states.device), - up=factor, - pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), - ) - - return output - - def forward(self, hidden_states): - if self.use_conv: - height = self._upsample_2d( - hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel - ) - height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) - else: - height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) - - return height - - -class FirDownsample2D(nn.Module): - def __init__( - self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1) - ): - super().__init__() - out_channels = out_channels if out_channels else channels - if use_conv: - self.Conv2d_0 = nn.Conv2d( - channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - self.fir_kernel = fir_kernel - self.use_conv = use_conv - self.out_channels = out_channels - - def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): - """Fused `Conv2d()` followed by `downsample_2d()`. - Padding is performed only once at the beginning, not between the operations. The fused op is considerably more - efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of - arbitrary order. - - Args: - hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - weight: - Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be - performed by `inChannels = x.shape[0] // numGroups`. - kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * - factor`, which corresponds to average pooling. - factor: Integer downsampling factor (default: 2). - gain: Scaling factor for signal magnitude (default: 1.0). - - Returns: - output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and - same datatype as `x`. - """ - - assert isinstance(factor, int) and factor >= 1 - if kernel is None: - kernel = [1] * factor - - # setup kernel - kernel = pi.tensor(kernel, dtype=pi.float32) - if kernel.ndim == 1: - kernel = pi.outer(kernel, kernel) - kernel /= pi.sum(kernel) - - kernel = kernel * gain - - if self.use_conv: - _, _, convH, convW = weight.shape - pad_value = (kernel.shape[0] - factor) + (convW - 1) - stride_value = [factor, factor] - upfirdn_input = upfirdn2d_native( - hidden_states, - pi.tensor(kernel, device=hidden_states.device), - pad=((pad_value + 1) // 2, pad_value // 2), - ) - output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) - else: - pad_value = kernel.shape[0] - factor - output = upfirdn2d_native( - hidden_states, - pi.tensor(kernel, device=hidden_states.device), - down=factor, - pad=((pad_value + 1) // 2, pad_value // 2), - ) - - return output - - def forward(self, hidden_states): - if self.use_conv: - downsample_input = self._downsample_2d( - hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel - ) - hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) - else: - hidden_states = self._downsample_2d( - hidden_states, kernel=self.fir_kernel, factor=2 - ) - - return hidden_states - - -class ResnetBlock2D(nn.Module): - def __init__( - self, - *, - in_channels, - out_channels=None, - conv_shortcut=False, - dropout=0.0, - temb_channels=512, - groups=32, - groups_out=None, - pre_norm=True, - eps=1e-6, - non_linearity="swish", - time_embedding_norm="default", - kernel=None, - output_scale_factor=1.0, - use_in_shortcut=None, - up=False, - down=False, - ): - super().__init__() - self.pre_norm = pre_norm - self.pre_norm = True - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - self.time_embedding_norm = time_embedding_norm - self.up = up - self.down = down - self.output_scale_factor = output_scale_factor - - if groups_out is None: - groups_out = groups - - self.norm1 = pi.nn.GroupNorm( - num_groups=groups, num_channels=in_channels, eps=eps, affine=True - ) - - self.conv1 = pi.nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - - if temb_channels is not None: - if self.time_embedding_norm == "default": - time_emb_proj_out_channels = out_channels - elif self.time_embedding_norm == "scale_shift": - time_emb_proj_out_channels = out_channels * 2 - else: - raise ValueError( - f"unknown time_embedding_norm : {self.time_embedding_norm} " - ) - - self.time_emb_proj = pi.nn.Linear(temb_channels, time_emb_proj_out_channels) - else: - self.time_emb_proj = None - - self.norm2 = pi.nn.GroupNorm( - num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True - ) - self.dropout = pi.nn.Dropout(dropout) - self.conv2 = pi.nn.Conv2d( - out_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - elif non_linearity == "mish": - self.nonlinearity = Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - - self.upsample = self.downsample = None - if self.up: - if kernel == "fir": - fir_kernel = (1, 3, 3, 1) - self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) - elif kernel == "sde_vp": - self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") - else: - self.upsample = Upsample2D(in_channels, use_conv=False) - elif self.down: - if kernel == "fir": - fir_kernel = (1, 3, 3, 1) - self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) - elif kernel == "sde_vp": - self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) - else: - self.downsample = Downsample2D( - in_channels, use_conv=False, padding=1, name="op" - ) - - self.use_in_shortcut = ( - self.in_channels != self.out_channels - if use_in_shortcut is None - else use_in_shortcut - ) - - self.conv_shortcut = None - if self.use_in_shortcut: - self.conv_shortcut = pi.nn.Conv2d( - in_channels, out_channels, kernel_size=1, stride=1, padding=0 - ) - - def forward(self, input_tensor, temb): - hidden_states = input_tensor - - hidden_states = self.norm1(hidden_states) - hidden_states = self.nonlinearity(hidden_states) - - if self.upsample is not None: - # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 - if hidden_states.shape[0] >= 64: - input_tensor = input_tensor.contiguous() - hidden_states = hidden_states.contiguous() - input_tensor = self.upsample(input_tensor) - hidden_states = self.upsample(hidden_states) - elif self.downsample is not None: - input_tensor = self.downsample(input_tensor) - hidden_states = self.downsample(hidden_states) - - hidden_states = self.conv1(hidden_states) - - if temb is not None: - temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] - - if temb is not None and self.time_embedding_norm == "default": - hidden_states = hidden_states + temb - - hidden_states = self.norm2(hidden_states) - - if temb is not None and self.time_embedding_norm == "scale_shift": - scale, shift = pi.chunk(temb, 2, dim=1) - hidden_states = hidden_states * (1 + scale) + shift - - hidden_states = self.nonlinearity(hidden_states) - - hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) - - if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor) - - output_tensor = (input_tensor + hidden_states) / self.output_scale_factor - - return output_tensor - - -class Mish(pi.nn.Module): - def forward(self, hidden_states): - return hidden_states * pi.tanh(pi.nn.functional.softplus(hidden_states)) - - -# unet_rl.py -def rearrange_dims(tensor): - if len(tensor.shape) == 2: - return tensor[:, :, None] - if len(tensor.shape) == 3: - return tensor[:, :, None, :] - elif len(tensor.shape) == 4: - return tensor[:, :, 0, :] - else: - raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") - - -class Conv1dBlock(nn.Module): - """ - Conv1d --> GroupNorm --> Mish - """ - - def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): - super().__init__() - - self.conv1d = nn.Conv1d( - inp_channels, out_channels, kernel_size, padding=kernel_size // 2 - ) - self.group_norm = nn.GroupNorm(n_groups, out_channels) - self.mish = nn.Mish() - - def forward(self, x): - x = self.conv1d(x) - x = rearrange_dims(x) - x = self.group_norm(x) - x = rearrange_dims(x) - x = self.mish(x) - return x - - -# unet_rl.py -class ResidualTemporalBlock1D(nn.Module): - def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): - super().__init__() - self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) - self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size) - - self.time_emb_act = nn.Mish() - self.time_emb = nn.Linear(embed_dim, out_channels) - - self.residual_conv = ( - nn.Conv1d(inp_channels, out_channels, 1) - if inp_channels != out_channels - else nn.Identity() - ) - - def forward(self, x, t): - """ - Args: - x : [ batch_size x inp_channels x horizon ] - t : [ batch_size x embed_dim ] - - returns: - out : [ batch_size x out_channels x horizon ] - """ - t = self.time_emb_act(t) - t = self.time_emb(t) - out = self.conv_in(x) + rearrange_dims(t) - out = self.conv_out(out) - return out + self.residual_conv(x) - - -def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): - r"""Upsample2D a batch of 2D images with the given filter. - Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given - filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified - `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is - a: multiple of the upsampling factor. - - Args: - hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - kernel: FIR filter of the shape `[firH, firW]` or `[firN]` - (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. - factor: Integer upsampling factor (default: 2). - gain: Scaling factor for signal magnitude (default: 1.0). - - Returns: - output: Tensor of the shape `[N, C, H * factor, W * factor]` - """ - assert isinstance(factor, int) and factor >= 1 - if kernel is None: - kernel = [1] * factor - - kernel = pi.tensor(kernel, dtype=pi.float32) - if kernel.ndim == 1: - kernel = pi.outer(kernel, kernel) - kernel /= pi.sum(kernel) - - kernel = kernel * (gain * (factor ** 2)) - pad_value = kernel.shape[0] - factor - output = upfirdn2d_native( - hidden_states, - kernel.to(device=hidden_states.device), - up=factor, - pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), - ) - return output - - -def downsample_2d(hidden_states, kernel=None, factor=2, gain=1): - r"""Downsample2D a batch of 2D images with the given filter. - Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the - given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the - specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its - shape is a multiple of the downsampling factor. - - Args: - hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - kernel: FIR filter of the shape `[firH, firW]` or `[firN]` - (separable). The default is `[1] * factor`, which corresponds to average pooling. - factor: Integer downsampling factor (default: 2). - gain: Scaling factor for signal magnitude (default: 1.0). - - Returns: - output: Tensor of the shape `[N, C, H // factor, W // factor]` - """ - - assert isinstance(factor, int) and factor >= 1 - if kernel is None: - kernel = [1] * factor - - kernel = pi.tensor(kernel, dtype=pi.float32) - if kernel.ndim == 1: - kernel = pi.outer(kernel, kernel) - kernel /= pi.sum(kernel) - - kernel = kernel * gain - pad_value = kernel.shape[0] - factor - output = upfirdn2d_native( - hidden_states, - kernel.to(device=hidden_states.device), - down=factor, - pad=((pad_value + 1) // 2, pad_value // 2), - ) - return output - - -def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)): - up_x = up_y = up - down_x = down_y = down - pad_x0 = pad_y0 = pad[0] - pad_x1 = pad_y1 = pad[1] - - _, channel, in_h, in_w = tensor.shape - tensor = tensor.reshape(-1, in_h, in_w, 1) - - _, in_h, in_w, minor = tensor.shape - kernel_h, kernel_w = kernel.shape - - out = tensor.view(-1, in_h, 1, in_w, 1, minor) - out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) - out = out.view(-1, in_h * up_y, in_w * up_x, minor) - - out = F.pad( - out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] - ) - out = out.to(tensor.device) # Move back to mps if necessary - out = out[ - :, - max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), - max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), - :, - ] - - out = out.permute(0, 3, 1, 2) - out = out.reshape( - [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] - ) - w = pi.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) - out = F.conv2d(out, w) - out = out.reshape( - -1, - minor, - in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, - in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, - ) - out = out.permute(0, 2, 3, 1) - out = out[:, ::down_y, ::down_x, :] - - out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 - out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 - - return out.view(-1, channel, out_h, out_w) diff --git a/pi/models/unet/transformer_2d.py b/pi/models/unet/transformer_2d.py deleted file mode 100644 index 5376875..0000000 --- a/pi/models/unet/transformer_2d.py +++ /dev/null @@ -1,357 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from typing import Optional - -from ... import nn -from ... import pi -from ...nn import functional as F - -from .attention import BasicTransformerBlock -from .embeddings import PatchEmbed, ImagePositionalEmbeddings - - -@dataclass -class Transformer2DModelOutput: - """ - Args: - sample (`pi.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): - Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions - for the unnoised latent pixels. - """ - - sample: pi.FloatTensor - - -class Transformer2DModel(nn.Module): - """ - Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual - embeddings) inputs. - - When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard - transformer action. Finally, reshape to image. - - When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional - embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict - classes of unnoised image. - - Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised - image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): - Pass if the input is continuous. The number of channels in the input and output. - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. - sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. - Note that this is fixed at training time as it is used for learning a number of position embeddings. See - `ImagePositionalEmbeddings`. - num_vector_embeds (`int`, *optional*): - Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. - Includes the class for the masked latent pixel. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. - The number of diffusion steps used during training. Note that this is fixed at training time as it is used - to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for - up to but not more than steps than `num_embeds_ada_norm`. - attention_bias (`bool`, *optional*): - Configure if the TransformerBlocks' attention should contain a bias parameter. - """ - - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: Optional[int] = None, - out_channels: Optional[int] = None, - num_layers: int = 1, - dropout: float = 0.0, - norm_num_groups: int = 32, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - sample_size: Optional[int] = None, - num_vector_embeds: Optional[int] = None, - patch_size: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - norm_type: str = "layer_norm", - norm_elementwise_affine: bool = True, - ): - super().__init__() - self.use_linear_projection = use_linear_projection - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - inner_dim = num_attention_heads * attention_head_dim - - # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` - # Define whether input is continuous or discrete depending on configuration - self.is_input_continuous = (in_channels is not None) and (patch_size is None) - self.is_input_vectorized = num_vector_embeds is not None - self.is_input_patches = in_channels is not None and patch_size is not None - - if norm_type == "layer_norm" and num_embeds_ada_norm is not None: - deprecation_message = ( - f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" - " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." - " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" - " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" - " would be very nice if you could open a Pull request for the `transformer/config.json` file" - ) - norm_type = "ada_norm" - - if self.is_input_continuous and self.is_input_vectorized: - raise ValueError( - f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" - " sure that either `in_channels` or `num_vector_embeds` is None." - ) - elif self.is_input_vectorized and self.is_input_patches: - raise ValueError( - f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" - " sure that either `num_vector_embeds` or `num_patches` is None." - ) - elif ( - not self.is_input_continuous - and not self.is_input_vectorized - and not self.is_input_patches - ): - raise ValueError( - f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" - f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." - ) - - # 2. Define input layers - if self.is_input_continuous: - self.in_channels = in_channels - - self.norm = pi.nn.GroupNorm( - num_groups=norm_num_groups, - num_channels=in_channels, - eps=1e-6, - affine=True, - ) - if use_linear_projection: - self.proj_in = nn.Linear(in_channels, inner_dim) - else: - self.proj_in = nn.Conv2d( - in_channels, inner_dim, kernel_size=1, stride=1, padding=0 - ) - elif self.is_input_vectorized: - assert ( - sample_size is not None - ), "Transformer2DModel over discrete input must provide sample_size" - assert ( - num_vector_embeds is not None - ), "Transformer2DModel over discrete input must provide num_embed" - - self.height = sample_size - self.width = sample_size - self.num_vector_embeds = num_vector_embeds - self.num_latent_pixels = self.height * self.width - - self.latent_image_embedding = ImagePositionalEmbeddings( - num_embed=num_vector_embeds, - embed_dim=inner_dim, - height=self.height, - width=self.width, - ) - elif self.is_input_patches: - assert ( - sample_size is not None - ), "Transformer2DModel over patched input must provide sample_size" - - self.height = sample_size - self.width = sample_size - - self.patch_size = patch_size - self.pos_embed = PatchEmbed( - height=sample_size, - width=sample_size, - patch_size=patch_size, - in_channels=in_channels, - embed_dim=inner_dim, - ) - - # 3. Define transformers blocks - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - attention_bias=attention_bias, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - norm_type=norm_type, - norm_elementwise_affine=norm_elementwise_affine, - ) - for d in range(num_layers) - ] - ) - - # 4. Define output layers - self.out_channels = in_channels if out_channels is None else out_channels - if self.is_input_continuous: - # TODO: should use out_channels for continous projections - if use_linear_projection: - self.proj_out = nn.Linear(in_channels, inner_dim) - else: - self.proj_out = nn.Conv2d( - inner_dim, in_channels, kernel_size=1, stride=1, padding=0 - ) - elif self.is_input_vectorized: - self.norm_out = nn.LayerNorm(inner_dim) - self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) - elif self.is_input_patches: - self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) - self.proj_out_2 = nn.Linear( - inner_dim, patch_size * patch_size * self.out_channels - ) - - def forward( - self, - hidden_states, - encoder_hidden_states=None, - timestep=None, - class_labels=None, - cross_attention_kwargs=None, - return_dict: bool = True, - ): - """ - Args: - hidden_states ( When discrete, `pi.LongTensor` of shape `(batch size, num latent pixels)`. - When continous, `pi.FloatTensor` of shape `(batch size, channel, height, width)`): Input - hidden_states - encoder_hidden_states ( `pi.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - timestep ( `pi.long`, *optional*): - Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. - class_labels ( `pi.LongTensor` of shape `(batch size, num classes)`, *optional*): - Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels - conditioning. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. - - Returns: - [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: - [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - """ - # 1. Input - if self.is_input_continuous: - batch, _, height, width = hidden_states.shape - residual = hidden_states - - hidden_states = self.norm(hidden_states) - if not self.use_linear_projection: - hidden_states = self.proj_in(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( - batch, height * width, inner_dim - ) - else: - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( - batch, height * width, inner_dim - ) - hidden_states = self.proj_in(hidden_states) - elif self.is_input_vectorized: - hidden_states = self.latent_image_embedding(hidden_states) - elif self.is_input_patches: - hidden_states = self.pos_embed(hidden_states) - - # 2. Blocks - for block in self.transformer_blocks: - hidden_states = block( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) - - # 3. Output - if self.is_input_continuous: - if not self.use_linear_projection: - hidden_states = ( - hidden_states.reshape(batch, height, width, inner_dim) - .permute(0, 3, 1, 2) - .contiguous() - ) - hidden_states = self.proj_out(hidden_states) - else: - hidden_states = self.proj_out(hidden_states) - hidden_states = ( - hidden_states.reshape(batch, height, width, inner_dim) - .permute(0, 3, 1, 2) - .contiguous() - ) - - output = hidden_states + residual - elif self.is_input_vectorized: - hidden_states = self.norm_out(hidden_states) - logits = self.out(hidden_states) - # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) - logits = logits.permute(0, 2, 1) - - # log(p(x_0)) - output = F.log_softmax(logits.double(), dim=1).float() - elif self.is_input_patches: - # TODO: cleanup! - conditioning = self.transformer_blocks[0].norm1.emb( - timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) - hidden_states = ( - self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] - ) - hidden_states = self.proj_out_2(hidden_states) - - # unpatchify - height = width = int(hidden_states.shape[1] ** 0.5) - hidden_states = hidden_states.reshape( - shape=( - -1, - height, - width, - self.patch_size, - self.patch_size, - self.out_channels, - ) - ) - hidden_states = pi.einsum("nhwpqc->nchpwq", hidden_states) - output = hidden_states.reshape( - shape=( - -1, - self.out_channels, - height * self.patch_size, - width * self.patch_size, - ) - ) - - if not return_dict: - return (output,) - - return Transformer2DModelOutput(sample=output) diff --git a/pi/models/unet/unet_2d_blocks.py b/pi/models/unet/unet_2d_blocks.py deleted file mode 100644 index 75b51b5..0000000 --- a/pi/models/unet/unet_2d_blocks.py +++ /dev/null @@ -1,2312 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import numpy as np -from diffusers import Transformer2DModel - -from .attention import AttentionBlock -from .resnet import ResnetBlock2D -from ... import nn, Tensor -from ... import pi - - -def get_down_block( - down_block_type, - num_layers, - in_channels, - out_channels, - temb_channels, - add_downsample, - resnet_eps, - resnet_act_fn, - attn_num_head_channels, - resnet_groups=None, - cross_attention_dim=None, - downsample_padding=None, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - resnet_time_scale_shift="default", -): - down_block_type = ( - down_block_type[7:] - if down_block_type.startswith("UNetRes") - else down_block_type - ) - if down_block_type == "DownBlock2D": - return DownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "ResnetDownsampleBlock2D": - return ResnetDownsampleBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "AttnDownBlock2D": - return AttnDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "CrossAttnDownBlock2D": - if cross_attention_dim is None: - raise ValueError( - "cross_attention_dim must be specified for CrossAttnDownBlock2D" - ) - return CrossAttnDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "SimpleCrossAttnDownBlock2D": - if cross_attention_dim is None: - raise ValueError( - "cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D" - ) - return SimpleCrossAttnDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "SkipDownBlock2D": - return SkipDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "AttnSkipDownBlock2D": - return AttnSkipDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - downsample_padding=downsample_padding, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "DownEncoderBlock2D": - return DownEncoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "AttnDownEncoderBlock2D": - return AttnDownEncoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - raise ValueError(f"{down_block_type} does not exist.") - - -def get_up_block( - up_block_type, - num_layers, - in_channels, - out_channels, - prev_output_channel, - temb_channels, - add_upsample, - resnet_eps, - resnet_act_fn, - attn_num_head_channels, - resnet_groups=None, - cross_attention_dim=None, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - resnet_time_scale_shift="default", -): - up_block_type = ( - up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type - ) - if up_block_type == "UpBlock2D": - return UpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "ResnetUpsampleBlock2D": - return ResnetUpsampleBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "CrossAttnUpBlock2D": - if cross_attention_dim is None: - raise ValueError( - "cross_attention_dim must be specified for CrossAttnUpBlock2D" - ) - return CrossAttnUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "SimpleCrossAttnUpBlock2D": - if cross_attention_dim is None: - raise ValueError( - "cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D" - ) - return SimpleCrossAttnUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "AttnUpBlock2D": - return AttnUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "SkipUpBlock2D": - return SkipUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "AttnSkipUpBlock2D": - return AttnSkipUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "UpDecoderBlock2D": - return UpDecoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "AttnUpDecoderBlock2D": - return AttnUpDecoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - raise ValueError(f"{up_block_type} does not exist.") - - -class UNetMidBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - add_attention: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - ): - super().__init__() - resnet_groups = ( - resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - ) - self.add_attention = add_attention - - # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - - for _ in range(num_layers): - if self.add_attention: - attentions.append( - AttentionBlock( - in_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - ) - ) - else: - attentions.append(None) - - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward(self, hidden_states, temb=None): - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - if attn is not None: - hidden_states = attn(hidden_states) - hidden_states = resnet(hidden_states, temb) - - return hidden_states - - -class UNetMidBlock2DCrossAttn(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - cross_attention_dim=1280, - dual_cross_attention=False, - use_linear_projection=False, - upcast_attention=False, - ): - super().__init__() - - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - resnet_groups = ( - resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - ) - - # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - - for _ in range(num_layers): - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - attn_num_head_channels, - in_channels // attn_num_head_channels, - in_channels=in_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - attn_num_head_channels, - in_channels // attn_num_head_channels, - in_channels=in_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward( - self, - hidden_states, - temb=None, - encoder_hidden_states=None, - attention_mask=None, - cross_attention_kwargs=None, - ): - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - hidden_states = resnet(hidden_states, temb) - - return hidden_states - - -class UNetMidBlock2DSimpleCrossAttn(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - cross_attention_dim=1280, - ): - super().__init__() - - self.has_cross_attention = True - - self.attn_num_head_channels = attn_num_head_channels - resnet_groups = ( - resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - ) - - self.num_heads = in_channels // self.attn_num_head_channels - - # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - - for _ in range(num_layers): - attentions.append( - CrossAttention( - query_dim=in_channels, - cross_attention_dim=in_channels, - heads=self.num_heads, - dim_head=attn_num_head_channels, - added_kv_proj_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - bias=True, - upcast_softmax=True, - processor=CrossAttnAddedKVProcessor(), - ) - ) - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward( - self, - hidden_states, - temb=None, - encoder_hidden_states=None, - attention_mask=None, - cross_attention_kwargs=None, - ): - cross_attention_kwargs = ( - cross_attention_kwargs if cross_attention_kwargs is not None else {} - ) - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - # attn - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - - # resnet - hidden_states = resnet(hidden_states, temb) - - return hidden_states - - -class AttnDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - downsample_padding=1, - add_downsample=True, - ): - super().__init__() - resnets = [] - attentions = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - ) - ] - ) - else: - self.downsamplers = None - - def forward(self, hidden_states, temb=None): - output_states = () - - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states += (hidden_states,) - - return hidden_states, output_states - - -class CrossAttnDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - downsample_padding=1, - add_downsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - ): - super().__init__() - resnets = [] - attentions = [] - - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - temb=None, - encoder_hidden_states=None, - attention_mask=None, - cross_attention_kwargs=None, - ): - # TODO(Patrick, William) - attention mask is not used - output_states = () - - for resnet, attn in zip(self.resnets, self.attentions): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = pi.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - hidden_states = pi.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states += (hidden_states,) - - return hidden_states, output_states - - -class DownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_downsample=True, - downsample_padding=1, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward(self, hidden_states, temb=None): - output_states = () - - for resnet in self.resnets: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = pi.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - else: - hidden_states = resnet(hidden_states, temb) - - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states += (hidden_states,) - - return hidden_states, output_states - - -class DownEncoderBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_downsample=True, - downsample_padding=1, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - ) - ] - ) - else: - self.downsamplers = None - - def forward(self, hidden_states): - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=None) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - return hidden_states - - -class AttnDownEncoderBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - add_downsample=True, - downsample_padding=1, - ): - super().__init__() - resnets = [] - attentions = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - ) - ] - ) - else: - self.downsamplers = None - - def forward(self, hidden_states): - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb=None) - hidden_states = attn(hidden_states) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - return hidden_states - - -class AttnSkipDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=np.sqrt(2.0), - downsample_padding=1, - add_downsample=True, - ): - super().__init__() - self.attentions = nn.ModuleList([]) - self.resnets = nn.ModuleList([]) - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - self.resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(in_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - self.attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - ) - ) - - if add_downsample: - self.resnet_down = ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - use_in_shortcut=True, - down=True, - kernel="fir", - ) - self.downsamplers = nn.ModuleList( - [FirDownsample2D(out_channels, out_channels=out_channels)] - ) - self.skip_conv = nn.Conv2d( - 3, out_channels, kernel_size=(1, 1), stride=(1, 1) - ) - else: - self.resnet_down = None - self.downsamplers = None - self.skip_conv = None - - def forward(self, hidden_states, temb=None, skip_sample=None): - output_states = () - - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) - output_states += (hidden_states,) - - if self.downsamplers is not None: - hidden_states = self.resnet_down(hidden_states, temb) - for downsampler in self.downsamplers: - skip_sample = downsampler(skip_sample) - - hidden_states = self.skip_conv(skip_sample) + hidden_states - - output_states += (hidden_states,) - - return hidden_states, output_states, skip_sample - - -class SkipDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_pre_norm: bool = True, - output_scale_factor=np.sqrt(2.0), - add_downsample=True, - downsample_padding=1, - ): - super().__init__() - self.resnets = nn.ModuleList([]) - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - self.resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(in_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - if add_downsample: - self.resnet_down = ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - use_in_shortcut=True, - down=True, - kernel="fir", - ) - self.downsamplers = nn.ModuleList( - [FirDownsample2D(out_channels, out_channels=out_channels)] - ) - self.skip_conv = nn.Conv2d( - 3, out_channels, kernel_size=(1, 1), stride=(1, 1) - ) - else: - self.resnet_down = None - self.downsamplers = None - self.skip_conv = None - - def forward(self, hidden_states, temb=None, skip_sample=None): - output_states = () - - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb) - output_states += (hidden_states,) - - if self.downsamplers is not None: - hidden_states = self.resnet_down(hidden_states, temb) - for downsampler in self.downsamplers: - skip_sample = downsampler(skip_sample) - - hidden_states = self.skip_conv(skip_sample) + hidden_states - - output_states += (hidden_states,) - - return hidden_states, output_states, skip_sample - - -class ResnetDownsampleBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_downsample=True, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - down=True, - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward(self, hidden_states, temb=None): - output_states = () - - for resnet in self.resnets: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = pi.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - else: - hidden_states = resnet(hidden_states, temb) - - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, temb) - - output_states += (hidden_states,) - - return hidden_states, output_states - - -class SimpleCrossAttnDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - add_downsample=True, - ): - super().__init__() - - self.has_cross_attention = True - - resnets = [] - attentions = [] - - self.attn_num_head_channels = attn_num_head_channels - self.num_heads = out_channels // self.attn_num_head_channels - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - CrossAttention( - query_dim=out_channels, - cross_attention_dim=out_channels, - heads=self.num_heads, - dim_head=attn_num_head_channels, - added_kv_proj_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - bias=True, - upcast_softmax=True, - processor=CrossAttnAddedKVProcessor(), - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - down=True, - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - temb=None, - encoder_hidden_states=None, - attention_mask=None, - cross_attention_kwargs=None, - ): - output_states = () - cross_attention_kwargs = ( - cross_attention_kwargs if cross_attention_kwargs is not None else {} - ) - - for resnet, attn in zip(self.resnets, self.attentions): - # resnet - hidden_states = resnet(hidden_states, temb) - - # attn - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, temb) - - output_states += (hidden_states,) - - return hidden_states, output_states - - -class AttnUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - attentions = [] - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList( - [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] - ) - else: - self.upsamplers = None - - def forward(self, hidden_states, res_hidden_states_tuple, temb=None): - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = pi.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - - return hidden_states - - -class CrossAttnUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - prev_output_channel: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - add_upsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - ): - super().__init__() - resnets = [] - attentions = [] - - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList( - [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] - ) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - cross_attention_kwargs=None, - upsample_size=None, - attention_mask=None, - ): - # TODO(Patrick, William) - attention mask is not used - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = pi.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = pi.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - hidden_states = pi.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - - return hidden_states - - -class UpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList( - [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] - ) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None - ): - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = pi.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = pi.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - else: - hidden_states = resnet(hidden_states, temb) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - - return hidden_states - - -class UpDecoderBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList( - [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] - ) - else: - self.upsamplers = None - - def forward(self, hidden_states): - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=None) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - - return hidden_states - - -class AttnUpDecoderBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - attentions = [] - - for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList( - [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] - ) - else: - self.upsamplers = None - - def forward(self, hidden_states): - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb=None) - hidden_states = attn(hidden_states) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - - return hidden_states - - -class AttnSkipUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=np.sqrt(2.0), - upsample_padding=1, - add_upsample=True, - ): - super().__init__() - self.attentions = nn.ModuleList([]) - self.resnets = nn.ModuleList([]) - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - self.resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(resnet_in_channels + res_skip_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - ) - ) - - self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) - if add_upsample: - self.resnet_up = ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(out_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - use_in_shortcut=True, - up=True, - kernel="fir", - ) - self.skip_conv = nn.Conv2d( - out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) - ) - self.skip_norm = pi.nn.GroupNorm( - num_groups=min(out_channels // 4, 32), - num_channels=out_channels, - eps=resnet_eps, - affine=True, - ) - self.act = nn.SiLU() - else: - self.resnet_up = None - self.skip_conv = None - self.skip_norm = None - self.act = None - - def forward( - self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None - ): - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = pi.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - - hidden_states = self.attentions[0](hidden_states) - - if skip_sample is not None: - skip_sample = self.upsampler(skip_sample) - else: - skip_sample = 0 - - if self.resnet_up is not None: - skip_sample_states = self.skip_norm(hidden_states) - skip_sample_states = self.act(skip_sample_states) - skip_sample_states = self.skip_conv(skip_sample_states) - - skip_sample = skip_sample + skip_sample_states - - hidden_states = self.resnet_up(hidden_states, temb) - - return hidden_states, skip_sample - - -class SkipUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_pre_norm: bool = True, - output_scale_factor=np.sqrt(2.0), - add_upsample=True, - upsample_padding=1, - ): - super().__init__() - self.resnets = nn.ModuleList([]) - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - self.resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min((resnet_in_channels + res_skip_channels) // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) - if add_upsample: - self.resnet_up = ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(out_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - use_in_shortcut=True, - up=True, - kernel="fir", - ) - self.skip_conv = nn.Conv2d( - out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) - ) - self.skip_norm = pi.nn.GroupNorm( - num_groups=min(out_channels // 4, 32), - num_channels=out_channels, - eps=resnet_eps, - affine=True, - ) - self.act = nn.SiLU() - else: - self.resnet_up = None - self.skip_conv = None - self.skip_norm = None - self.act = None - - def forward( - self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None - ): - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = pi.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - - if skip_sample is not None: - skip_sample = self.upsampler(skip_sample) - else: - skip_sample = 0 - - if self.resnet_up is not None: - skip_sample_states = self.skip_norm(hidden_states) - skip_sample_states = self.act(skip_sample_states) - skip_sample_states = self.skip_conv(skip_sample_states) - - skip_sample = skip_sample + skip_sample_states - - hidden_states = self.resnet_up(hidden_states, temb) - - return hidden_states, skip_sample - - -class ResnetUpsampleBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList( - [ - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - up=True, - ) - ] - ) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None - ): - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = pi.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = pi.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - else: - hidden_states = resnet(hidden_states, temb) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, temb) - - return hidden_states - - -class SimpleCrossAttnUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - prev_output_channel: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - attentions = [] - - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - - self.num_heads = out_channels // self.attn_num_head_channels - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - CrossAttention( - query_dim=out_channels, - cross_attention_dim=out_channels, - heads=self.num_heads, - dim_head=attn_num_head_channels, - added_kv_proj_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - bias=True, - upcast_softmax=True, - processor=CrossAttnAddedKVProcessor(), - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList( - [ - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - up=True, - ) - ] - ) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - upsample_size=None, - attention_mask=None, - cross_attention_kwargs=None, - ): - cross_attention_kwargs = ( - cross_attention_kwargs if cross_attention_kwargs is not None else {} - ) - for resnet, attn in zip(self.resnets, self.attentions): - # resnet - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = pi.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - - # attn - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, temb) - - return hidden_states diff --git a/pi/models/unet/unet_2d_condition.py b/pi/models/unet/unet_2d_condition.py deleted file mode 100644 index e5df4d1..0000000 --- a/pi/models/unet/unet_2d_condition.py +++ /dev/null @@ -1,517 +0,0 @@ -import logging -import math -from dataclasses import dataclass -from typing import Optional, Tuple, Union, Dict, List, Any - -from diffusers.models.unet_2d_blocks import get_down_block - -from .cross_attention import AttnProcessor -from .embeddings import Timesteps, TimestepEmbedding -from .unet_2d_blocks import ( - UNetMidBlock2DCrossAttn, - UNetMidBlock2DSimpleCrossAttn, - get_up_block, - CrossAttnDownBlock2D, - DownBlock2D, - CrossAttnUpBlock2D, - UpBlock2D, -) -from ... import nn, Tensor -from ... import pi - - -logger = logging.getLogger(__name__) - - -@dataclass -class UNet2DConditionOutput: - """ - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. - """ - - sample: pi.FloatTensor - - -class UNet2DConditionModel: - def __init__( - self, - sample_size: Optional[int] = None, - in_channels: int = 4, - out_channels: int = 4, - center_input_sample: bool = False, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - down_block_types: Tuple[str] = ( - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D", - ), - mid_block_type: str = "UNetMidBlock2DCrossAttn", - up_block_types: Tuple[str] = ( - "UpBlock2D", - "CrossAttnUpBlock2D", - "CrossAttnUpBlock2D", - "CrossAttnUpBlock2D", - ), - only_cross_attention: Union[bool, Tuple[bool]] = False, - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - layers_per_block: int = 2, - downsample_padding: int = 1, - mid_block_scale_factor: float = 1, - act_fn: str = "silu", - norm_num_groups: int = 32, - norm_eps: float = 1e-5, - cross_attention_dim: int = 1280, - attention_head_dim: Union[int, Tuple[int]] = 8, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - class_embed_type: Optional[str] = None, - num_class_embeds: Optional[int] = None, - upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", - ): - super().__init__() - - self.sample_size = sample_size - time_embed_dim = block_out_channels[0] * 4 - - # input - self.conv_in = nn.Conv2d( - in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1) - ) - - # time - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] - - self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - - # class embedding - if class_embed_type is None and num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - elif class_embed_type == "identity": - self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) - else: - self.class_embedding = None - - self.down_blocks = nn.ModuleList([]) - self.mid_block = None - self.up_blocks = nn.ModuleList([]) - - if isinstance(only_cross_attention, bool): - only_cross_attention = [only_cross_attention] * len(down_block_types) - - if isinstance(attention_head_dim, int): - attention_head_dim = (attention_head_dim,) * len(down_block_types) - - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[i], - downsample_padding=downsample_padding, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - self.down_blocks.append(down_block) - - # mid - if mid_block_type == "UNetMidBlock2DCrossAttn": - self.mid_block = UNetMidBlock2DCrossAttn( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift=resnet_time_scale_shift, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[-1], - resnet_groups=norm_num_groups, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) - elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": - self.mid_block = UNetMidBlock2DSimpleCrossAttn( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[-1], - resnet_groups=norm_num_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - else: - raise ValueError(f"unknown mid_block_type : {mid_block_type}") - - # count how many layers upsample the images - self.num_upsamplers = 0 - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - reversed_attention_head_dim = list(reversed(attention_head_dim)) - only_cross_attention = list(reversed(only_cross_attention)) - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - is_final_block = i == len(block_out_channels) - 1 - - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[ - min(i + 1, len(block_out_channels) - 1) - ] - - # add upsample block for all BUT final layer - if not is_final_block: - add_upsample = True - self.num_upsamplers += 1 - else: - add_upsample = False - - up_block = get_up_block( - up_block_type, - num_layers=layers_per_block + 1, - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - temb_channels=time_embed_dim, - add_upsample=add_upsample, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=reversed_attention_head_dim[i], - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps - ) - self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d( - block_out_channels[0], out_channels, kernel_size=3, padding=1 - ) - - @property - def attn_processors(self) -> Dict[str, AttnProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors( - name: str, module: pi.nn.Module, processors: Dict[str, AttnProcessor] - ): - if hasattr(module, "set_processor"): - processors[f"{name}.processor"] = module.processor - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - def set_attn_processor( - self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]] - ): - r""" - Parameters: - `processor (`dict` of `AttnProcessor` or `AttnProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - of **all** `CrossAttention` layers. - In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: pi.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - def set_attention_slice(self, slice_size): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is - provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` - must be a multiple of `slice_size`. - """ - sliceable_head_dims = [] - - def fn_recursive_retrieve_slicable_dims(module: pi.nn.Module): - if hasattr(module, "set_attention_slice"): - sliceable_head_dims.append(module.sliceable_head_dim) - - for child in module.children(): - fn_recursive_retrieve_slicable_dims(child) - - # retrieve number of attention layers - for module in self.children(): - fn_recursive_retrieve_slicable_dims(module) - - num_slicable_layers = len(sliceable_head_dims) - - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = [dim // 2 for dim in sliceable_head_dims] - elif slice_size == "max": - # make smallest slice possible - slice_size = num_slicable_layers * [1] - - slice_size = ( - num_slicable_layers * [slice_size] - if not isinstance(slice_size, list) - else slice_size - ) - - if len(slice_size) != len(sliceable_head_dims): - raise ValueError( - f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" - f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." - ) - - for i in range(len(slice_size)): - size = slice_size[i] - dim = sliceable_head_dims[i] - if size is not None and size > dim: - raise ValueError(f"size {size} has to be smaller or equal to {dim}.") - - # Recursively walk through all the children. - # Any children which exposes the set_attention_slice method - # gets the message - def fn_recursive_set_attention_slice( - module: pi.nn.Module, slice_size: List[int] - ): - if hasattr(module, "set_attention_slice"): - module.set_attention_slice(slice_size.pop()) - - for child in module.children(): - fn_recursive_set_attention_slice(child, slice_size) - - reversed_slice_size = list(reversed(slice_size)) - for module in self.children(): - fn_recursive_set_attention_slice(module, reversed_slice_size) - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance( - module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D) - ): - module.gradient_checkpointing = value - - def forward( - self, - sample: pi.FloatTensor, - timestep: Union[pi.Tensor, float, int], - encoder_hidden_states: pi.Tensor, - class_labels: Optional[pi.Tensor] = None, - attention_mask: Optional[pi.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - ) -> Union[UNet2DConditionOutput, Tuple]: - r""" - Args: - sample (`pi.FloatTensor`): (batch, channel, height, width) noisy inputs tensor - timestep (`pi.FloatTensor` or `float` or `int`): (batch) timesteps - encoder_hidden_states (`pi.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. - - Returns: - [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - """ - # By default samples have to be AT least a multiple of the overall upsampling factor. - # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). - # However, the upsampling interpolation output size can be forced to fit any upsampling size - # on the fly if necessary. - default_overall_up_factor = 2 ** self.num_upsamplers - - # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` - forward_upsample_size = False - upsample_size = None - - if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - logger.info("Forward upsample size to force interpolation output size.") - forward_upsample_size = True - - # prepare attention_mask - if attention_mask is not None: - attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - - # 0. center input if necessary - if self.config.center_input_sample: - sample = 2 * sample - 1.0 - - # 1. time - timesteps = timestep - if not pi.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = pi.float32 if is_mps else pi.float64 - else: - dtype = pi.int32 if is_mps else pi.int64 - timesteps = pi.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - t_emb = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) - emb = self.time_embedding(t_emb) - - if self.class_embedding is not None: - if class_labels is None: - raise ValueError( - "class_labels should be provided when num_class_embeds > 0" - ) - - if self.config.class_embed_type == "timestep": - class_labels = self.time_proj(class_labels) - - class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) - emb = emb + class_emb - - # 2. pre-process - sample = self.conv_in(sample) - - # 3. down - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if ( - hasattr(downsample_block, "has_cross_attention") - and downsample_block.has_cross_attention - ): - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - - down_block_res_samples += res_samples - - # 4. mid - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) - - # 5. up - for i, upsample_block in enumerate(self.up_blocks): - is_final_block = i == len(self.up_blocks) - 1 - - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[ - : -len(upsample_block.resnets) - ] - - # if we have not reached the final block and need to forward the - # upsample size, we do it here - if not is_final_block and forward_upsample_size: - upsample_size = down_block_res_samples[-1].shape[2:] - - if ( - hasattr(upsample_block, "has_cross_attention") - and upsample_block.has_cross_attention - ): - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - upsample_size=upsample_size, - attention_mask=attention_mask, - ) - else: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - upsample_size=upsample_size, - ) - # 6. post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - if not return_dict: - return (sample,) - - return UNet2DConditionOutput(sample=sample) diff --git a/pyproject.toml b/pyproject.toml index 37876fc..314893f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,12 @@ requires = ["setuptools>=42", "ninja"] build-backend = "setuptools.build_meta" +[tool.pytest.ini_options] +log_cli = false +log_cli_level = "DEBUG" +log_cli_format = "[%(filename)s:%(funcName)s:%(lineno)d] %(message)s" +log_cli_date_format = "%Y-%m-%d %H:%M:%S" + [tool.cibuildwheel] before-build = "pip install -r build-requirements.txt -r requirements.txt -v" # HOLY FUCK From 1a7c9c6b8d33af4aef82ede1a7fdf839f002906d Mon Sep 17 00:00:00 2001 From: max Date: Thu, 16 Mar 2023 00:17:01 -0500 Subject: [PATCH 2/2] use lazy import --- examples/unet.py | 115 +- pi/lazy_importer/__init__.py | 0 pi/lazy_importer/lazy_imports.py | 377 + pi/lazy_importer/run_lazy_imports.py | 76 + pi/models/unet.py | 11078 ------------------------- pi/models/unet/hand_imports.py | 21 + pi/models/unet/linearized.py | 5014 +++++++++++ pi/models/unet/prologue.py | 10 + requirements.txt | 1 + 9 files changed, 5573 insertions(+), 11119 deletions(-) create mode 100644 pi/lazy_importer/__init__.py create mode 100644 pi/lazy_importer/lazy_imports.py create mode 100644 pi/lazy_importer/run_lazy_imports.py delete mode 100644 pi/models/unet.py create mode 100644 pi/models/unet/hand_imports.py create mode 100644 pi/models/unet/linearized.py create mode 100644 pi/models/unet/prologue.py diff --git a/examples/unet.py b/examples/unet.py index 268fef9..5d45508 100644 --- a/examples/unet.py +++ b/examples/unet.py @@ -1,59 +1,92 @@ -import torch +import inspect +import re + import numpy as np +import torch -from pi.models.unet import UNet2DConditionModel -import torch_mlir - -unet = UNet2DConditionModel( - **{ - "block_out_channels": (32, 64), - "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"), - "up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"), - "cross_attention_dim": 32, - "attention_head_dim": 8, - "out_channels": 4, - "in_channels": 4, - "layers_per_block": 2, - "sample_size": 32, - } -) -unet.eval() - -batch_size = 4 -num_channels = 4 -sizes = (32, 32) +from pi.lazy_importer.run_lazy_imports import do_package_imports, do_hand_imports +from pi.lazy_importer import lazy_imports -def floats_tensor(shape, scale=1.0, rng=None, name=None): +# + +def floats_tensor(shape, scale=1.0, rng=None, name=None): total_dims = 1 for dim in shape: total_dims *= dim - values = [] for _ in range(total_dims): values.append(np.random.random() * scale) - return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() -noise = floats_tensor((batch_size, num_channels) + sizes) -time_step = torch.tensor([10]) -encoder_hidden_states = floats_tensor((batch_size, 4, 32)) +def run( + CTor, + down_block_types=("CrossAttnDownBlock2D", "ResnetDownsampleBlock2D"), + up_block_types=("UpBlock2D", "ResnetUpsampleBlock2D"), +): + unet = CTor( + **{ + "block_out_channels": (32, 64), + "down_block_types": down_block_types, + "up_block_types": up_block_types, + "cross_attention_dim": 32, + "attention_head_dim": 8, + "out_channels": 4, + "in_channels": 4, + "layers_per_block": 2, + "sample_size": 32, + } + ) + unet.eval() + batch_size = 4 + num_channels = 4 + sizes = (32, 32) + + noise = floats_tensor((batch_size, num_channels) + sizes) + time_step = torch.tensor([10]) + encoder_hidden_states = floats_tensor((batch_size, 4, 32)) + output = unet(noise, time_step, encoder_hidden_states) + + +def make_linearized(): + def filter(ret): + try: + MODULE_TARGET = lambda x: re.match( + r"(huggingface|torch|diffusers)", inspect.getmodule(x).__package__ + ) + return MODULE_TARGET(ret) + except: + return None + + lazy_imports.MODULE_TARGET = filter + + def _inner(): + + from diffusers import UNet2DConditionModel + + run( + UNet2DConditionModel, + down_block_types=("CrossAttnDownBlock2D", "ResnetDownsampleBlock2D"), + up_block_types=("UpBlock2D", "ResnetUpsampleBlock2D"), + ) + run( + UNet2DConditionModel, + down_block_types=("DownBlock2D", "AttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "ResnetUpsampleBlock2D"), + ) + + prefix = "from pi.models.unet.prologue import CONFIG_NAME, LORA_WEIGHT_NAME" + name = "unet_linearized" + do_package_imports(_inner, prefix, name) + -output = unet(noise, time_step, encoder_hidden_states) -print(output) +def run_linearized(): + from pi.models.unet import linearized -traced = torch.jit.trace(unet, (noise, time_step, encoder_hidden_states), strict=False) -frozen = torch.jit.freeze(traced) -print(frozen.graph) + run(linearized.UNet2DConditionModel) -module = torch_mlir.compile( - frozen, - (noise, time_step, encoder_hidden_states), - use_tracing=True, - output_type=torch_mlir.OutputType.RAW, -) -with open("unet.mlir", "w") as f: - f.write(str(module)) +if __name__ == "__main__": + make_linearized() diff --git a/pi/lazy_importer/__init__.py b/pi/lazy_importer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pi/lazy_importer/lazy_imports.py b/pi/lazy_importer/lazy_imports.py new file mode 100644 index 0000000..216843d --- /dev/null +++ b/pi/lazy_importer/lazy_imports.py @@ -0,0 +1,377 @@ +# -*- coding: utf-8 -*- +import torch +import ast +import copy +import importlib +import inspect +import logging +import re +import sys +from textwrap import dedent +from types import FrameType +from typing import Any, List, Optional, Set, Union + +import pyccolo as pyc +from pyccolo._fast.misc_ast_utils import subscript_to_slice + +logger = logging.getLogger(__name__) + + +_unresolved = object() + +imports_file = set() +import_uses_file = set() +nn_modules_file = set() +functions_file = set() +classes_file = set() +all_calls = set() + + +class _LazySymbol: + non_modules: Set[str] = set() + blocklist_packages: Set[str] = set() + + def __init__(self, spec: Union[ast.Import, ast.ImportFrom]): + self.spec = spec + imports_file.add(ast.unparse(spec)) + self.value = _unresolved + + @property + def qualified_module(self) -> str: + node = self.spec + name = node.names[0].name + if isinstance(node, ast.Import): + return name + else: + return f"{node.module}.{name}" + + @staticmethod + def top_level_package(module: str) -> str: + return module.split(".", 1)[0] + + @classmethod + def _unwrap_module(cls, module: str) -> Any: + if module in sys.modules: + return sys.modules[module] + exc = None + if module not in cls.non_modules: + try: + with pyc.allow_reentrant_event_handling(): + return importlib.import_module(module) + except ImportError as e: + cls.non_modules.add(module) + exc = e + except Exception: + logger.error("fatal error trying to import module %s", module) + raise + module_symbol = module.rsplit(".", 1) + if len(module_symbol) != 2: + raise exc + else: + module, symbol = module_symbol + ret = getattr(cls._unwrap_module(module), symbol) + if isinstance(ret, _LazySymbol): + ret = ret.unwrap() + handle_unwrapped_ret(ret) + return ret + + def _unwrap_helper(self) -> Any: + return self._unwrap_module(self.qualified_module) + + def unwrap(self) -> Any: + if self.value is not _unresolved: + return self.value + ret = self._unwrap_helper() + handle_unwrapped_ret(ret) + self.value = ret + return ret + + def __call__(self, *args, **kwargs): + return self.unwrap()(*args, **kwargs) + + def __getattr__(self, item): + hasattr(self.unwrap(), item) + + +class _GetLazyNames(ast.NodeVisitor): + def __init__(self) -> None: + self.lazy_names: Optional[Set[str]] = set() + + def visit_Import(self, node: ast.Import) -> None: + if self.lazy_names is None: + return + for alias in node.names: + if alias.asname is None: + return + for alias in node.names: + self.lazy_names.add(alias.asname) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + if self.lazy_names is None: + return + for alias in node.names: + if alias.name == "*": + self.lazy_names = None + return + for alias in node.names: + self.lazy_names.add(alias.asname or alias.name) + + @classmethod + def compute(cls, node: ast.Module) -> Set[str]: + inst = cls() + inst.visit(node) + return inst.lazy_names + + +def _make_attr_guard_helper(node: ast.Attribute) -> Optional[str]: + if isinstance(node.value, ast.Name): + return f"{node.value.id}_O_{node.attr}" + elif isinstance(node.value, ast.Attribute): + prefix = _make_attr_guard_helper(node.value) + if prefix is None: + return None + else: + return f"{prefix}_O_{node.attr}" + else: + return None + + +def _make_attr_guard(node: ast.Attribute) -> Optional[str]: + suffix = _make_attr_guard_helper(node) + if suffix is None: + return None + else: + return f"_Xix_{suffix}" + + +def _make_subscript_guard_helper(node: ast.Subscript) -> Optional[str]: + slice_val = subscript_to_slice(node) + if isinstance(slice_val, (ast.Constant, ast.Str, ast.Num, ast.Name)): + if isinstance(slice_val, ast.Name): + subscript = slice_val.id + elif hasattr(slice_val, "s"): + subscript = f"_{slice_val.s}_" # type: ignore + elif hasattr(slice_val, "n"): + subscript = f"_{slice_val.n}_" # type: ignore + else: + return None + else: + return None + if isinstance(node.value, ast.Name): + return f"{node.value.id}_S_{subscript}" + elif isinstance(node.value, ast.Subscript): + prefix = _make_subscript_guard_helper(node.value) + if prefix is None: + return None + else: + return f"{prefix}_S_{subscript}" + else: + return None + + +def _make_subscript_guard(node: ast.Subscript) -> Optional[str]: + suffix = _make_subscript_guard_helper(node) + if suffix is None: + return None + else: + return f"_Xix_{suffix}" + + +MODULE_TARGET = None + + +def handle_unwrapped_ret(ret): + if match := MODULE_TARGET(ret): + if ( + inspect.isclass(ret) + and issubclass(ret, torch.nn.Module) + and match.group() != "torch" + ): + src = dedent(inspect.getsource(ret)) + nn_modules_file.add(src) + elif inspect.isclass(ret): + src = dedent(inspect.getsource(ret)) + classes_file.add(src) + + +def handle_call(ret): + if match := MODULE_TARGET(ret): + if isinstance(ret, torch.nn.Module) and match.group() != "torch": + src = dedent(inspect.getsource(ret.__class__)) + nn_modules_file.add(src) + elif inspect.isclass(ret) and match.group() != "torch": + src = dedent(inspect.getsource(ret)) + classes_file.add(src) + elif inspect.isfunction(ret): + if hasattr(ret, "__name__") and ret.__name__ == "__init__": + # 'super().__init__()' + return + src = dedent(inspect.getsource(ret)) + functions_file.add(src) + + +class LazyImportTracer(pyc.BaseTracer): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.cur_module_lazy_names: Set[str] = set() + self.saved_attributes: List[Any] = [] + self.saved_subscripts: List[Any] = [] + self.saved_slices: List[Any] = [] + + def _is_name_lazy_load(self, node: Union[ast.Attribute, ast.Name]) -> bool: + if self.cur_module_lazy_names is None: + return True + elif isinstance(node, ast.Name): + return node.id in self.cur_module_lazy_names + elif isinstance(node, (ast.Attribute, ast.Subscript)): + return self._is_name_lazy_load(node.value) # type: ignore + elif isinstance(node, ast.Call): + return self._is_name_lazy_load(node.func) + else: + return False + + def static_init_module(self, node: ast.Module) -> None: + self.cur_module_lazy_names = _GetLazyNames.compute(node) + + @staticmethod + def _convert_relative_to_absolute( + package: str, module: Optional[str], level: int + ) -> str: + prefix = package.rsplit(".", level - 1)[0] + if not module: + return prefix + else: + return f"{prefix}.{module}" + + @pyc.init_module + def init_module( + self, _ret: None, node: ast.Module, frame: FrameType, *_, **__ + ) -> None: + assert node is not None + for guard in self.local_guards_by_module_id.get(id(node), []): + frame.f_globals[guard] = False + + @pyc.before_call() + def before_call( + self, + ret: None, + node: Union[ast.Import, ast.ImportFrom], + frame: FrameType, + *_, + **__, + ) -> Any: + handle_call(ret) + return ret + + @pyc.before_stmt( + when=pyc.Predicate( + lambda node: isinstance(node, (ast.Import, ast.ImportFrom)) + and pyc.is_outer_stmt(node), + static=True, + ) + ) + def before_stmt( + self, + _ret: None, + node: Union[ast.Import, ast.ImportFrom], + frame: FrameType, + *_, + **__, + ) -> Any: + is_import = isinstance(node, ast.Import) + for alias in node.names: + if alias.name == "*": + return None + elif is_import and alias.asname is None: + return None + package = frame.f_globals["__package__"] + level = getattr(node, "level", 0) + if is_import: + module = None + else: + module = node.module # type: ignore + if level > 0: + module = self._convert_relative_to_absolute(package, module, level) + for alias in node.names: + node_cpy = copy.deepcopy(node) + node_cpy.names = [alias] + if module is not None: + node_cpy.module = module # type: ignore + node_cpy.level = 0 # type: ignore + lz = _LazySymbol(spec=node_cpy) + frame.f_globals[alias.asname or alias.name] = lz + import_uses_file.add(alias.asname or alias.name) + return pyc.Pass + + @pyc.before_attribute_load( + when=pyc.Predicate(_is_name_lazy_load, static=True), guard=_make_attr_guard + ) + def before_attr_load(self, ret: Any, *_, **__) -> Any: + self.saved_attributes.append(ret) + return ret + + @pyc.after_attribute_load( + when=pyc.Predicate(_is_name_lazy_load, static=True), guard=_make_attr_guard + ) + def after_attr_load( + self, ret: Any, node: ast.Attribute, frame: FrameType, _evt, guard, *_, **__ + ) -> Any: + if guard is not None: + frame.f_globals[guard] = True + saved_attr_obj = self.saved_attributes.pop() + if isinstance(ret, _LazySymbol): + ret = ret.unwrap() + setattr(saved_attr_obj, node.attr, ret) + return pyc.Null if ret is None else ret + + @pyc.before_subscript_load( + when=pyc.Predicate(_is_name_lazy_load, static=True), guard=_make_subscript_guard + ) + def before_subscript_load(self, ret: Any, *_, attr_or_subscript: Any, **__) -> Any: + self.saved_subscripts.append(ret) + self.saved_slices.append(attr_or_subscript) + return ret + + @pyc.after_subscript_load( + when=pyc.Predicate(_is_name_lazy_load, static=True), guard=_make_subscript_guard + ) + def after_subscript_load( + self, ret: Any, _node, frame: FrameType, _evt, guard, *_, **__ + ) -> Any: + if guard is not None: + frame.f_globals[guard] = True + saved_subscript_obj = self.saved_subscripts.pop() + saved_slice_obj = self.saved_slices.pop() + if isinstance(ret, _LazySymbol): + ret = ret.unwrap() + saved_subscript_obj[saved_slice_obj] = ret + return pyc.Null if ret is None else ret + + @pyc.load_name( + when=pyc.Predicate(_is_name_lazy_load, static=True), + guard=lambda node: f"_Xix_{node.id}_guard", + ) + def load_name( + self, ret: Any, node: ast.Name, frame: FrameType, _evt, guard, *_, **__ + ) -> Any: + if guard is not None: + frame.f_globals[guard] = True + if isinstance(ret, _LazySymbol): + ret = ret.unwrap() + frame.f_globals[node.id] = ret + else: + print("notlazysymbol", ret) + return pyc.Null if ret is None else ret + + +def capture_all_calls(frame: FrameType, event, arg): + if event != "call": + return + + try: + if MODULE_TARGET(frame.f_code): + code = inspect.getsource(frame.f_code) + if not re.match(r"\s", code): + all_calls.add(code) + except: + pass diff --git a/pi/lazy_importer/run_lazy_imports.py b/pi/lazy_importer/run_lazy_imports.py new file mode 100644 index 0000000..a85966b --- /dev/null +++ b/pi/lazy_importer/run_lazy_imports.py @@ -0,0 +1,76 @@ +import sys +from textwrap import dedent + +from pi.lazy_importer.lazy_imports import ( + LazyImportTracer, + imports_file, + import_uses_file, + nn_modules_file, + classes_file, + capture_all_calls, + functions_file, +) + + +class LazyImports(LazyImportTracer): + def should_instrument_file(self, filename: str) -> bool: + return not filename.endswith("lazy_imports.py") + + +# /home/mlevental/mambaforge/envs/PI/lib/python3.11/site-packages/pyccolo/ast_rewriter.py:161 +PREFIX = f"""\ +from __future__ import annotations +import functools +import importlib +import inspect +import json +import logging +import math +import os +from collections import defaultdict +from dataclasses import dataclass, fields +from functools import partial +from pathlib import PosixPath +from typing import Optional, Callable, Tuple, Union, Dict, List, Any, OrderedDict +import sys +import warnings + +import numpy as np +import torch +from torch import nn, Tensor +from torch.nn import functional as F +logger = logging.getLogger(__name__) + +""" + + +def do_hand_imports(closure, closure2): + with LazyImports.instance(): + closure() + with open("imports.py", "w") as f: + for n in sorted(imports_file): + print(n, file=f, flush=True) + for n in sorted(import_uses_file): + print( + dedent( + f"""\ + try: {n}() + except: pass + """ + ), + file=f, + flush=True, + ) + with LazyImports.instance(): + closure2() + + +def do_package_imports(closure, prefix, name): + with LazyImports.instance(): + closure() + + with open(f"{name}.py", "w") as f: + print(PREFIX, file=f, flush=True) + print(prefix, file=f, flush=True) + for n in sorted(functions_file | classes_file | nn_modules_file): + print(n, file=f, flush=True) diff --git a/pi/models/unet.py b/pi/models/unet.py deleted file mode 100644 index 7491e73..0000000 --- a/pi/models/unet.py +++ /dev/null @@ -1,11078 +0,0 @@ -import importlib -import json -import logging -import sys -from pathlib import Path, PosixPath - -logger = logging.getLogger(__name__) - -import functools -import operator as op -import os -from dataclasses import dataclass, fields -from collections import OrderedDict, defaultdict -from copy import deepcopy -from functools import partial -from torch import Tensor, device -from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from uuid import uuid4 -import inspect -import math -import numpy as np -import re -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint -import warnings - - -class MakeCutouts(nn.Module): - def __init__(self, cut_size, cut_power=1.0): - super().__init__() - self.cut_size = cut_size - self.cut_power = cut_power - - def forward(self, pixel_values, num_cutouts): - sideY, sideX = pixel_values.shape[2:4] - max_size = min(sideX, sideY) - min_size = min(sideX, sideY, self.cut_size) - cutouts = [] - for _ in range(num_cutouts): - size = int( - torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size - ) - offsetx = torch.randint(0, sideX - size + 1, ()) - offsety = torch.randint(0, sideY - size + 1, ()) - cutout = pixel_values[ - :, :, offsety : offsety + size, offsetx : offsetx + size - ] - cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) - return torch.cat(cutouts) - - -class DiffusionUncond(nn.Module): - def __init__(self, global_args): - super().__init__() - self.diffusion = DiffusionAttnUnet1D(global_args, n_attn_layers=4) - self.diffusion_ema = deepcopy(self.diffusion) - self.rng = torch.quasirandom.SobolEngine(1, scramble=True) - - -class AttnProcsLayers(torch.nn.Module): - def __init__(self, state_dict: "Dict[str, torch.Tensor]"): - super().__init__() - self.layers = torch.nn.ModuleList(state_dict.values()) - self.mapping = {k: v for k, v in enumerate(state_dict.keys())} - self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} - - def map_to(module, state_dict, *args, **kwargs): - new_state_dict = {} - for key, value in state_dict.items(): - num = int(key.split(".")[1]) - new_key = key.replace(f"layers.{num}", module.mapping[num]) - new_state_dict[new_key] = value - return new_state_dict - - def map_from(module, state_dict, *args, **kwargs): - all_keys = list(state_dict.keys()) - for key in all_keys: - replace_key = key.split(".processor")[0] + ".processor" - new_key = key.replace( - replace_key, f"layers.{module.rev_mapping[replace_key]}" - ) - state_dict[new_key] = state_dict[key] - del state_dict[key] - - self._register_state_dict_hook(map_to) - self._register_load_state_dict_pre_hook(map_from, with_module=True) - - -def is_xformers_available(): - return _xformers_available - - -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted - to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - Uses three q, k, v linear layers to compute attention. - - Parameters: - channels (`int`): The number of channels in the input and output. - num_head_channels (`int`, *optional*): - The number of channels in each head. If None, then `num_heads` = 1. - norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm. - rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by. - eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. - """ - - def __init__( - self, - channels: "int", - num_head_channels: "Optional[int]" = None, - norm_num_groups: "int" = 32, - rescale_output_factor: "float" = 1.0, - eps: "float" = 1e-05, - ): - super().__init__() - self.channels = channels - self.num_heads = ( - channels // num_head_channels if num_head_channels is not None else 1 - ) - self.num_head_size = num_head_channels - self.group_norm = nn.GroupNorm( - num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True - ) - self.query = nn.Linear(channels, channels) - self.key = nn.Linear(channels, channels) - self.value = nn.Linear(channels, channels) - self.rescale_output_factor = rescale_output_factor - self.proj_attn = nn.Linear(channels, channels, 1) - self._use_memory_efficient_attention_xformers = False - self._attention_op = None - - def reshape_heads_to_batch_dim(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.num_heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape( - batch_size * head_size, seq_len, dim // head_size - ) - return tensor - - def reshape_batch_dim_to_heads(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.num_heads - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape( - batch_size // head_size, seq_len, dim * head_size - ) - return tensor - - def set_use_memory_efficient_attention_xformers( - self, - use_memory_efficient_attention_xformers: "bool", - attention_op: "Optional[Callable]" = None, - ): - if use_memory_efficient_attention_xformers: - if not is_xformers_available(): - raise ModuleNotFoundError( - "Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers", - name="xformers", - ) - elif not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only available for GPU " - ) - else: - try: - _ = xformers.ops.memory_efficient_attention( - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - ) - except Exception as e: - raise e - self._use_memory_efficient_attention_xformers = ( - use_memory_efficient_attention_xformers - ) - self._attention_op = attention_op - - def forward(self, hidden_states): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - hidden_states = self.group_norm(hidden_states) - hidden_states = hidden_states.view(batch, channel, height * width).transpose( - 1, 2 - ) - query_proj = self.query(hidden_states) - key_proj = self.key(hidden_states) - value_proj = self.value(hidden_states) - scale = 1 / math.sqrt(self.channels / self.num_heads) - query_proj = self.reshape_heads_to_batch_dim(query_proj) - key_proj = self.reshape_heads_to_batch_dim(key_proj) - value_proj = self.reshape_heads_to_batch_dim(value_proj) - if self._use_memory_efficient_attention_xformers: - hidden_states = xformers.ops.memory_efficient_attention( - query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op - ) - hidden_states = hidden_states - else: - attention_scores = torch.baddbmm( - torch.empty( - query_proj.shape[0], - query_proj.shape[1], - key_proj.shape[1], - dtype=query_proj.dtype, - device=query_proj.device, - ), - query_proj, - key_proj.transpose(-1, -2), - beta=0, - alpha=scale, - ) - attention_probs = torch.softmax(attention_scores.float(), dim=-1).type( - attention_scores.dtype - ) - hidden_states = torch.bmm(attention_probs, value_proj) - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - hidden_states = self.proj_attn(hidden_states) - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch, channel, height, width - ) - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - -class AdaLayerNorm(nn.Module): - """ - Norm layer modified to incorporate timestep embeddings. - """ - - def __init__(self, embedding_dim, num_embeddings): - super().__init__() - self.emb = nn.Embedding(num_embeddings, embedding_dim) - self.silu = nn.SiLU() - self.linear = nn.Linear(embedding_dim, embedding_dim * 2) - self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) - - def forward(self, x, timestep): - emb = self.linear(self.silu(self.emb(timestep))) - scale, shift = torch.chunk(emb, 2) - x = self.norm(x) * (1 + scale) + shift - return x - - -class LabelEmbedding(nn.Module): - """ - Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. - - Args: - num_classes (`int`): The number of classes. - hidden_size (`int`): The size of the vector embeddings. - dropout_prob (`float`): The probability of dropping a label. - """ - - def __init__(self, num_classes, hidden_size, dropout_prob): - super().__init__() - use_cfg_embedding = dropout_prob > 0 - self.embedding_table = nn.Embedding( - num_classes + use_cfg_embedding, hidden_size - ) - self.num_classes = num_classes - self.dropout_prob = dropout_prob - - def token_drop(self, labels, force_drop_ids=None): - """ - Drops labels to enable classifier-free guidance. - """ - if force_drop_ids is None: - drop_ids = ( - torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob - ) - else: - drop_ids = torch.tensor(force_drop_ids == 1) - labels = torch.where(drop_ids, self.num_classes, labels) - return labels - - def forward(self, labels, force_drop_ids=None): - use_dropout = self.dropout_prob > 0 - if self.training and use_dropout or force_drop_ids is not None: - labels = self.token_drop(labels, force_drop_ids) - embeddings = self.embedding_table(labels) - return embeddings - - -class TimestepEmbedding(nn.Module): - def __init__( - self, - in_channels: "int", - time_embed_dim: "int", - act_fn: "str" = "silu", - out_dim: "int" = None, - post_act_fn: "Optional[str]" = None, - cond_proj_dim=None, - ): - super().__init__() - self.linear_1 = nn.Linear(in_channels, time_embed_dim) - if cond_proj_dim is not None: - self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) - else: - self.cond_proj = None - if act_fn == "silu": - self.act = nn.SiLU() - elif act_fn == "mish": - self.act = nn.Mish() - elif act_fn == "gelu": - self.act = nn.GELU() - else: - raise ValueError( - f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'" - ) - if out_dim is not None: - time_embed_dim_out = out_dim - else: - time_embed_dim_out = time_embed_dim - self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) - if post_act_fn is None: - self.post_act = None - elif post_act_fn == "silu": - self.post_act = nn.SiLU() - elif post_act_fn == "mish": - self.post_act = nn.Mish() - elif post_act_fn == "gelu": - self.post_act = nn.GELU() - else: - raise ValueError( - f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'" - ) - - def forward(self, sample, condition=None): - if condition is not None: - sample = sample + self.cond_proj(condition) - sample = self.linear_1(sample) - if self.act is not None: - sample = self.act(sample) - sample = self.linear_2(sample) - if self.post_act is not None: - sample = self.post_act(sample) - return sample - - -def get_timestep_embedding( - timesteps: "torch.Tensor", - embedding_dim: "int", - flip_sin_to_cos: "bool" = False, - downscale_freq_shift: "float" = 1, - scale: "float" = 1, - max_period: "int" = 10000, -): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. - - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the - embeddings. :return: an [N x dim] Tensor of positional embeddings. - """ - assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" - half_dim = embedding_dim // 2 - exponent = -math.log(max_period) * torch.arange( - start=0, end=half_dim, dtype=torch.float32, device=timesteps.device - ) - exponent = exponent / (half_dim - downscale_freq_shift) - emb = torch.exp(exponent) - emb = timesteps[:, None].float() * emb[None, :] - emb = scale * emb - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) - if flip_sin_to_cos: - emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) - if embedding_dim % 2 == 1: - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb - - -class Timesteps(nn.Module): - def __init__( - self, - num_channels: "int", - flip_sin_to_cos: "bool", - downscale_freq_shift: "float", - ): - super().__init__() - self.num_channels = num_channels - self.flip_sin_to_cos = flip_sin_to_cos - self.downscale_freq_shift = downscale_freq_shift - - def forward(self, timesteps): - t_emb = get_timestep_embedding( - timesteps, - self.num_channels, - flip_sin_to_cos=self.flip_sin_to_cos, - downscale_freq_shift=self.downscale_freq_shift, - ) - return t_emb - - -class CombinedTimestepLabelEmbeddings(nn.Module): - def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): - super().__init__() - self.time_proj = Timesteps( - num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1 - ) - self.timestep_embedder = TimestepEmbedding( - in_channels=256, time_embed_dim=embedding_dim - ) - self.class_embedder = LabelEmbedding( - num_classes, embedding_dim, class_dropout_prob - ) - - def forward(self, timestep, class_labels, hidden_dtype=None): - timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder(timesteps_proj) - class_labels = self.class_embedder(class_labels) - conditioning = timesteps_emb + class_labels - return conditioning - - -class AdaLayerNormZero(nn.Module): - """ - Norm layer adaptive layer norm zero (adaLN-Zero). - """ - - def __init__(self, embedding_dim, num_embeddings): - super().__init__() - self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) - self.silu = nn.SiLU() - self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) - self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-06) - - def forward(self, x, timestep, class_labels, hidden_dtype=None): - emb = self.linear( - self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)) - ) - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk( - 6, dim=1 - ) - x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] - return x, gate_msa, shift_mlp, scale_mlp, gate_mlp - - -class AttnProcessor2_0: - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - def __call__( - self, - attn: "CrossAttention", - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - ): - batch_size, sequence_length, _ = ( - hidden_states.shape - if encoder_hidden_states is None - else encoder_hidden_states.shape - ) - inner_dim = hidden_states.shape[-1] - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask( - attention_mask, sequence_length, batch_size - ) - attention_mask = attention_mask.view( - batch_size, attn.heads, -1, attention_mask.shape[-1] - ) - query = attn.to_q(hidden_states) - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - head_dim = inner_dim // attn.heads - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - hidden_states = hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - hidden_states = hidden_states - hidden_states = attn.to_out[0](hidden_states) - hidden_states = attn.to_out[1](hidden_states) - return hidden_states - - -class CrossAttnAddedKVProcessor: - def __call__( - self, - attn: "CrossAttention", - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - ): - residual = hidden_states - hidden_states = hidden_states.view( - hidden_states.shape[0], hidden_states.shape[1], -1 - ).transpose(1, 2) - batch_size, sequence_length, _ = hidden_states.shape - encoder_hidden_states = encoder_hidden_states.transpose(1, 2) - attention_mask = attn.prepare_attention_mask( - attention_mask, sequence_length, batch_size - ) - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states) - query = attn.head_to_batch_dim(query) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.head_to_batch_dim( - encoder_hidden_states_key_proj - ) - encoder_hidden_states_value_proj = attn.head_to_batch_dim( - encoder_hidden_states_value_proj - ) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - hidden_states = attn.to_out[0](hidden_states) - hidden_states = attn.to_out[1](hidden_states) - hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) - hidden_states = hidden_states + residual - return hidden_states - - -class CrossAttnProcessor: - def __call__( - self, - attn: "CrossAttention", - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - ): - batch_size, sequence_length, _ = ( - hidden_states.shape - if encoder_hidden_states is None - else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask( - attention_mask, sequence_length, batch_size - ) - query = attn.to_q(hidden_states) - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - hidden_states = attn.to_out[0](hidden_states) - hidden_states = attn.to_out[1](hidden_states) - return hidden_states - - -class LoRALinearLayer(nn.Module): - def __init__(self, in_features, out_features, rank=4): - super().__init__() - if rank > min(in_features, out_features): - raise ValueError( - f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}" - ) - self.down = nn.Linear(in_features, rank, bias=False) - self.up = nn.Linear(rank, out_features, bias=False) - nn.init.normal_(self.down.weight, std=1 / rank) - nn.init.zeros_(self.up.weight) - - def forward(self, hidden_states): - orig_dtype = hidden_states.dtype - dtype = self.down.weight.dtype - down_hidden_states = self.down(hidden_states) - up_hidden_states = self.up(down_hidden_states) - return up_hidden_states - - -class LoRACrossAttnProcessor(nn.Module): - def __init__(self, hidden_size, cross_attention_dim=None, rank=4): - super().__init__() - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - self.rank = rank - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - self.to_k_lora = LoRALinearLayer( - cross_attention_dim or hidden_size, hidden_size, rank - ) - self.to_v_lora = LoRALinearLayer( - cross_attention_dim or hidden_size, hidden_size, rank - ) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - - def __call__( - self, - attn: "CrossAttention", - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - scale=1.0, - ): - batch_size, sequence_length, _ = ( - hidden_states.shape - if encoder_hidden_states is None - else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask( - attention_mask, sequence_length, batch_size - ) - query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) - query = attn.head_to_batch_dim(query) - encoder_hidden_states = ( - encoder_hidden_states - if encoder_hidden_states is not None - else hidden_states - ) - key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora( - encoder_hidden_states - ) - value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora( - encoder_hidden_states - ) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora( - hidden_states - ) - hidden_states = attn.to_out[1](hidden_states) - return hidden_states - - -class LoRAXFormersCrossAttnProcessor(nn.Module): - def __init__( - self, - hidden_size, - cross_attention_dim, - rank=4, - attention_op: "Optional[Callable]" = None, - ): - super().__init__() - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - self.rank = rank - self.attention_op = attention_op - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - self.to_k_lora = LoRALinearLayer( - cross_attention_dim or hidden_size, hidden_size, rank - ) - self.to_v_lora = LoRALinearLayer( - cross_attention_dim or hidden_size, hidden_size, rank - ) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - - def __call__( - self, - attn: "CrossAttention", - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - scale=1.0, - ): - batch_size, sequence_length, _ = ( - hidden_states.shape - if encoder_hidden_states is None - else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask( - attention_mask, sequence_length, batch_size - ) - query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) - query = attn.head_to_batch_dim(query).contiguous() - encoder_hidden_states = ( - encoder_hidden_states - if encoder_hidden_states is not None - else hidden_states - ) - key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora( - encoder_hidden_states - ) - value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora( - encoder_hidden_states - ) - key = attn.head_to_batch_dim(key).contiguous() - value = attn.head_to_batch_dim(value).contiguous() - hidden_states = xformers.ops.memory_efficient_attention( - query, - key, - value, - attn_bias=attention_mask, - op=self.attention_op, - scale=attn.scale, - ) - hidden_states = attn.batch_to_head_dim(hidden_states) - hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora( - hidden_states - ) - hidden_states = attn.to_out[1](hidden_states) - return hidden_states - - -class SlicedAttnAddedKVProcessor: - def __init__(self, slice_size): - self.slice_size = slice_size - - def __call__( - self, - attn: "'CrossAttention'", - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - ): - residual = hidden_states - hidden_states = hidden_states.view( - hidden_states.shape[0], hidden_states.shape[1], -1 - ).transpose(1, 2) - encoder_hidden_states = encoder_hidden_states.transpose(1, 2) - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask( - attention_mask, sequence_length, batch_size - ) - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states) - dim = query.shape[-1] - query = attn.head_to_batch_dim(query) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - encoder_hidden_states_key_proj = attn.head_to_batch_dim( - encoder_hidden_states_key_proj - ) - encoder_hidden_states_value_proj = attn.head_to_batch_dim( - encoder_hidden_states_value_proj - ) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) - batch_size_attention, query_tokens, _ = query.shape - hidden_states = torch.zeros( - (batch_size_attention, query_tokens, dim // attn.heads), - device=query.device, - dtype=query.dtype, - ) - for i in range(batch_size_attention // self.slice_size): - start_idx = i * self.slice_size - end_idx = (i + 1) * self.slice_size - query_slice = query[start_idx:end_idx] - key_slice = key[start_idx:end_idx] - attn_mask_slice = ( - attention_mask[start_idx:end_idx] - if attention_mask is not None - else None - ) - attn_slice = attn.get_attention_scores( - query_slice, key_slice, attn_mask_slice - ) - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) - hidden_states[start_idx:end_idx] = attn_slice - hidden_states = attn.batch_to_head_dim(hidden_states) - hidden_states = attn.to_out[0](hidden_states) - hidden_states = attn.to_out[1](hidden_states) - hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) - hidden_states = hidden_states + residual - return hidden_states - - -class SlicedAttnProcessor: - def __init__(self, slice_size): - self.slice_size = slice_size - - def __call__( - self, - attn: "CrossAttention", - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - ): - batch_size, sequence_length, _ = ( - hidden_states.shape - if encoder_hidden_states is None - else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask( - attention_mask, sequence_length, batch_size - ) - query = attn.to_q(hidden_states) - dim = query.shape[-1] - query = attn.head_to_batch_dim(query) - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - batch_size_attention, query_tokens, _ = query.shape - hidden_states = torch.zeros( - (batch_size_attention, query_tokens, dim // attn.heads), - device=query.device, - dtype=query.dtype, - ) - for i in range(batch_size_attention // self.slice_size): - start_idx = i * self.slice_size - end_idx = (i + 1) * self.slice_size - query_slice = query[start_idx:end_idx] - key_slice = key[start_idx:end_idx] - attn_mask_slice = ( - attention_mask[start_idx:end_idx] - if attention_mask is not None - else None - ) - attn_slice = attn.get_attention_scores( - query_slice, key_slice, attn_mask_slice - ) - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) - hidden_states[start_idx:end_idx] = attn_slice - hidden_states = attn.batch_to_head_dim(hidden_states) - hidden_states = attn.to_out[0](hidden_states) - hidden_states = attn.to_out[1](hidden_states) - return hidden_states - - -class XFormersCrossAttnProcessor: - def __init__(self, attention_op: "Optional[Callable]" = None): - self.attention_op = attention_op - - def __call__( - self, - attn: "CrossAttention", - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - ): - batch_size, sequence_length, _ = ( - hidden_states.shape - if encoder_hidden_states is None - else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask( - attention_mask, sequence_length, batch_size - ) - query = attn.to_q(hidden_states) - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - query = attn.head_to_batch_dim(query).contiguous() - key = attn.head_to_batch_dim(key).contiguous() - value = attn.head_to_batch_dim(value).contiguous() - hidden_states = xformers.ops.memory_efficient_attention( - query, - key, - value, - attn_bias=attention_mask, - op=self.attention_op, - scale=attn.scale, - ) - hidden_states = hidden_states - hidden_states = attn.batch_to_head_dim(hidden_states) - hidden_states = attn.to_out[0](hidden_states) - hidden_states = attn.to_out[1](hidden_states) - return hidden_states - - -def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True): - deprecated_kwargs = take_from - values = () - if not isinstance(args[0], tuple): - args = (args,) - for attribute, version_name, message in args: - if version.parse(version.parse(__version__).base_version) >= version.parse( - version_name - ): - raise ValueError( - f"The deprecation tuple {attribute, version_name, message} should be removed since diffusers' version {__version__} is >= {version_name}" - ) - warning = None - if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs: - values += (deprecated_kwargs.pop(attribute),) - warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}." - elif hasattr(deprecated_kwargs, attribute): - values += (getattr(deprecated_kwargs, attribute),) - warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}." - elif deprecated_kwargs is None: - warning = f"`{attribute}` is deprecated and will be removed in version {version_name}." - if warning is not None: - warning = warning + " " if standard_warn else "" - warnings.warn(warning + message, FutureWarning, stacklevel=2) - if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: - call_frame = inspect.getouterframes(inspect.currentframe())[1] - filename = call_frame.filename - line_number = call_frame.lineno - function = call_frame.function - key, value = next(iter(deprecated_kwargs.items())) - raise TypeError( - f"{function} in {filename} line {line_number - 1} got an unexpected keyword argument `{key}`" - ) - if len(values) == 0: - return - elif len(values) == 1: - return values[0] - return values - - -class CrossAttention(nn.Module): - """ - A cross attention layer. - - Parameters: - query_dim (`int`): The number of channels in the query. - cross_attention_dim (`int`, *optional*): - The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. - heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. - dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - bias (`bool`, *optional*, defaults to False): - Set to `True` for the query, key, and value linear layers to contain a bias parameter. - """ - - def __init__( - self, - query_dim: "int", - cross_attention_dim: "Optional[int]" = None, - heads: "int" = 8, - dim_head: "int" = 64, - dropout: "float" = 0.0, - bias=False, - upcast_attention: "bool" = False, - upcast_softmax: "bool" = False, - cross_attention_norm: "bool" = False, - added_kv_proj_dim: "Optional[int]" = None, - norm_num_groups: "Optional[int]" = None, - out_bias: "bool" = True, - scale_qk: "bool" = True, - processor: "Optional['AttnProcessor']" = None, - ): - super().__init__() - inner_dim = dim_head * heads - cross_attention_dim = ( - cross_attention_dim if cross_attention_dim is not None else query_dim - ) - self.upcast_attention = upcast_attention - self.upcast_softmax = upcast_softmax - self.cross_attention_norm = cross_attention_norm - self.scale = dim_head ** -0.5 if scale_qk else 1.0 - self.heads = heads - self.sliceable_head_dim = heads - self.added_kv_proj_dim = added_kv_proj_dim - if norm_num_groups is not None: - self.group_norm = nn.GroupNorm( - num_channels=inner_dim, - num_groups=norm_num_groups, - eps=1e-05, - affine=True, - ) - else: - self.group_norm = None - if cross_attention_norm: - self.norm_cross = nn.LayerNorm(cross_attention_dim) - self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) - self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) - self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) - if self.added_kv_proj_dim is not None: - self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) - self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) - self.to_out = nn.ModuleList([]) - self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias)) - self.to_out.append(nn.Dropout(dropout)) - if processor is None: - processor = ( - AttnProcessor2_0() - if hasattr(F, "scaled_dot_product_attention") and scale_qk - else CrossAttnProcessor() - ) - self.set_processor(processor) - - def set_use_memory_efficient_attention_xformers( - self, - use_memory_efficient_attention_xformers: "bool", - attention_op: "Optional[Callable]" = None, - ): - is_lora = hasattr(self, "processor") and isinstance( - self.processor, (LoRACrossAttnProcessor, LoRAXFormersCrossAttnProcessor) - ) - if use_memory_efficient_attention_xformers: - if self.added_kv_proj_dim is not None: - raise NotImplementedError( - "Memory efficient attention with `xformers` is currently not supported when `self.added_kv_proj_dim` is defined." - ) - elif not is_xformers_available(): - raise ModuleNotFoundError( - "Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers", - name="xformers", - ) - elif not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only available for GPU " - ) - else: - try: - _ = xformers.ops.memory_efficient_attention( - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - ) - except Exception as e: - raise e - if is_lora: - processor = LoRAXFormersCrossAttnProcessor( - hidden_size=self.processor.hidden_size, - cross_attention_dim=self.processor.cross_attention_dim, - rank=self.processor.rank, - attention_op=attention_op, - ) - processor.load_state_dict(self.processor.state_dict()) - processor - else: - processor = XFormersCrossAttnProcessor(attention_op=attention_op) - elif is_lora: - processor = LoRACrossAttnProcessor( - hidden_size=self.processor.hidden_size, - cross_attention_dim=self.processor.cross_attention_dim, - rank=self.processor.rank, - ) - processor.load_state_dict(self.processor.state_dict()) - processor - else: - processor = CrossAttnProcessor() - self.set_processor(processor) - - def set_attention_slice(self, slice_size): - if slice_size is not None and slice_size > self.sliceable_head_dim: - raise ValueError( - f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}." - ) - if slice_size is not None and self.added_kv_proj_dim is not None: - processor = SlicedAttnAddedKVProcessor(slice_size) - elif slice_size is not None: - processor = SlicedAttnProcessor(slice_size) - elif self.added_kv_proj_dim is not None: - processor = CrossAttnAddedKVProcessor() - else: - processor = CrossAttnProcessor() - self.set_processor(processor) - - def set_processor(self, processor: "'AttnProcessor'"): - if ( - hasattr(self, "processor") - and isinstance(self.processor, torch.nn.Module) - and not isinstance(processor, torch.nn.Module) - ): - logger.info( - f"You are removing possibly trained weights of {self.processor} with {processor}" - ) - self._modules.pop("processor") - self.processor = processor - - def forward( - self, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - **cross_attention_kwargs, - ): - return self.processor( - self, - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - - def batch_to_head_dim(self, tensor): - head_size = self.heads - batch_size, seq_len, dim = tensor.shape - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape( - batch_size // head_size, seq_len, dim * head_size - ) - return tensor - - def head_to_batch_dim(self, tensor): - head_size = self.heads - batch_size, seq_len, dim = tensor.shape - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape( - batch_size * head_size, seq_len, dim // head_size - ) - return tensor - - def get_attention_scores(self, query, key, attention_mask=None): - dtype = query.dtype - if self.upcast_attention: - query = query.float() - key = key.float() - if attention_mask is None: - baddbmm_input = torch.empty( - query.shape[0], - query.shape[1], - key.shape[1], - dtype=query.dtype, - device=query.device, - ) - beta = 0 - else: - baddbmm_input = attention_mask - beta = 1 - attention_scores = torch.baddbmm( - baddbmm_input, query, key.transpose(-1, -2), beta=beta, alpha=self.scale - ) - if self.upcast_softmax: - attention_scores = attention_scores.float() - attention_probs = attention_scores.softmax(dim=-1) - attention_probs = attention_probs - return attention_probs - - def prepare_attention_mask(self, attention_mask, target_length, batch_size=None): - if batch_size is None: - deprecate( - "batch_size=None", - "0.0.15", - "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to `prepare_attention_mask` when preparing the attention_mask.", - ) - batch_size = 1 - head_size = self.heads - if attention_mask is None: - return attention_mask - if attention_mask.shape[-1] != target_length: - if attention_mask.device.type == "mps": - padding_shape = ( - attention_mask.shape[0], - attention_mask.shape[1], - target_length, - ) - padding = torch.zeros( - padding_shape, - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - attention_mask = torch.cat([attention_mask, padding], dim=2) - else: - attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) - if attention_mask.shape[0] < batch_size * head_size: - attention_mask = attention_mask.repeat_interleave(head_size, dim=0) - return attention_mask - - -class ApproximateGELU(nn.Module): - """ - The approximate form of Gaussian Error Linear Unit (GELU) - - For more details, see section 2: https://arxiv.org/abs/1606.08415 - """ - - def __init__(self, dim_in: "int", dim_out: "int"): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out) - - def forward(self, x): - x = self.proj(x) - return x * torch.sigmoid(1.702 * x) - - -class GEGLU(nn.Module): - """ - A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. - - Parameters: - dim_in (`int`): The number of channels in the input. - dim_out (`int`): The number of channels in the output. - """ - - def __init__(self, dim_in: "int", dim_out: "int"): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) - - def gelu(self, gate): - if gate.device.type != "mps": - return F.gelu(gate) - return F.gelu(gate.to(dtype=torch.float32)) - - def forward(self, hidden_states): - hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) - return hidden_states * self.gelu(gate) - - -class GELU(nn.Module): - """ - GELU activation function with tanh approximation support with `approximate="tanh"`. - """ - - def __init__(self, dim_in: "int", dim_out: "int", approximate: "str" = "none"): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out) - self.approximate = approximate - - def gelu(self, gate): - if gate.device.type != "mps": - return F.gelu(gate, approximate=self.approximate) - return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate) - - def forward(self, hidden_states): - hidden_states = self.proj(hidden_states) - hidden_states = self.gelu(hidden_states) - return hidden_states - - -class FeedForward(nn.Module): - """ - A feed-forward layer. - - Parameters: - dim (`int`): The number of channels in the input. - dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. - """ - - def __init__( - self, - dim: "int", - dim_out: "Optional[int]" = None, - mult: "int" = 4, - dropout: "float" = 0.0, - activation_fn: "str" = "geglu", - final_dropout: "bool" = False, - ): - super().__init__() - inner_dim = int(dim * mult) - dim_out = dim_out if dim_out is not None else dim - if activation_fn == "gelu": - act_fn = GELU(dim, inner_dim) - if activation_fn == "gelu-approximate": - act_fn = GELU(dim, inner_dim, approximate="tanh") - elif activation_fn == "geglu": - act_fn = GEGLU(dim, inner_dim) - elif activation_fn == "geglu-approximate": - act_fn = ApproximateGELU(dim, inner_dim) - self.net = nn.ModuleList([]) - self.net.append(act_fn) - self.net.append(nn.Dropout(dropout)) - self.net.append(nn.Linear(inner_dim, dim_out)) - if final_dropout: - self.net.append(nn.Dropout(dropout)) - - def forward(self, hidden_states): - for module in self.net: - hidden_states = module(hidden_states) - return hidden_states - - -class BasicTransformerBlock(nn.Module): - """ - A basic Transformer block. - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm (: - obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. - attention_bias (: - obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. - """ - - def __init__( - self, - dim: "int", - num_attention_heads: "int", - attention_head_dim: "int", - dropout=0.0, - cross_attention_dim: "Optional[int]" = None, - activation_fn: "str" = "geglu", - num_embeds_ada_norm: "Optional[int]" = None, - attention_bias: "bool" = False, - only_cross_attention: "bool" = False, - upcast_attention: "bool" = False, - norm_elementwise_affine: "bool" = True, - norm_type: "str" = "layer_norm", - final_dropout: "bool" = False, - ): - super().__init__() - self.only_cross_attention = only_cross_attention - self.use_ada_layer_norm_zero = ( - num_embeds_ada_norm is not None and norm_type == "ada_norm_zero" - ) - self.use_ada_layer_norm = ( - num_embeds_ada_norm is not None and norm_type == "ada_norm" - ) - if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: - raise ValueError( - f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." - ) - self.attn1 = CrossAttention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - ) - self.ff = FeedForward( - dim, - dropout=dropout, - activation_fn=activation_fn, - final_dropout=final_dropout, - ) - if cross_attention_dim is not None: - self.attn2 = CrossAttention( - query_dim=dim, - cross_attention_dim=cross_attention_dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - ) - else: - self.attn2 = None - if self.use_ada_layer_norm: - self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) - elif self.use_ada_layer_norm_zero: - self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) - else: - self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - if cross_attention_dim is not None: - self.norm2 = ( - AdaLayerNorm(dim, num_embeds_ada_norm) - if self.use_ada_layer_norm - else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - ) - else: - self.norm2 = None - self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - - def forward( - self, - hidden_states, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - timestep=None, - cross_attention_kwargs=None, - class_labels=None, - ): - if self.use_ada_layer_norm: - norm_hidden_states = self.norm1(hidden_states, timestep) - elif self.use_ada_layer_norm_zero: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( - hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - else: - norm_hidden_states = self.norm1(hidden_states) - cross_attention_kwargs = ( - cross_attention_kwargs if cross_attention_kwargs is not None else {} - ) - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states - if self.only_cross_attention - else None, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - if self.use_ada_layer_norm_zero: - attn_output = gate_msa.unsqueeze(1) * attn_output - hidden_states = attn_output + hidden_states - if self.attn2 is not None: - norm_hidden_states = ( - self.norm2(hidden_states, timestep) - if self.use_ada_layer_norm - else self.norm2(hidden_states) - ) - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - **cross_attention_kwargs, - ) - hidden_states = attn_output + hidden_states - norm_hidden_states = self.norm3(hidden_states) - if self.use_ada_layer_norm_zero: - norm_hidden_states = ( - norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - ) - ff_output = self.ff(norm_hidden_states) - if self.use_ada_layer_norm_zero: - ff_output = gate_mlp.unsqueeze(1) * ff_output - hidden_states = ff_output + hidden_states - return hidden_states - - -class AdaGroupNorm(nn.Module): - """ - GroupNorm layer modified to incorporate timestep embeddings. - """ - - def __init__( - self, - embedding_dim: "int", - out_dim: "int", - num_groups: "int", - act_fn: "Optional[str]" = None, - eps: "float" = 1e-05, - ): - super().__init__() - self.num_groups = num_groups - self.eps = eps - self.act = None - if act_fn == "swish": - self.act = lambda x: F.silu(x) - elif act_fn == "mish": - self.act = nn.Mish() - elif act_fn == "silu": - self.act = nn.SiLU() - elif act_fn == "gelu": - self.act = nn.GELU() - self.linear = nn.Linear(embedding_dim, out_dim * 2) - - def forward(self, x, emb): - if self.act: - emb = self.act(emb) - emb = self.linear(emb) - emb = emb[:, :, None, None] - scale, shift = emb.chunk(2, dim=1) - x = F.group_norm(x, self.num_groups, eps=self.eps) - x = x * (1 + scale) + shift - return x - - -def zero_module(module): - for p in module.parameters(): - nn.init.zeros_(p) - return module - - -class ControlNetConditioningEmbedding(nn.Module): - """ - Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN - [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized - training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the - convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides - (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full - model) to encode image-space conditions ... into feature maps ..." - """ - - def __init__( - self, - conditioning_embedding_channels: "int", - conditioning_channels: "int" = 3, - block_out_channels: "Tuple[int]" = (16, 32, 96, 256), - ): - super().__init__() - self.conv_in = nn.Conv2d( - conditioning_channels, block_out_channels[0], kernel_size=3, padding=1 - ) - self.blocks = nn.ModuleList([]) - for i in range(len(block_out_channels) - 1): - channel_in = block_out_channels[i] - channel_out = block_out_channels[i + 1] - self.blocks.append( - nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1) - ) - self.blocks.append( - nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2) - ) - self.conv_out = zero_module( - nn.Conv2d( - block_out_channels[-1], - conditioning_embedding_channels, - kernel_size=3, - padding=1, - ) - ) - - def forward(self, conditioning): - embedding = self.conv_in(conditioning) - embedding = F.silu(embedding) - for block in self.blocks: - embedding = block(embedding) - embedding = F.silu(embedding) - embedding = self.conv_out(embedding) - return embedding - - -STR_OPERATION_TO_FUNC = { - ">": op.gt, - ">=": op.ge, - "==": op.eq, - "!=": op.ne, - "<=": op.le, - "<": op.lt, -} - - -def compare_versions( - library_or_version: "Union[str, Version]", - operation: "str", - requirement_version: "str", -): - """ - Args: - Compares a library version to some requirement using a given operation. - library_or_version (`str` or `packaging.version.Version`): - A library name or a version to check. - operation (`str`): - A string representation of an operator, such as `">"` or `"<="`. - requirement_version (`str`): - The version to compare the library version against - """ - if operation not in STR_OPERATION_TO_FUNC.keys(): - raise ValueError( - f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}" - ) - operation = STR_OPERATION_TO_FUNC[operation] - if isinstance(library_or_version, str): - library_or_version = parse(importlib_metadata.version(library_or_version)) - return operation(library_or_version, parse(requirement_version)) - - -def is_transformers_version(operation: "str", version: "str"): - """ - Args: - Compares the current Transformers version to a given reference with an operation. - operation (`str`): - A string representation of an operator, such as `">"` or `"<="` - version (`str`): - A version string - """ - if not _transformers_available: - return False - return compare_versions(parse(_transformers_version), operation, version) - - -def requires_backends(obj, backends): - if not isinstance(backends, (list, tuple)): - backends = [backends] - name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ - checks = (BACKENDS_MAPPING[backend] for backend in backends) - failed = [msg.format(name) for available, msg in checks if not available()] - if failed: - raise ImportError("".join(failed)) - if ( - name - in [ - "VersatileDiffusionTextToImagePipeline", - "VersatileDiffusionPipeline", - "VersatileDiffusionDualGuidedPipeline", - "StableDiffusionImageVariationPipeline", - "UnCLIPPipeline", - ] - and is_transformers_version("<", "4.25.0") - ): - raise ImportError( - f"You need to install `transformers>=4.25` in order to use {name}: \n```\n pip install --upgrade transformers \n```" - ) - if ( - name - in [ - "StableDiffusionDepth2ImgPipeline", - "StableDiffusionPix2PixZeroPipeline", - ] - and is_transformers_version("<", "4.26.0") - ): - raise ImportError( - f"You need to install `transformers>=4.26` in order to use {name}: \n```\n pip install --upgrade transformers \n```" - ) - - -class DummyObject(type): - """ - Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by - `requires_backend` each time a user tries to access any method of that class. - """ - - def __getattr__(cls, key): - if key.startswith("_"): - return super().__getattr__(cls, key) - requires_backends(cls, cls._backends) - - -class FrozenDict(OrderedDict): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - for key, value in self.items(): - setattr(self, key, value) - self.__frozen = True - - def __delitem__(self, *args, **kwargs): - raise Exception( - f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance." - ) - - def setdefault(self, *args, **kwargs): - raise Exception( - f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance." - ) - - def pop(self, *args, **kwargs): - raise Exception( - f"You cannot use ``pop`` on a {self.__class__.__name__} instance." - ) - - def update(self, *args, **kwargs): - raise Exception( - f"You cannot use ``update`` on a {self.__class__.__name__} instance." - ) - - def __setattr__(self, name, value): - if hasattr(self, "__frozen") and self.__frozen: - raise Exception( - f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance." - ) - super().__setattr__(name, value) - - def __setitem__(self, name, value): - if hasattr(self, "__frozen") and self.__frozen: - raise Exception( - f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance." - ) - super().__setitem__(name, value) - - -HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" - - -def extract_commit_hash( - resolved_file: "Optional[str]", commit_hash: "Optional[str]" = None -): - """ - Extracts the commit hash from a resolved filename toward a cache file. - """ - if resolved_file is None or commit_hash is not None: - return commit_hash - resolved_file = str(Path(resolved_file).as_posix()) - search = re.search("snapshots/([^/]+)/", resolved_file) - if search is None: - return None - commit_hash = search.groups()[0] - return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None - - -ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} - -SESSION_ID = uuid4().hex - -_flax_version = "N/A" - -_jax_version = "N/A" - -_onnxruntime_version = "N/A" - -_torch_version = "N/A" - - -class ConfigMixin: - """ - Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all - methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with - - [`~ConfigMixin.from_config`] - - [`~ConfigMixin.save_config`] - - Class attributes: - - **config_name** (`str`) -- A filename under which the config should stored when calling - [`~ConfigMixin.save_config`] (should be overridden by parent class). - - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be - overridden by subclass). - - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass). - - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function - should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by - subclass). - """ - - config_name = None - ignore_for_config = [] - has_compatibles = False - _deprecated_kwargs = [] - - def register_to_config(self, **kwargs): - if self.config_name is None: - raise NotImplementedError( - f"Make sure that {self.__class__} has defined a class name `config_name`" - ) - kwargs.pop("kwargs", None) - for key, value in kwargs.items(): - try: - setattr(self, key, value) - except AttributeError as err: - logger.error(f"Can't set {key} with value {value} for {self}") - raise err - if not hasattr(self, "_internal_dict"): - internal_dict = kwargs - else: - previous_dict = dict(self._internal_dict) - internal_dict = {**self._internal_dict, **kwargs} - logger.debug(f"Updating config from {previous_dict} to {internal_dict}") - self._internal_dict = FrozenDict(internal_dict) - - def save_config( - self, - save_directory: "Union[str, os.PathLike]", - push_to_hub: "bool" = False, - **kwargs, - ): - if os.path.isfile(save_directory): - raise AssertionError( - f"Provided path ({save_directory}) should be a directory, not a file" - ) - os.makedirs(save_directory, exist_ok=True) - output_config_file = os.path.join(save_directory, self.config_name) - self.to_json_file(output_config_file) - logger.info(f"Configuration saved in {output_config_file}") - - @classmethod - def from_config( - cls, - config: "Union[FrozenDict, Dict[str, Any]]" = None, - return_unused_kwargs=False, - **kwargs, - ): - if "pretrained_model_name_or_path" in kwargs: - config = kwargs.pop("pretrained_model_name_or_path") - if config is None: - raise ValueError( - "Please make sure to provide a config as the first positional argument." - ) - if not isinstance(config, dict): - deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`." - if "Scheduler" in cls.__name__: - deprecation_message += f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead. Otherwise, please make sure to pass a configuration dictionary instead. This functionality will be removed in v1.0.0." - elif "Model" in cls.__name__: - deprecation_message += f"If you were trying to load a model, please use {cls}.load_config(...) followed by {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary instead. This functionality will be removed in v1.0.0." - deprecate( - "config-passed-as-path", - "1.0.0", - deprecation_message, - standard_warn=False, - ) - config, kwargs = cls.load_config( - pretrained_model_name_or_path=config, - return_unused_kwargs=True, - **kwargs, - ) - init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs) - if "dtype" in unused_kwargs: - init_dict["dtype"] = unused_kwargs.pop("dtype") - for deprecated_kwarg in cls._deprecated_kwargs: - if deprecated_kwarg in unused_kwargs: - init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg) - model = cls(**init_dict) - model.register_to_config(**hidden_dict) - unused_kwargs = {**unused_kwargs, **hidden_dict} - if return_unused_kwargs: - return model, unused_kwargs - else: - return model - - @classmethod - def get_config_dict(cls, *args, **kwargs): - deprecation_message = f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be removed in version v1.0.0" - deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False) - return cls.load_config(*args, **kwargs) - - @classmethod - def load_config( - cls, - pretrained_model_name_or_path: "Union[str, os.PathLike]", - return_unused_kwargs=False, - return_commit_hash=False, - **kwargs, - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - """ - Instantiate a Python class from a config dictionary - - Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): - Can be either: - - - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an - organization name, like `google/ddpm-celebahq-256`. - - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g., - `./my_model_directory/`. - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info(`bool`, *optional*, defaults to `False`): - Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `transformers-cli login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - subfolder (`str`, *optional*, defaults to `""`): - In case the relevant files are located inside a subfolder of the model repo (either remote in - huggingface.co or downloaded locally), you can specify the folder name here. - return_unused_kwargs (`bool`, *optional*, defaults to `False): - Whether unused keyword arguments of the config shall be returned. - return_commit_hash (`bool`, *optional*, defaults to `False): - Whether the commit_hash of the loaded configuration shall be returned. - - - - It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated - models](https://huggingface.co/docs/hub/models-gated#gated-models). - - - - - - Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to - use this method in a firewalled environment. - - - """ - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - force_download = kwargs.pop("force_download", False) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - use_auth_token = kwargs.pop("use_auth_token", None) - local_files_only = kwargs.pop("local_files_only", False) - revision = kwargs.pop("revision", None) - _ = kwargs.pop("mirror", None) - subfolder = kwargs.pop("subfolder", None) - user_agent = kwargs.pop("user_agent", {}) - user_agent = {**user_agent, "file_type": "config"} - user_agent = http_user_agent(user_agent) - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - if cls.config_name is None: - raise ValueError( - "`self.config_name` is not defined. Note that one should not load a config from `ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`" - ) - if os.path.isfile(pretrained_model_name_or_path): - config_file = pretrained_model_name_or_path - elif os.path.isdir(pretrained_model_name_or_path): - if os.path.isfile( - os.path.join(pretrained_model_name_or_path, cls.config_name) - ): - config_file = os.path.join( - pretrained_model_name_or_path, cls.config_name - ) - elif subfolder is not None and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) - ): - config_file = os.path.join( - pretrained_model_name_or_path, subfolder, cls.config_name - ) - else: - raise EnvironmentError( - f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." - ) - else: - pass - try: - config_dict = cls._dict_from_json_file(config_file) - commit_hash = extract_commit_hash(config_file) - except (json.JSONDecodeError, UnicodeDecodeError): - raise EnvironmentError( - f"It looks like the config file at '{config_file}' is not a valid JSON file." - ) - if not (return_unused_kwargs or return_commit_hash): - return config_dict - outputs = (config_dict,) - if return_unused_kwargs: - outputs += (kwargs,) - if return_commit_hash: - outputs += (commit_hash,) - return outputs - - @staticmethod - def _get_init_keys(cls): - return set(dict(inspect.signature(cls.__init__).parameters).keys()) - - @classmethod - def extract_init_dict(cls, config_dict, **kwargs): - original_dict = {k: v for k, v in config_dict.items()} - expected_keys = cls._get_init_keys(cls) - expected_keys.remove("self") - if "kwargs" in expected_keys: - expected_keys.remove("kwargs") - if hasattr(cls, "_flax_internal_args"): - for arg in cls._flax_internal_args: - expected_keys.remove(arg) - if len(cls.ignore_for_config) > 0: - expected_keys = expected_keys - set(cls.ignore_for_config) - diffusers_library = importlib.import_module(__name__.split(".")[0]) - if cls.has_compatibles: - compatible_classes = [ - c for c in cls._get_compatibles() if not isinstance(c, DummyObject) - ] - else: - compatible_classes = [] - expected_keys_comp_cls = set() - for c in compatible_classes: - expected_keys_c = cls._get_init_keys(c) - expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c) - expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls) - config_dict = { - k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls - } - orig_cls_name = config_dict.pop("_class_name", cls.__name__) - if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name): - orig_cls = getattr(diffusers_library, orig_cls_name) - unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys - config_dict = { - k: v - for k, v in config_dict.items() - if k not in unexpected_keys_from_orig - } - config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")} - init_dict = {} - for key in expected_keys: - if key in kwargs and key in config_dict: - config_dict[key] = kwargs.pop(key) - if key in kwargs: - init_dict[key] = kwargs.pop(key) - elif key in config_dict: - init_dict[key] = config_dict.pop(key) - if len(config_dict) > 0: - logger.warning( - f"The config attributes {config_dict} were passed to {cls.__name__}, but are not expected and will be ignored. Please verify your {cls.config_name} configuration file." - ) - passed_keys = set(init_dict.keys()) - if len(expected_keys - passed_keys) > 0: - logger.info( - f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values." - ) - unused_kwargs = {**config_dict, **kwargs} - hidden_config_dict = { - k: v for k, v in original_dict.items() if k not in init_dict - } - return init_dict, unused_kwargs, hidden_config_dict - - @classmethod - def _dict_from_json_file(cls, json_file: "Union[str, os.PathLike]"): - with open(json_file, "r", encoding="utf-8") as reader: - text = reader.read() - return json.loads(text) - - def __repr__(self): - return f"{self.__class__.__name__} {self.to_json_string()}" - - @property - def config(self) -> Dict[str, Any]: - """ - Returns the config of the class as a frozen dictionary - - Returns: - `Dict[str, Any]`: Config of the class. - """ - return self._internal_dict - - def to_json_string(self) -> str: - """ - Serializes this instance to a JSON string. - - Returns: - `str`: String containing all the attributes that make up this configuration instance in JSON format. - """ - config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {} - config_dict["_class_name"] = self.__class__.__name__ - config_dict["_diffusers_version"] = __version__ - - def to_json_saveable(value): - if isinstance(value, np.ndarray): - value = value.tolist() - elif isinstance(value, PosixPath): - value = str(value) - return value - - config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()} - return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" - - def to_json_file(self, json_file_path: "Union[str, os.PathLike]"): - """ - Save this instance to a JSON file. - - Args: - json_file_path (`str` or `os.PathLike`): - Path to the JSON file in which this configuration instance's parameters will be saved. - """ - with open(json_file_path, "w", encoding="utf-8") as writer: - writer.write(self.to_json_string()) - - -class ImagePositionalEmbeddings(nn.Module): - """ - Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the - height and width of the latent space. - - For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 - - For VQ-diffusion: - - Output vector embeddings are used as input for the transformer. - - Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. - - Args: - num_embed (`int`): - Number of embeddings for the latent pixels embeddings. - height (`int`): - Height of the latent image i.e. the number of height embeddings. - width (`int`): - Width of the latent image i.e. the number of width embeddings. - embed_dim (`int`): - Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. - """ - - def __init__(self, num_embed: "int", height: "int", width: "int", embed_dim: "int"): - super().__init__() - self.height = height - self.width = width - self.num_embed = num_embed - self.embed_dim = embed_dim - self.emb = nn.Embedding(self.num_embed, embed_dim) - self.height_emb = nn.Embedding(self.height, embed_dim) - self.width_emb = nn.Embedding(self.width, embed_dim) - - def forward(self, index): - emb = self.emb(index) - height_emb = self.height_emb( - torch.arange(self.height, device=index.device).view(1, self.height) - ) - height_emb = height_emb.unsqueeze(2) - width_emb = self.width_emb( - torch.arange(self.width, device=index.device).view(1, self.width) - ) - width_emb = width_emb.unsqueeze(1) - pos_emb = height_emb + width_emb - pos_emb = pos_emb.view(1, self.height * self.width, -1) - emb = emb + pos_emb[:, : emb.shape[1], :] - return emb - - -CONFIG_NAME = "config.json" - -FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" - -SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" - -WEIGHTS_NAME = "diffusion_pytorch_model.bin" - - -def _add_variant(weights_name: "str", variant: "Optional[str]" = None) -> str: - if variant is not None: - splits = weights_name.split(".") - splits = splits[:-1] + [variant] + splits[-1:] - weights_name = ".".join(splits) - return weights_name - - -DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] - - -def _get_model_file( - pretrained_model_name_or_path, - *, - weights_name, - subfolder, - cache_dir, - force_download, - proxies, - resume_download, - local_files_only, - use_auth_token, - user_agent, - revision, - commit_hash=None, -): - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - if os.path.isfile(pretrained_model_name_or_path): - return pretrained_model_name_or_path - elif os.path.isdir(pretrained_model_name_or_path): - if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): - model_file = os.path.join(pretrained_model_name_or_path, weights_name) - return model_file - elif subfolder is not None and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, weights_name) - ): - model_file = os.path.join( - pretrained_model_name_or_path, subfolder, weights_name - ) - return model_file - else: - raise EnvironmentError( - f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}." - ) - else: - if ( - revision in DEPRECATED_REVISION_ARGS - and ( - weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME - ) - and version.parse(version.parse(__version__).base_version) - >= version.parse("0.17.0") - ): - try: - model_file = hf_hub_download( - pretrained_model_name_or_path, - filename=_add_variant(weights_name, revision), - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - user_agent=user_agent, - subfolder=subfolder, - revision=revision or commit_hash, - ) - warnings.warn( - f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.", - FutureWarning, - ) - return model_file - except: - warnings.warn( - f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name, revision)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name, revision)}' so that the correct variant file can be added.", - FutureWarning, - ) - - -def _load_state_dict_into_model(model_to_load, state_dict): - state_dict = state_dict.copy() - error_msgs = [] - - def load(module: "torch.nn.Module", prefix=""): - args = state_dict, prefix, {}, True, [], [], error_msgs - module._load_from_state_dict(*args) - for name, child in module._modules.items(): - if child is not None: - load(child, prefix + name + ".") - - load(model_to_load) - return error_msgs - - -def get_parameter_device(parameter: "torch.nn.Module"): - try: - return next(parameter.parameters()).device - except StopIteration: - - def find_tensor_attributes( - module: "torch.nn.Module", - ) -> List[Tuple[str, Tensor]]: - tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] - return tuples - - gen = parameter._named_members(get_members_fn=find_tensor_attributes) - first_tuple = next(gen) - return first_tuple[1].device - - -def get_parameter_dtype(parameter: "torch.nn.Module"): - try: - return next(parameter.parameters()).dtype - except StopIteration: - - def find_tensor_attributes( - module: "torch.nn.Module", - ) -> List[Tuple[str, Tensor]]: - tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] - return tuples - - gen = parameter._named_members(get_members_fn=find_tensor_attributes) - first_tuple = next(gen) - return first_tuple[1].dtype - - -def is_torch_version(operation: "str", version: "str"): - """ - Args: - Compares the current PyTorch version to a given reference with an operation. - operation (`str`): - A string representation of an operator, such as `">"` or `"<="` - version (`str`): - A string version of PyTorch - """ - return compare_versions(parse(_torch_version), operation, version) - - -def load_state_dict( - checkpoint_file: "Union[str, os.PathLike]", variant: "Optional[str]" = None -): - """ - Reads a checkpoint file, returning properly formatted errors if they arise. - """ - try: - if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant): - return torch.load(checkpoint_file, map_location="cpu") - else: - return safetensors.torch.load_file(checkpoint_file, device="cpu") - except Exception as e: - try: - with open(checkpoint_file) as f: - if f.read().startswith("version"): - raise OSError( - "You seem to have cloned a repository without having git-lfs installed. Please install git-lfs and run `git lfs install` followed by `git lfs pull` in the folder you cloned." - ) - else: - raise ValueError( - f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained model. Make sure you have saved the model properly." - ) from e - except (UnicodeDecodeError, ValueError): - raise OSError( - f"Unable to load weights from checkpoint file for '{checkpoint_file}' at '{checkpoint_file}'. If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True." - ) - - -class ModelMixin(torch.nn.Module): - """ - Base class for all models. - - [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading - and saving models. - - - **config_name** ([`str`]) -- A filename under which the model should be stored when calling - [`~models.ModelMixin.save_pretrained`]. - """ - - config_name = CONFIG_NAME - _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] - _supports_gradient_checkpointing = False - - def __init__(self): - super().__init__() - - @property - def is_gradient_checkpointing(self) -> bool: - """ - Whether gradient checkpointing is activated for this model or not. - - Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint - activations". - """ - return any( - hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing - for m in self.modules() - ) - - def enable_gradient_checkpointing(self): - """ - Activates gradient checkpointing for the current model. - - Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint - activations". - """ - if not self._supports_gradient_checkpointing: - raise ValueError( - f"{self.__class__.__name__} does not support gradient checkpointing." - ) - self.apply(partial(self._set_gradient_checkpointing, value=True)) - - def disable_gradient_checkpointing(self): - """ - Deactivates gradient checkpointing for the current model. - - Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint - activations". - """ - if self._supports_gradient_checkpointing: - self.apply(partial(self._set_gradient_checkpointing, value=False)) - - def set_use_memory_efficient_attention_xformers( - self, valid: "bool", attention_op: "Optional[Callable]" = None - ) -> None: - def fn_recursive_set_mem_eff(module: "torch.nn.Module"): - if hasattr(module, "set_use_memory_efficient_attention_xformers"): - module.set_use_memory_efficient_attention_xformers(valid, attention_op) - for child in module.children(): - fn_recursive_set_mem_eff(child) - - for module in self.children(): - if isinstance(module, torch.nn.Module): - fn_recursive_set_mem_eff(module) - - def enable_xformers_memory_efficient_attention( - self, attention_op: "Optional[Callable]" = None - ): - self.set_use_memory_efficient_attention_xformers(True, attention_op) - - def disable_xformers_memory_efficient_attention(self): - """ - Disable memory efficient attention as implemented in xformers. - """ - self.set_use_memory_efficient_attention_xformers(False) - - def save_pretrained( - self, - save_directory: "Union[str, os.PathLike]", - is_main_process: "bool" = True, - save_function: "Callable" = None, - safe_serialization: "bool" = False, - variant: "Optional[str]" = None, - ): - """ - Save a model and its configuration file to a directory, so that it can be re-loaded using the - `[`~models.ModelMixin.from_pretrained`]` class method. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to which to save. Will be created if it doesn't exist. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful when in distributed training like - TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on - the main process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful on distributed training like TPUs when one - need to replace `torch.save` by another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `False`): - Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). - variant (`str`, *optional*): - If specified, weights are saved in the format pytorch_model..bin. - """ - if safe_serialization and not is_safetensors_available(): - raise ImportError( - "`safe_serialization` requires the `safetensors library: `pip install safetensors`." - ) - if os.path.isfile(save_directory): - logger.error( - f"Provided path ({save_directory}) should be a directory, not a file" - ) - return - os.makedirs(save_directory, exist_ok=True) - model_to_save = self - if is_main_process: - model_to_save.save_config(save_directory) - state_dict = model_to_save.state_dict() - weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME - weights_name = _add_variant(weights_name, variant) - if safe_serialization: - safetensors.torch.save_file( - state_dict, - os.path.join(save_directory, weights_name), - metadata={"format": "pt"}, - ) - else: - torch.save(state_dict, os.path.join(save_directory, weights_name)) - logger.info( - f"Model weights saved in {os.path.join(save_directory, weights_name)}" - ) - - @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path: "Optional[Union[str, os.PathLike]]", - **kwargs, - ): - """ - Instantiate a pretrained pytorch model from a pre-trained model configuration. - - The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train - the model, you should first set it back in training mode with `model.train()`. - - The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come - pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning - task. - - The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those - weights are discarded. - - Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): - Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. - - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., - `./my_model_directory/`. - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype - will be automatically derived from the model's weights. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info(`bool`, *optional*, defaults to `False`): - Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `diffusers-cli login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - from_flax (`bool`, *optional*, defaults to `False`): - Load the model weights from a Flax checkpoint save file. - subfolder (`str`, *optional*, defaults to `""`): - In case the relevant files are located inside a subfolder of the model repo (either remote in - huggingface.co or downloaded locally), you can specify the folder name here. - - mirror (`str`, *optional*): - Mirror source to accelerate downloads in China. If you are from China and have an accessibility - problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. - Please refer to the mirror site for more information. - device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): - A map that specifies where each submodule should go. It doesn't need to be refined to each - parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the - same device. - - To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For - more information about each option see [designing a device - map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). - low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): - Speed up model loading by not initializing the weights and only loading the pre-trained weights. This - also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the - model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, - setting this argument to `True` will raise an error. - variant (`str`, *optional*): - If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is - ignored when using `from_flax`. - - - - It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated - models](https://huggingface.co/docs/hub/models-gated#gated-models). - - - - - - Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use - this method in a firewalled environment. - - - - """ - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) - force_download = kwargs.pop("force_download", False) - from_flax = kwargs.pop("from_flax", False) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - output_loading_info = kwargs.pop("output_loading_info", False) - local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - torch_dtype = kwargs.pop("torch_dtype", None) - subfolder = kwargs.pop("subfolder", None) - device_map = kwargs.pop("device_map", None) - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) - variant = kwargs.pop("variant", None) - if low_cpu_mem_usage and not is_accelerate_available(): - low_cpu_mem_usage = False - logger.warning( - "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip install accelerate\n```\n." - ) - if device_map is not None and not is_accelerate_available(): - raise NotImplementedError( - "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set `device_map=None`. You can install accelerate with `pip install accelerate`." - ) - if device_map is not None and not is_torch_version(">=", "1.9.0"): - raise NotImplementedError( - "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set `device_map=None`." - ) - if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): - raise NotImplementedError( - "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set `low_cpu_mem_usage=False`." - ) - if low_cpu_mem_usage is False and device_map is not None: - raise ValueError( - f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and dispatching. Please make sure to set `low_cpu_mem_usage=True`." - ) - config_path = pretrained_model_name_or_path - user_agent = { - "diffusers": __version__, - "file_type": "model", - "framework": "pytorch", - } - config, unused_kwargs, commit_hash = cls.load_config( - config_path, - cache_dir=cache_dir, - return_unused_kwargs=True, - return_commit_hash=True, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - device_map=device_map, - user_agent=user_agent, - **kwargs, - ) - model_file = None - if from_flax: - model = cls.from_config(config, **unused_kwargs) - else: - if is_safetensors_available(): - try: - model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - commit_hash=commit_hash, - ) - except: - pass - if model_file is None: - model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=_add_variant(WEIGHTS_NAME, variant), - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - commit_hash=commit_hash, - ) - if low_cpu_mem_usage: - with accelerate.init_empty_weights(): - model = cls.from_config(config, **unused_kwargs) - if device_map is None: - param_device = "cpu" - state_dict = load_state_dict(model_file, variant=variant) - missing_keys = set(model.state_dict().keys()) - set( - state_dict.keys() - ) - if len(missing_keys) > 0: - raise ValueError( - f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are missing: \n {', '.join(missing_keys)}. \n Please make sure to pass `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize those weights or else make sure your checkpoint file is correct." - ) - for param_name, param in state_dict.items(): - accepts_dtype = "dtype" in set( - inspect.signature( - set_module_tensor_to_device - ).parameters.keys() - ) - if accepts_dtype: - set_module_tensor_to_device( - model, - param_name, - param_device, - value=param, - dtype=torch_dtype, - ) - else: - set_module_tensor_to_device( - model, param_name, param_device, value=param - ) - else: - accelerate.load_checkpoint_and_dispatch( - model, model_file, device_map, dtype=torch_dtype - ) - loading_info = { - "missing_keys": [], - "unexpected_keys": [], - "mismatched_keys": [], - "error_msgs": [], - } - else: - model = cls.from_config(config, **unused_kwargs) - state_dict = load_state_dict(model_file, variant=variant) - ( - model, - missing_keys, - unexpected_keys, - mismatched_keys, - error_msgs, - ) = cls._load_pretrained_model( - model, - state_dict, - model_file, - pretrained_model_name_or_path, - ignore_mismatched_sizes=ignore_mismatched_sizes, - ) - loading_info = { - "missing_keys": missing_keys, - "unexpected_keys": unexpected_keys, - "mismatched_keys": mismatched_keys, - "error_msgs": error_msgs, - } - if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): - raise ValueError( - f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." - ) - elif torch_dtype is not None: - model = model - model.register_to_config(_name_or_path=pretrained_model_name_or_path) - model.eval() - if output_loading_info: - return model, loading_info - return model - - @classmethod - def _load_pretrained_model( - cls, - model, - state_dict, - resolved_archive_file, - pretrained_model_name_or_path, - ignore_mismatched_sizes=False, - ): - model_state_dict = model.state_dict() - loaded_keys = [k for k in state_dict.keys()] - expected_keys = list(model_state_dict.keys()) - original_loaded_keys = loaded_keys - missing_keys = list(set(expected_keys) - set(loaded_keys)) - unexpected_keys = list(set(loaded_keys) - set(expected_keys)) - model_to_load = model - - def _find_mismatched_keys( - state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes - ): - mismatched_keys = [] - if ignore_mismatched_sizes: - for checkpoint_key in loaded_keys: - model_key = checkpoint_key - if ( - model_key in model_state_dict - and state_dict[checkpoint_key].shape - != model_state_dict[model_key].shape - ): - mismatched_keys.append( - ( - checkpoint_key, - state_dict[checkpoint_key].shape, - model_state_dict[model_key].shape, - ) - ) - del state_dict[checkpoint_key] - return mismatched_keys - - if state_dict is not None: - mismatched_keys = _find_mismatched_keys( - state_dict, - model_state_dict, - original_loaded_keys, - ignore_mismatched_sizes, - ) - error_msgs = _load_state_dict_into_model(model_to_load, state_dict) - if len(error_msgs) > 0: - error_msg = "\n\t".join(error_msgs) - if "size mismatch" in error_msg: - error_msg += "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." - raise RuntimeError( - f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}" - ) - if len(unexpected_keys) > 0: - logger.warning( - f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." - ) - else: - logger.info( - f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n" - ) - if len(missing_keys) > 0: - logger.warning( - f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." - ) - elif len(mismatched_keys) == 0: - logger.info( - f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint was trained on, you can already use {model.__class__.__name__} for predictions without further training." - ) - if len(mismatched_keys) > 0: - mismatched_warning = "\n".join( - [ - f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" - for key, shape1, shape2 in mismatched_keys - ] - ) - logger.warning( - f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} and are newly initialized because the shapes did not match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." - ) - return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs - - @property - def device(self) -> device: - """ - `torch.device`: The device on which the module is (assuming that all the module parameters are on the same - device). - """ - return get_parameter_device(self) - - @property - def dtype(self) -> torch.dtype: - """ - `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). - """ - return get_parameter_dtype(self) - - def num_parameters( - self, only_trainable: "bool" = False, exclude_embeddings: "bool" = False - ) -> int: - """ - Get number of (optionally, trainable or non-embeddings) parameters in the module. - - Args: - only_trainable (`bool`, *optional*, defaults to `False`): - Whether or not to return only the number of trainable parameters - - exclude_embeddings (`bool`, *optional*, defaults to `False`): - Whether or not to return only the number of non-embeddings parameters - - Returns: - `int`: The number of parameters. - """ - if exclude_embeddings: - embedding_param_names = [ - f"{name}.weight" - for name, module_type in self.named_modules() - if isinstance(module_type, torch.nn.Embedding) - ] - non_embedding_parameters = [ - parameter - for name, parameter in self.named_parameters() - if name not in embedding_param_names - ] - return sum( - p.numel() - for p in non_embedding_parameters - if p.requires_grad or not only_trainable - ) - else: - return sum( - p.numel() - for p in self.parameters() - if p.requires_grad or not only_trainable - ) - - -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): - """ - embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) - """ - if embed_dim % 2 != 0: - raise ValueError("embed_dim must be divisible by 2") - omega = np.arange(embed_dim // 2, dtype=np.float64) - omega /= embed_dim / 2.0 - omega = 1.0 / 10000 ** omega - pos = pos.reshape(-1) - out = np.einsum("m,d->md", pos, omega) - emb_sin = np.sin(out) - emb_cos = np.cos(out) - emb = np.concatenate([emb_sin, emb_cos], axis=1) - return emb - - -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): - if embed_dim % 2 != 0: - raise ValueError("embed_dim must be divisible by 2") - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) - emb = np.concatenate([emb_h, emb_w], axis=1) - return emb - - -def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): - """ - grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or - [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) - """ - grid_h = np.arange(grid_size, dtype=np.float32) - grid_w = np.arange(grid_size, dtype=np.float32) - grid = np.meshgrid(grid_w, grid_h) - grid = np.stack(grid, axis=0) - grid = grid.reshape([2, 1, grid_size, grid_size]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) - if cls_token and extra_tokens > 0: - pos_embed = np.concatenate( - [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0 - ) - return pos_embed - - -class PatchEmbed(nn.Module): - """2D Image to Patch Embedding""" - - def __init__( - self, - height=224, - width=224, - patch_size=16, - in_channels=3, - embed_dim=768, - layer_norm=False, - flatten=True, - bias=True, - ): - super().__init__() - num_patches = height // patch_size * (width // patch_size) - self.flatten = flatten - self.layer_norm = layer_norm - self.proj = nn.Conv2d( - in_channels, - embed_dim, - kernel_size=(patch_size, patch_size), - stride=patch_size, - bias=bias, - ) - if layer_norm: - self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-06) - else: - self.norm = None - pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches ** 0.5)) - self.register_buffer( - "pos_embed", - torch.from_numpy(pos_embed).float().unsqueeze(0), - persistent=False, - ) - - def forward(self, latent): - latent = self.proj(latent) - if self.flatten: - latent = latent.flatten(2).transpose(1, 2) - if self.layer_norm: - latent = self.norm(latent) - return latent + self.pos_embed - - -class BaseOutput(OrderedDict): - """ - Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a - tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular - python dictionary. - - - - You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple - before. - - - """ - - def __post_init__(self): - class_fields = fields(self) - if not len(class_fields): - raise ValueError(f"{self.__class__.__name__} has no fields.") - first_field = getattr(self, class_fields[0].name) - other_fields_are_none = all( - getattr(self, field.name) is None for field in class_fields[1:] - ) - if other_fields_are_none and isinstance(first_field, dict): - for key, value in first_field.items(): - self[key] = value - else: - for field in class_fields: - v = getattr(self, field.name) - if v is not None: - self[field.name] = v - - def __delitem__(self, *args, **kwargs): - raise Exception( - f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance." - ) - - def setdefault(self, *args, **kwargs): - raise Exception( - f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance." - ) - - def pop(self, *args, **kwargs): - raise Exception( - f"You cannot use ``pop`` on a {self.__class__.__name__} instance." - ) - - def update(self, *args, **kwargs): - raise Exception( - f"You cannot use ``update`` on a {self.__class__.__name__} instance." - ) - - def __getitem__(self, k): - if isinstance(k, str): - inner_dict = {k: v for k, v in self.items()} - return inner_dict[k] - else: - return self.to_tuple()[k] - - def __setattr__(self, name, value): - if name in self.keys() and value is not None: - super().__setitem__(name, value) - super().__setattr__(name, value) - - def __setitem__(self, key, value): - super().__setitem__(key, value) - super().__setattr__(key, value) - - def to_tuple(self) -> Tuple[Any]: - """ - Convert self to a tuple containing all the attributes/keys that are not `None`. - """ - return tuple(self[k] for k in self.keys()) - - -@dataclass -class Transformer2DModelOutput(BaseOutput): - """ - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): - Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions - for the unnoised latent pixels. - """ - - sample: "torch.FloatTensor" - - -def register_to_config(init): - """ - Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are - automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that - shouldn't be registered in the config, use the `ignore_for_config` class variable - - Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! - """ - - @functools.wraps(init) - def inner_init(self, *args, **kwargs): - init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} - config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")} - if not isinstance(self, ConfigMixin): - raise RuntimeError( - f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does not inherit from `ConfigMixin`." - ) - ignore = getattr(self, "ignore_for_config", []) - new_kwargs = {} - signature = inspect.signature(init) - parameters = { - name: p.default - for i, (name, p) in enumerate(signature.parameters.items()) - if i > 0 and name not in ignore - } - for arg, name in zip(args, parameters.keys()): - new_kwargs[name] = arg - new_kwargs.update( - { - k: init_kwargs.get(k, default) - for k, default in parameters.items() - if k not in ignore and k not in new_kwargs - } - ) - new_kwargs = {**config_init_kwargs, **new_kwargs} - getattr(self, "register_to_config")(**new_kwargs) - init(self, *args, **init_kwargs) - - return inner_init - - -class Transformer2DModel(ModelMixin, ConfigMixin): - """ - Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual - embeddings) inputs. - - When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard - transformer action. Finally, reshape to image. - - When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional - embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict - classes of unnoised image. - - Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised - image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): - Pass if the input is continuous. The number of channels in the input and output. - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. - sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. - Note that this is fixed at training time as it is used for learning a number of position embeddings. See - `ImagePositionalEmbeddings`. - num_vector_embeds (`int`, *optional*): - Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. - Includes the class for the masked latent pixel. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. - The number of diffusion steps used during training. Note that this is fixed at training time as it is used - to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for - up to but not more than steps than `num_embeds_ada_norm`. - attention_bias (`bool`, *optional*): - Configure if the TransformerBlocks' attention should contain a bias parameter. - """ - - @register_to_config - def __init__( - self, - num_attention_heads: "int" = 16, - attention_head_dim: "int" = 88, - in_channels: "Optional[int]" = None, - out_channels: "Optional[int]" = None, - num_layers: "int" = 1, - dropout: "float" = 0.0, - norm_num_groups: "int" = 32, - cross_attention_dim: "Optional[int]" = None, - attention_bias: "bool" = False, - sample_size: "Optional[int]" = None, - num_vector_embeds: "Optional[int]" = None, - patch_size: "Optional[int]" = None, - activation_fn: "str" = "geglu", - num_embeds_ada_norm: "Optional[int]" = None, - use_linear_projection: "bool" = False, - only_cross_attention: "bool" = False, - upcast_attention: "bool" = False, - norm_type: "str" = "layer_norm", - norm_elementwise_affine: "bool" = True, - ): - super().__init__() - self.use_linear_projection = use_linear_projection - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - inner_dim = num_attention_heads * attention_head_dim - self.is_input_continuous = in_channels is not None and patch_size is None - self.is_input_vectorized = num_vector_embeds is not None - self.is_input_patches = in_channels is not None and patch_size is not None - if norm_type == "layer_norm" and num_embeds_ada_norm is not None: - deprecation_message = f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config. Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for the `transformer/config.json` file" - deprecate( - "norm_type!=num_embeds_ada_norm", - "1.0.0", - deprecation_message, - standard_warn=False, - ) - norm_type = "ada_norm" - if self.is_input_continuous and self.is_input_vectorized: - raise ValueError( - f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make sure that either `in_channels` or `num_vector_embeds` is None." - ) - elif self.is_input_vectorized and self.is_input_patches: - raise ValueError( - f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make sure that either `num_vector_embeds` or `num_patches` is None." - ) - elif ( - not self.is_input_continuous - and not self.is_input_vectorized - and not self.is_input_patches - ): - raise ValueError( - f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size: {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." - ) - if self.is_input_continuous: - self.in_channels = in_channels - self.norm = torch.nn.GroupNorm( - num_groups=norm_num_groups, - num_channels=in_channels, - eps=1e-06, - affine=True, - ) - if use_linear_projection: - self.proj_in = nn.Linear(in_channels, inner_dim) - else: - self.proj_in = nn.Conv2d( - in_channels, inner_dim, kernel_size=1, stride=1, padding=0 - ) - elif self.is_input_vectorized: - assert ( - sample_size is not None - ), "Transformer2DModel over discrete input must provide sample_size" - assert ( - num_vector_embeds is not None - ), "Transformer2DModel over discrete input must provide num_embed" - self.height = sample_size - self.width = sample_size - self.num_vector_embeds = num_vector_embeds - self.num_latent_pixels = self.height * self.width - self.latent_image_embedding = ImagePositionalEmbeddings( - num_embed=num_vector_embeds, - embed_dim=inner_dim, - height=self.height, - width=self.width, - ) - elif self.is_input_patches: - assert ( - sample_size is not None - ), "Transformer2DModel over patched input must provide sample_size" - self.height = sample_size - self.width = sample_size - self.patch_size = patch_size - self.pos_embed = PatchEmbed( - height=sample_size, - width=sample_size, - patch_size=patch_size, - in_channels=in_channels, - embed_dim=inner_dim, - ) - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - attention_bias=attention_bias, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - norm_type=norm_type, - norm_elementwise_affine=norm_elementwise_affine, - ) - for d in range(num_layers) - ] - ) - self.out_channels = in_channels if out_channels is None else out_channels - if self.is_input_continuous: - if use_linear_projection: - self.proj_out = nn.Linear(inner_dim, in_channels) - else: - self.proj_out = nn.Conv2d( - inner_dim, in_channels, kernel_size=1, stride=1, padding=0 - ) - elif self.is_input_vectorized: - self.norm_out = nn.LayerNorm(inner_dim) - self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) - elif self.is_input_patches: - self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-06) - self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) - self.proj_out_2 = nn.Linear( - inner_dim, patch_size * patch_size * self.out_channels - ) - - def forward( - self, - hidden_states, - encoder_hidden_states=None, - timestep=None, - class_labels=None, - cross_attention_kwargs=None, - return_dict: "bool" = True, - ): - """ - Args: - hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. - When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input - hidden_states - encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - timestep ( `torch.long`, *optional*): - Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. - class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): - Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels - conditioning. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. - - Returns: - [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: - [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - """ - if self.is_input_continuous: - batch, _, height, width = hidden_states.shape - residual = hidden_states - hidden_states = self.norm(hidden_states) - if not self.use_linear_projection: - hidden_states = self.proj_in(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( - batch, height * width, inner_dim - ) - else: - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( - batch, height * width, inner_dim - ) - hidden_states = self.proj_in(hidden_states) - elif self.is_input_vectorized: - hidden_states = self.latent_image_embedding(hidden_states) - elif self.is_input_patches: - hidden_states = self.pos_embed(hidden_states) - for block in self.transformer_blocks: - hidden_states = block( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) - if self.is_input_continuous: - if not self.use_linear_projection: - hidden_states = ( - hidden_states.reshape(batch, height, width, inner_dim) - .permute(0, 3, 1, 2) - .contiguous() - ) - hidden_states = self.proj_out(hidden_states) - else: - hidden_states = self.proj_out(hidden_states) - hidden_states = ( - hidden_states.reshape(batch, height, width, inner_dim) - .permute(0, 3, 1, 2) - .contiguous() - ) - output = hidden_states + residual - elif self.is_input_vectorized: - hidden_states = self.norm_out(hidden_states) - logits = self.out(hidden_states) - logits = logits.permute(0, 2, 1) - output = F.log_softmax(logits.double(), dim=1).float() - elif self.is_input_patches: - conditioning = self.transformer_blocks[0].norm1.emb( - timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) - hidden_states = ( - self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] - ) - hidden_states = self.proj_out_2(hidden_states) - height = width = int(hidden_states.shape[1] ** 0.5) - hidden_states = hidden_states.reshape( - shape=( - -1, - height, - width, - self.patch_size, - self.patch_size, - self.out_channels, - ) - ) - hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) - output = hidden_states.reshape( - shape=( - -1, - self.out_channels, - height * self.patch_size, - width * self.patch_size, - ) - ) - if not return_dict: - return (output,) - return Transformer2DModelOutput(sample=output) - - -class DualTransformer2DModel(nn.Module): - """ - Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): - Pass if the input is continuous. The number of channels in the input and output. - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. - sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. - Note that this is fixed at training time as it is used for learning a number of position embeddings. See - `ImagePositionalEmbeddings`. - num_vector_embeds (`int`, *optional*): - Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. - Includes the class for the masked latent pixel. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. - The number of diffusion steps used during training. Note that this is fixed at training time as it is used - to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for - up to but not more than steps than `num_embeds_ada_norm`. - attention_bias (`bool`, *optional*): - Configure if the TransformerBlocks' attention should contain a bias parameter. - """ - - def __init__( - self, - num_attention_heads: "int" = 16, - attention_head_dim: "int" = 88, - in_channels: "Optional[int]" = None, - num_layers: "int" = 1, - dropout: "float" = 0.0, - norm_num_groups: "int" = 32, - cross_attention_dim: "Optional[int]" = None, - attention_bias: "bool" = False, - sample_size: "Optional[int]" = None, - num_vector_embeds: "Optional[int]" = None, - activation_fn: "str" = "geglu", - num_embeds_ada_norm: "Optional[int]" = None, - ): - super().__init__() - self.transformers = nn.ModuleList( - [ - Transformer2DModel( - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - in_channels=in_channels, - num_layers=num_layers, - dropout=dropout, - norm_num_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - attention_bias=attention_bias, - sample_size=sample_size, - num_vector_embeds=num_vector_embeds, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - ) - for _ in range(2) - ] - ) - self.mix_ratio = 0.5 - self.condition_lengths = [77, 257] - self.transformer_index_for_condition = [1, 0] - - def forward( - self, - hidden_states, - encoder_hidden_states, - timestep=None, - attention_mask=None, - cross_attention_kwargs=None, - return_dict: "bool" = True, - ): - """ - Args: - hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. - When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input - hidden_states - encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - timestep ( `torch.long`, *optional*): - Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. - attention_mask (`torch.FloatTensor`, *optional*): - Optional attention mask to be applied in CrossAttention - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. - - Returns: - [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: - [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - """ - input_states = hidden_states - encoded_states = [] - tokens_start = 0 - for i in range(2): - condition_state = encoder_hidden_states[ - :, tokens_start : tokens_start + self.condition_lengths[i] - ] - transformer_index = self.transformer_index_for_condition[i] - encoded_state = self.transformers[transformer_index]( - input_states, - encoder_hidden_states=condition_state, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - return_dict=False, - )[0] - encoded_states.append(encoded_state - input_states) - tokens_start += self.condition_lengths[i] - output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * ( - 1 - self.mix_ratio - ) - output_states = output_states + input_states - if not return_dict: - return (output_states,) - return Transformer2DModelOutput(sample=output_states) - - -class GaussianFourierProjection(nn.Module): - """Gaussian Fourier embeddings for noise levels.""" - - def __init__( - self, - embedding_size: "int" = 256, - scale: "float" = 1.0, - set_W_to_weight=True, - log=True, - flip_sin_to_cos=False, - ): - super().__init__() - self.weight = nn.Parameter( - torch.randn(embedding_size) * scale, requires_grad=False - ) - self.log = log - self.flip_sin_to_cos = flip_sin_to_cos - if set_W_to_weight: - self.W = nn.Parameter( - torch.randn(embedding_size) * scale, requires_grad=False - ) - self.weight = self.W - - def forward(self, x): - if self.log: - x = torch.log(x) - x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi - if self.flip_sin_to_cos: - out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) - else: - out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) - return out - - -@dataclass -class PriorTransformerOutput(BaseOutput): - """ - Args: - predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): - The predicted CLIP image embedding conditioned on the CLIP text embedding input. - """ - - predicted_image_embedding: "torch.FloatTensor" - - -class PriorTransformer(ModelMixin, ConfigMixin): - """ - The prior transformer from unCLIP is used to predict CLIP image embeddings from CLIP text embeddings. Note that the - transformer predicts the image embeddings through a denoising diffusion process. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the models (such as downloading or saving, etc.) - - For more details, see the original paper: https://arxiv.org/abs/2204.06125 - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. - num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use. - embedding_dim (`int`, *optional*, defaults to 768): The dimension of the CLIP embeddings. Note that CLIP - image embeddings and text embeddings are both the same dimension. - num_embeddings (`int`, *optional*, defaults to 77): The max number of clip embeddings allowed. I.e. the - length of the prompt after it has been tokenized. - additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the - projected hidden_states. The actual length of the used hidden_states is `num_embeddings + - additional_embeddings`. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - - """ - - @register_to_config - def __init__( - self, - num_attention_heads: "int" = 32, - attention_head_dim: "int" = 64, - num_layers: "int" = 20, - embedding_dim: "int" = 768, - num_embeddings=77, - additional_embeddings=4, - dropout: "float" = 0.0, - ): - super().__init__() - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - inner_dim = num_attention_heads * attention_head_dim - self.additional_embeddings = additional_embeddings - self.time_proj = Timesteps(inner_dim, True, 0) - self.time_embedding = TimestepEmbedding(inner_dim, inner_dim) - self.proj_in = nn.Linear(embedding_dim, inner_dim) - self.embedding_proj = nn.Linear(embedding_dim, inner_dim) - self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim) - self.positional_embedding = nn.Parameter( - torch.zeros(1, num_embeddings + additional_embeddings, inner_dim) - ) - self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim)) - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - activation_fn="gelu", - attention_bias=True, - ) - for d in range(num_layers) - ] - ) - self.norm_out = nn.LayerNorm(inner_dim) - self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim) - causal_attention_mask = torch.full( - [ - num_embeddings + additional_embeddings, - num_embeddings + additional_embeddings, - ], - -10000.0, - ) - causal_attention_mask.triu_(1) - causal_attention_mask = causal_attention_mask[None, ...] - self.register_buffer( - "causal_attention_mask", causal_attention_mask, persistent=False - ) - self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim)) - self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim)) - - def forward( - self, - hidden_states, - timestep: "Union[torch.Tensor, float, int]", - proj_embedding: "torch.FloatTensor", - encoder_hidden_states: "torch.FloatTensor", - attention_mask: "Optional[torch.BoolTensor]" = None, - return_dict: "bool" = True, - ): - """ - Args: - hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): - x_t, the currently predicted image embeddings. - timestep (`torch.long`): - Current denoising step. - proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): - Projected embedding vector the denoising process is conditioned on. - encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`): - Hidden states of the text embeddings the denoising process is conditioned on. - attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`): - Text mask for the text embeddings. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.prior_transformer.PriorTransformerOutput`] instead of a plain - tuple. - - Returns: - [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`: - [`~models.prior_transformer.PriorTransformerOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - """ - batch_size = hidden_states.shape[0] - timesteps = timestep - if not torch.is_tensor(timesteps): - timesteps = torch.tensor( - [timesteps], dtype=torch.long, device=hidden_states.device - ) - elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: - timesteps = timesteps[None] - timesteps = timesteps * torch.ones( - batch_size, dtype=timesteps.dtype, device=timesteps.device - ) - timesteps_projected = self.time_proj(timesteps) - timesteps_projected = timesteps_projected - time_embeddings = self.time_embedding(timesteps_projected) - proj_embeddings = self.embedding_proj(proj_embedding) - encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states) - hidden_states = self.proj_in(hidden_states) - prd_embedding = self.prd_embedding.expand(batch_size, -1, -1) - positional_embeddings = self.positional_embedding - hidden_states = torch.cat( - [ - encoder_hidden_states, - proj_embeddings[:, None, :], - time_embeddings[:, None, :], - hidden_states[:, None, :], - prd_embedding, - ], - dim=1, - ) - hidden_states = hidden_states + positional_embeddings - if attention_mask is not None: - attention_mask = (1 - attention_mask) * -10000.0 - attention_mask = F.pad( - attention_mask, (0, self.additional_embeddings), value=0.0 - ) - attention_mask = attention_mask[:, None, :] + self.causal_attention_mask - attention_mask = attention_mask.repeat_interleave( - self.config.num_attention_heads, dim=0 - ) - for block in self.transformer_blocks: - hidden_states = block(hidden_states, attention_mask=attention_mask) - hidden_states = self.norm_out(hidden_states) - hidden_states = hidden_states[:, -1] - predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states) - if not return_dict: - return (predicted_image_embedding,) - return PriorTransformerOutput( - predicted_image_embedding=predicted_image_embedding - ) - - def post_process_latents(self, prior_latents): - prior_latents = prior_latents * self.clip_std + self.clip_mean - return prior_latents - - -class Upsample1D(nn.Module): - """ - An upsampling layer with an optional convolution. - - Parameters: - channels: channels in the inputs and outputs. - use_conv: a bool determining if a convolution is applied. - use_conv_transpose: - out_channels: - """ - - def __init__( - self, - channels, - use_conv=False, - use_conv_transpose=False, - out_channels=None, - name="conv", - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_conv_transpose = use_conv_transpose - self.name = name - self.conv = None - if use_conv_transpose: - self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) - elif use_conv: - self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) - - def forward(self, x): - assert x.shape[1] == self.channels - if self.use_conv_transpose: - return self.conv(x) - x = F.interpolate(x, scale_factor=2.0, mode="nearest") - if self.use_conv: - x = self.conv(x) - return x - - -class Downsample1D(nn.Module): - """ - A downsampling layer with an optional convolution. - - Parameters: - channels: channels in the inputs and outputs. - use_conv: a bool determining if a convolution is applied. - out_channels: - padding: - """ - - def __init__( - self, channels, use_conv=False, out_channels=None, padding=1, name="conv" - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.padding = padding - stride = 2 - self.name = name - if use_conv: - self.conv = nn.Conv1d( - self.channels, self.out_channels, 3, stride=stride, padding=padding - ) - else: - assert self.channels == self.out_channels - self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) - - def forward(self, x): - assert x.shape[1] == self.channels - return self.conv(x) - - -class Upsample2D(nn.Module): - """ - An upsampling layer with an optional convolution. - - Parameters: - channels: channels in the inputs and outputs. - use_conv: a bool determining if a convolution is applied. - use_conv_transpose: - out_channels: - """ - - def __init__( - self, - channels, - use_conv=False, - use_conv_transpose=False, - out_channels=None, - name="conv", - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_conv_transpose = use_conv_transpose - self.name = name - conv = None - if use_conv_transpose: - conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) - elif use_conv: - conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) - if name == "conv": - self.conv = conv - else: - self.Conv2d_0 = conv - - def forward(self, hidden_states, output_size=None): - assert hidden_states.shape[1] == self.channels - if self.use_conv_transpose: - return self.conv(hidden_states) - dtype = hidden_states.dtype - if dtype == torch.bfloat16: - hidden_states = hidden_states - if hidden_states.shape[0] >= 64: - hidden_states = hidden_states.contiguous() - if output_size is None: - hidden_states = F.interpolate( - hidden_states, scale_factor=2.0, mode="nearest" - ) - else: - hidden_states = F.interpolate( - hidden_states, size=output_size, mode="nearest" - ) - if dtype == torch.bfloat16: - hidden_states = hidden_states - if self.use_conv: - if self.name == "conv": - hidden_states = self.conv(hidden_states) - else: - hidden_states = self.Conv2d_0(hidden_states) - return hidden_states - - -class Downsample2D(nn.Module): - """ - A downsampling layer with an optional convolution. - - Parameters: - channels: channels in the inputs and outputs. - use_conv: a bool determining if a convolution is applied. - out_channels: - padding: - """ - - def __init__( - self, channels, use_conv=False, out_channels=None, padding=1, name="conv" - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.padding = padding - stride = 2 - self.name = name - if use_conv: - conv = nn.Conv2d( - self.channels, self.out_channels, 3, stride=stride, padding=padding - ) - else: - assert self.channels == self.out_channels - conv = nn.AvgPool2d(kernel_size=stride, stride=stride) - if name == "conv": - self.Conv2d_0 = conv - self.conv = conv - elif name == "Conv2d_0": - self.conv = conv - else: - self.conv = conv - - def forward(self, hidden_states): - assert hidden_states.shape[1] == self.channels - if self.use_conv and self.padding == 0: - pad = 0, 1, 0, 1 - hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) - assert hidden_states.shape[1] == self.channels - hidden_states = self.conv(hidden_states) - return hidden_states - - -def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)): - up_x = up_y = up - down_x = down_y = down - pad_x0 = pad_y0 = pad[0] - pad_x1 = pad_y1 = pad[1] - _, channel, in_h, in_w = tensor.shape - tensor = tensor.reshape(-1, in_h, in_w, 1) - _, in_h, in_w, minor = tensor.shape - kernel_h, kernel_w = kernel.shape - out = tensor.view(-1, in_h, 1, in_w, 1, minor) - out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) - out = out.view(-1, in_h * up_y, in_w * up_x, minor) - out = F.pad( - out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] - ) - out = out - out = out[ - :, - max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), - max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), - :, - ] - out = out.permute(0, 3, 1, 2) - out = out.reshape( - [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] - ) - w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) - out = F.conv2d(out, w) - out = out.reshape( - -1, - minor, - in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, - in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, - ) - out = out.permute(0, 2, 3, 1) - out = out[:, ::down_y, ::down_x, :] - out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 - out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 - return out.view(-1, channel, out_h, out_w) - - -class FirUpsample2D(nn.Module): - def __init__( - self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1) - ): - super().__init__() - out_channels = out_channels if out_channels else channels - if use_conv: - self.Conv2d_0 = nn.Conv2d( - channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - self.use_conv = use_conv - self.fir_kernel = fir_kernel - self.out_channels = out_channels - - def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): - """Fused `upsample_2d()` followed by `Conv2d()`. - - Padding is performed only once at the beginning, not between the operations. The fused op is considerably more - efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of - arbitrary order. - - Args: - hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - weight: Weight tensor of the shape `[filterH, filterW, inChannels, - outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. - kernel: FIR filter of the shape `[firH, firW]` or `[firN]` - (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. - factor: Integer upsampling factor (default: 2). - gain: Scaling factor for signal magnitude (default: 1.0). - - Returns: - output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same - datatype as `hidden_states`. - """ - assert isinstance(factor, int) and factor >= 1 - if kernel is None: - kernel = [1] * factor - kernel = torch.tensor(kernel, dtype=torch.float32) - if kernel.ndim == 1: - kernel = torch.outer(kernel, kernel) - kernel /= torch.sum(kernel) - kernel = kernel * (gain * factor ** 2) - if self.use_conv: - convH = weight.shape[2] - convW = weight.shape[3] - inC = weight.shape[1] - pad_value = kernel.shape[0] - factor - (convW - 1) - stride = factor, factor - output_shape = (hidden_states.shape[2] - 1) * factor + convH, ( - hidden_states.shape[3] - 1 - ) * factor + convW - output_padding = ( - output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH, - output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW, - ) - assert output_padding[0] >= 0 and output_padding[1] >= 0 - num_groups = hidden_states.shape[1] // inC - weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) - weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) - weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) - inverse_conv = F.conv_transpose2d( - hidden_states, - weight, - stride=stride, - output_padding=output_padding, - padding=0, - ) - output = upfirdn2d_native( - inverse_conv, - torch.tensor(kernel, device=inverse_conv.device), - pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1), - ) - else: - pad_value = kernel.shape[0] - factor - output = upfirdn2d_native( - hidden_states, - torch.tensor(kernel, device=hidden_states.device), - up=factor, - pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), - ) - return output - - def forward(self, hidden_states): - if self.use_conv: - height = self._upsample_2d( - hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel - ) - height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) - else: - height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) - return height - - -class FirDownsample2D(nn.Module): - def __init__( - self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1) - ): - super().__init__() - out_channels = out_channels if out_channels else channels - if use_conv: - self.Conv2d_0 = nn.Conv2d( - channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - self.fir_kernel = fir_kernel - self.use_conv = use_conv - self.out_channels = out_channels - - def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): - """Fused `Conv2d()` followed by `downsample_2d()`. - Padding is performed only once at the beginning, not between the operations. The fused op is considerably more - efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of - arbitrary order. - - Args: - hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - weight: - Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be - performed by `inChannels = x.shape[0] // numGroups`. - kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * - factor`, which corresponds to average pooling. - factor: Integer downsampling factor (default: 2). - gain: Scaling factor for signal magnitude (default: 1.0). - - Returns: - output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and - same datatype as `x`. - """ - assert isinstance(factor, int) and factor >= 1 - if kernel is None: - kernel = [1] * factor - kernel = torch.tensor(kernel, dtype=torch.float32) - if kernel.ndim == 1: - kernel = torch.outer(kernel, kernel) - kernel /= torch.sum(kernel) - kernel = kernel * gain - if self.use_conv: - _, _, convH, convW = weight.shape - pad_value = kernel.shape[0] - factor + (convW - 1) - stride_value = [factor, factor] - upfirdn_input = upfirdn2d_native( - hidden_states, - torch.tensor(kernel, device=hidden_states.device), - pad=((pad_value + 1) // 2, pad_value // 2), - ) - output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) - else: - pad_value = kernel.shape[0] - factor - output = upfirdn2d_native( - hidden_states, - torch.tensor(kernel, device=hidden_states.device), - down=factor, - pad=((pad_value + 1) // 2, pad_value // 2), - ) - return output - - def forward(self, hidden_states): - if self.use_conv: - downsample_input = self._downsample_2d( - hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel - ) - hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) - else: - hidden_states = self._downsample_2d( - hidden_states, kernel=self.fir_kernel, factor=2 - ) - return hidden_states - - -class KDownsample2D(nn.Module): - def __init__(self, pad_mode="reflect"): - super().__init__() - self.pad_mode = pad_mode - kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) - self.pad = kernel_1d.shape[1] // 2 - 1 - self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) - - def forward(self, x): - x = F.pad(x, (self.pad,) * 4, self.pad_mode) - weight = x.new_zeros( - [x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]] - ) - indices = torch.arange(x.shape[1], device=x.device) - weight[indices, indices] = self.kernel - return F.conv2d(x, weight, stride=2) - - -class KUpsample2D(nn.Module): - def __init__(self, pad_mode="reflect"): - super().__init__() - self.pad_mode = pad_mode - kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2 - self.pad = kernel_1d.shape[1] // 2 - 1 - self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) - - def forward(self, x): - x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode) - weight = x.new_zeros( - [x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]] - ) - indices = torch.arange(x.shape[1], device=x.device) - weight[indices, indices] = self.kernel - return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1) - - -def downsample_2d(hidden_states, kernel=None, factor=2, gain=1): - """Downsample2D a batch of 2D images with the given filter. - Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the - given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the - specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its - shape is a multiple of the downsampling factor. - - Args: - hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - kernel: FIR filter of the shape `[firH, firW]` or `[firN]` - (separable). The default is `[1] * factor`, which corresponds to average pooling. - factor: Integer downsampling factor (default: 2). - gain: Scaling factor for signal magnitude (default: 1.0). - - Returns: - output: Tensor of the shape `[N, C, H // factor, W // factor]` - """ - assert isinstance(factor, int) and factor >= 1 - if kernel is None: - kernel = [1] * factor - kernel = torch.tensor(kernel, dtype=torch.float32) - if kernel.ndim == 1: - kernel = torch.outer(kernel, kernel) - kernel /= torch.sum(kernel) - kernel = kernel * gain - pad_value = kernel.shape[0] - factor - output = upfirdn2d_native( - hidden_states, kernel, down=factor, pad=((pad_value + 1) // 2, pad_value // 2) - ) - return output - - -def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): - """Upsample2D a batch of 2D images with the given filter. - Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given - filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified - `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is - a: multiple of the upsampling factor. - - Args: - hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - kernel: FIR filter of the shape `[firH, firW]` or `[firN]` - (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. - factor: Integer upsampling factor (default: 2). - gain: Scaling factor for signal magnitude (default: 1.0). - - Returns: - output: Tensor of the shape `[N, C, H * factor, W * factor]` - """ - assert isinstance(factor, int) and factor >= 1 - if kernel is None: - kernel = [1] * factor - kernel = torch.tensor(kernel, dtype=torch.float32) - if kernel.ndim == 1: - kernel = torch.outer(kernel, kernel) - kernel /= torch.sum(kernel) - kernel = kernel * (gain * factor ** 2) - pad_value = kernel.shape[0] - factor - output = upfirdn2d_native( - hidden_states, - kernel, - up=factor, - pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), - ) - return output - - -class ResnetBlock2D(nn.Module): - """ - A Resnet block. - - Parameters: - in_channels (`int`): The number of channels in the input. - out_channels (`int`, *optional*, default to be `None`): - The number of output channels for the first conv2d layer. If None, same as `in_channels`. - dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. - temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. - groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. - groups_out (`int`, *optional*, default to None): - The number of groups to use for the second normalization layer. if set to None, same as `groups`. - eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. - non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. - time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. - By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or - "ada_group" for a stronger conditioning with scale and shift. - kernal (`torch.FloatTensor`, optional, default to None): FIR filter, see - [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. - output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. - use_in_shortcut (`bool`, *optional*, default to `True`): - If `True`, add a 1x1 nn.conv2d layer for skip-connection. - up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. - down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. - conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the - `conv_shortcut` output. - conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. - If None, same as `out_channels`. - """ - - def __init__( - self, - *, - in_channels, - out_channels=None, - conv_shortcut=False, - dropout=0.0, - temb_channels=512, - groups=32, - groups_out=None, - pre_norm=True, - eps=1e-06, - non_linearity="swish", - time_embedding_norm="default", - kernel=None, - output_scale_factor=1.0, - use_in_shortcut=None, - up=False, - down=False, - conv_shortcut_bias: bool = True, - conv_2d_out_channels: Optional[int] = None, - ): - super().__init__() - self.pre_norm = pre_norm - self.pre_norm = True - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - self.up = up - self.down = down - self.output_scale_factor = output_scale_factor - self.time_embedding_norm = time_embedding_norm - if groups_out is None: - groups_out = groups - if self.time_embedding_norm == "ada_group": - self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) - else: - self.norm1 = torch.nn.GroupNorm( - num_groups=groups, num_channels=in_channels, eps=eps, affine=True - ) - self.conv1 = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - if temb_channels is not None: - if self.time_embedding_norm == "default": - self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) - elif self.time_embedding_norm == "scale_shift": - self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) - elif self.time_embedding_norm == "ada_group": - self.time_emb_proj = None - else: - raise ValueError( - f"unknown time_embedding_norm : {self.time_embedding_norm} " - ) - else: - self.time_emb_proj = None - if self.time_embedding_norm == "ada_group": - self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) - else: - self.norm2 = torch.nn.GroupNorm( - num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True - ) - self.dropout = torch.nn.Dropout(dropout) - conv_2d_out_channels = conv_2d_out_channels or out_channels - self.conv2 = torch.nn.Conv2d( - out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1 - ) - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - elif non_linearity == "mish": - self.nonlinearity = nn.Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - elif non_linearity == "gelu": - self.nonlinearity = nn.GELU() - self.upsample = self.downsample = None - if self.up: - if kernel == "fir": - fir_kernel = 1, 3, 3, 1 - self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) - elif kernel == "sde_vp": - self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") - else: - self.upsample = Upsample2D(in_channels, use_conv=False) - elif self.down: - if kernel == "fir": - fir_kernel = 1, 3, 3, 1 - self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) - elif kernel == "sde_vp": - self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) - else: - self.downsample = Downsample2D( - in_channels, use_conv=False, padding=1, name="op" - ) - self.use_in_shortcut = ( - self.in_channels != conv_2d_out_channels - if use_in_shortcut is None - else use_in_shortcut - ) - self.conv_shortcut = None - if self.use_in_shortcut: - self.conv_shortcut = torch.nn.Conv2d( - in_channels, - conv_2d_out_channels, - kernel_size=1, - stride=1, - padding=0, - bias=conv_shortcut_bias, - ) - - def forward(self, input_tensor, temb): - hidden_states = input_tensor - if self.time_embedding_norm == "ada_group": - hidden_states = self.norm1(hidden_states, temb) - else: - hidden_states = self.norm1(hidden_states) - hidden_states = self.nonlinearity(hidden_states) - if self.upsample is not None: - if hidden_states.shape[0] >= 64: - input_tensor = input_tensor.contiguous() - hidden_states = hidden_states.contiguous() - input_tensor = self.upsample(input_tensor) - hidden_states = self.upsample(hidden_states) - elif self.downsample is not None: - input_tensor = self.downsample(input_tensor) - hidden_states = self.downsample(hidden_states) - hidden_states = self.conv1(hidden_states) - if self.time_emb_proj is not None: - temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] - if temb is not None and self.time_embedding_norm == "default": - hidden_states = hidden_states + temb - if self.time_embedding_norm == "ada_group": - hidden_states = self.norm2(hidden_states, temb) - else: - hidden_states = self.norm2(hidden_states) - if temb is not None and self.time_embedding_norm == "scale_shift": - scale, shift = torch.chunk(temb, 2, dim=1) - hidden_states = hidden_states * (1 + scale) + shift - hidden_states = self.nonlinearity(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) - if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor) - output_tensor = (input_tensor + hidden_states) / self.output_scale_factor - return output_tensor - - -class Mish(torch.nn.Module): - def forward(self, hidden_states): - return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) - - -def rearrange_dims(tensor): - if len(tensor.shape) == 2: - return tensor[:, :, None] - if len(tensor.shape) == 3: - return tensor[:, :, None, :] - elif len(tensor.shape) == 4: - return tensor[:, :, 0, :] - else: - raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") - - -class Conv1dBlock(nn.Module): - """ - Conv1d --> GroupNorm --> Mish - """ - - def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): - super().__init__() - self.conv1d = nn.Conv1d( - inp_channels, out_channels, kernel_size, padding=kernel_size // 2 - ) - self.group_norm = nn.GroupNorm(n_groups, out_channels) - self.mish = nn.Mish() - - def forward(self, x): - x = self.conv1d(x) - x = rearrange_dims(x) - x = self.group_norm(x) - x = rearrange_dims(x) - x = self.mish(x) - return x - - -class ResidualTemporalBlock1D(nn.Module): - def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): - super().__init__() - self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) - self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size) - self.time_emb_act = nn.Mish() - self.time_emb = nn.Linear(embed_dim, out_channels) - self.residual_conv = ( - nn.Conv1d(inp_channels, out_channels, 1) - if inp_channels != out_channels - else nn.Identity() - ) - - def forward(self, x, t): - """ - Args: - x : [ batch_size x inp_channels x horizon ] - t : [ batch_size x embed_dim ] - - returns: - out : [ batch_size x out_channels x horizon ] - """ - t = self.time_emb_act(t) - t = self.time_emb(t) - out = self.conv_in(x) + rearrange_dims(t) - out = self.conv_out(out) - return out + self.residual_conv(x) - - -@dataclass -class UNet1DOutput(BaseOutput): - """ - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`): - Hidden states output. Output of last layer of model. - """ - - sample: "torch.FloatTensor" - - -class LinearMultiDim(nn.Linear): - def __init__(self, in_features, out_features=None, second_dim=4, *args, **kwargs): - in_features = ( - [in_features, second_dim, 1] - if isinstance(in_features, int) - else list(in_features) - ) - if out_features is None: - out_features = in_features - out_features = ( - [out_features, second_dim, 1] - if isinstance(out_features, int) - else list(out_features) - ) - self.in_features_multidim = in_features - self.out_features_multidim = out_features - super().__init__(np.array(in_features).prod(), np.array(out_features).prod()) - - def forward(self, input_tensor, *args, **kwargs): - shape = input_tensor.shape - n_dim = len(self.in_features_multidim) - input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_features) - output_tensor = super().forward(input_tensor) - output_tensor = output_tensor.view( - *shape[0:-n_dim], *self.out_features_multidim - ) - return output_tensor - - -class ResnetBlockFlat(nn.Module): - def __init__( - self, - *, - in_channels, - out_channels=None, - dropout=0.0, - temb_channels=512, - groups=32, - groups_out=None, - pre_norm=True, - eps=1e-06, - time_embedding_norm="default", - use_in_shortcut=None, - second_dim=4, - **kwargs, - ): - super().__init__() - self.pre_norm = pre_norm - self.pre_norm = True - in_channels = ( - [in_channels, second_dim, 1] - if isinstance(in_channels, int) - else list(in_channels) - ) - self.in_channels_prod = np.array(in_channels).prod() - self.channels_multidim = in_channels - if out_channels is not None: - out_channels = ( - [out_channels, second_dim, 1] - if isinstance(out_channels, int) - else list(out_channels) - ) - out_channels_prod = np.array(out_channels).prod() - self.out_channels_multidim = out_channels - else: - out_channels_prod = self.in_channels_prod - self.out_channels_multidim = self.channels_multidim - self.time_embedding_norm = time_embedding_norm - if groups_out is None: - groups_out = groups - self.norm1 = torch.nn.GroupNorm( - num_groups=groups, num_channels=self.in_channels_prod, eps=eps, affine=True - ) - self.conv1 = torch.nn.Conv2d( - self.in_channels_prod, out_channels_prod, kernel_size=1, padding=0 - ) - if temb_channels is not None: - self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels_prod) - else: - self.time_emb_proj = None - self.norm2 = torch.nn.GroupNorm( - num_groups=groups_out, num_channels=out_channels_prod, eps=eps, affine=True - ) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d( - out_channels_prod, out_channels_prod, kernel_size=1, padding=0 - ) - self.nonlinearity = nn.SiLU() - self.use_in_shortcut = ( - self.in_channels_prod != out_channels_prod - if use_in_shortcut is None - else use_in_shortcut - ) - self.conv_shortcut = None - if self.use_in_shortcut: - self.conv_shortcut = torch.nn.Conv2d( - self.in_channels_prod, - out_channels_prod, - kernel_size=1, - stride=1, - padding=0, - ) - - def forward(self, input_tensor, temb): - shape = input_tensor.shape - n_dim = len(self.channels_multidim) - input_tensor = input_tensor.reshape( - *shape[0:-n_dim], self.in_channels_prod, 1, 1 - ) - input_tensor = input_tensor.view(-1, self.in_channels_prod, 1, 1) - hidden_states = input_tensor - hidden_states = self.norm1(hidden_states) - hidden_states = self.nonlinearity(hidden_states) - hidden_states = self.conv1(hidden_states) - if temb is not None: - temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] - hidden_states = hidden_states + temb - hidden_states = self.norm2(hidden_states) - hidden_states = self.nonlinearity(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) - if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor) - output_tensor = input_tensor + hidden_states - output_tensor = output_tensor.view(*shape[0:-n_dim], -1) - output_tensor = output_tensor.view( - *shape[0:-n_dim], *self.out_channels_multidim - ) - return output_tensor - - -class CrossAttnDownBlockFlat(nn.Module): - def __init__( - self, - in_channels: "int", - out_channels: "int", - temb_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_groups: "int" = 32, - resnet_pre_norm: "bool" = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - downsample_padding=1, - add_downsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - ): - super().__init__() - resnets = [] - attentions = [] - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlockFlat( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - LinearMultiDim( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - ) - ] - ) - else: - self.downsamplers = None - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - temb=None, - encoder_hidden_states=None, - attention_mask=None, - cross_attention_kwargs=None, - ): - output_states = () - for resnet, attn in zip(self.resnets, self.attentions): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - output_states += (hidden_states,) - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - output_states += (hidden_states,) - return hidden_states, output_states - - -class DownBlockFlat(nn.Module): - def __init__( - self, - in_channels: "int", - out_channels: "int", - temb_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_groups: "int" = 32, - resnet_pre_norm: "bool" = True, - output_scale_factor=1.0, - add_downsample=True, - downsample_padding=1, - ): - super().__init__() - resnets = [] - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlockFlat( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - self.resnets = nn.ModuleList(resnets) - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - LinearMultiDim( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - ) - ] - ) - else: - self.downsamplers = None - self.gradient_checkpointing = False - - def forward(self, hidden_states, temb=None): - output_states = () - for resnet in self.resnets: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - else: - hidden_states = resnet(hidden_states, temb) - output_states += (hidden_states,) - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - output_states += (hidden_states,) - return hidden_states, output_states - - -# def get_down_block( -# down_block_type, -# num_layers, -# in_channels, -# out_channels, -# temb_channels, -# add_downsample, -# resnet_eps, -# resnet_act_fn, -# attn_num_head_channels, -# resnet_groups=None, -# cross_attention_dim=None, -# downsample_padding=None, -# dual_cross_attention=False, -# use_linear_projection=False, -# only_cross_attention=False, -# upcast_attention=False, -# resnet_time_scale_shift="default", -# ): -# down_block_type = ( -# down_block_type[7:] -# if down_block_type.startswith("UNetRes") -# else down_block_type -# ) -# if down_block_type == "DownBlockFlat": -# return DownBlockFlat( -# num_layers=num_layers, -# in_channels=in_channels, -# out_channels=out_channels, -# temb_channels=temb_channels, -# add_downsample=add_downsample, -# resnet_eps=resnet_eps, -# resnet_act_fn=resnet_act_fn, -# resnet_groups=resnet_groups, -# downsample_padding=downsample_padding, -# resnet_time_scale_shift=resnet_time_scale_shift, -# ) -# elif down_block_type == "CrossAttnDownBlockFlat": -# if cross_attention_dim is None: -# raise ValueError( -# "cross_attention_dim must be specified for CrossAttnDownBlockFlat" -# ) -# return CrossAttnDownBlockFlat( -# num_layers=num_layers, -# in_channels=in_channels, -# out_channels=out_channels, -# temb_channels=temb_channels, -# add_downsample=add_downsample, -# resnet_eps=resnet_eps, -# resnet_act_fn=resnet_act_fn, -# resnet_groups=resnet_groups, -# downsample_padding=downsample_padding, -# cross_attention_dim=cross_attention_dim, -# attn_num_head_channels=attn_num_head_channels, -# dual_cross_attention=dual_cross_attention, -# use_linear_projection=use_linear_projection, -# only_cross_attention=only_cross_attention, -# resnet_time_scale_shift=resnet_time_scale_shift, -# ) -# raise ValueError(f"{down_block_type} is not supported.") - - -def get_down_block( - down_block_type, - num_layers, - in_channels, - out_channels, - temb_channels, - add_downsample, - resnet_eps, - resnet_act_fn, - attn_num_head_channels, - resnet_groups=None, - cross_attention_dim=None, - downsample_padding=None, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - resnet_time_scale_shift="default", -): - down_block_type = ( - down_block_type[7:] - if down_block_type.startswith("UNetRes") - else down_block_type - ) - if down_block_type == "DownBlock2D": - return DownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "ResnetDownsampleBlock2D": - return ResnetDownsampleBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "AttnDownBlock2D": - return AttnDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "CrossAttnDownBlock2D": - if cross_attention_dim is None: - raise ValueError( - "cross_attention_dim must be specified for CrossAttnDownBlock2D" - ) - return CrossAttnDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "SimpleCrossAttnDownBlock2D": - if cross_attention_dim is None: - raise ValueError( - "cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D" - ) - return SimpleCrossAttnDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "SkipDownBlock2D": - return SkipDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "AttnSkipDownBlock2D": - return AttnSkipDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - downsample_padding=downsample_padding, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "DownEncoderBlock2D": - return DownEncoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "AttnDownEncoderBlock2D": - return AttnDownEncoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "KDownBlock2D": - return KDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - ) - elif down_block_type == "KCrossAttnDownBlock2D": - return KCrossAttnDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - add_self_attention=True if not add_downsample else False, - ) - raise ValueError(f"{down_block_type} does not exist.") - - -class MidResTemporalBlock1D(nn.Module): - def __init__( - self, - in_channels, - out_channels, - embed_dim, - num_layers: "int" = 1, - add_downsample: "bool" = False, - add_upsample: "bool" = False, - non_linearity=None, - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.add_downsample = add_downsample - resnets = [ - ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim) - ] - for _ in range(num_layers): - resnets.append( - ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim) - ) - self.resnets = nn.ModuleList(resnets) - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - elif non_linearity == "mish": - self.nonlinearity = nn.Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - else: - self.nonlinearity = None - self.upsample = None - if add_upsample: - self.upsample = Downsample1D(out_channels, use_conv=True) - self.downsample = None - if add_downsample: - self.downsample = Downsample1D(out_channels, use_conv=True) - if self.upsample and self.downsample: - raise ValueError("Block cannot downsample and upsample") - - def forward(self, hidden_states, temb): - hidden_states = self.resnets[0](hidden_states, temb) - for resnet in self.resnets[1:]: - hidden_states = resnet(hidden_states, temb) - if self.upsample: - hidden_states = self.upsample(hidden_states) - if self.downsample: - self.downsample = self.downsample(hidden_states) - return hidden_states - - -_kernels = { - "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8], - "cubic": [ - -0.01171875, - -0.03515625, - 0.11328125, - 0.43359375, - 0.43359375, - 0.11328125, - -0.03515625, - -0.01171875, - ], - "lanczos3": [ - 0.003689131001010537, - 0.015056144446134567, - -0.03399861603975296, - -0.066637322306633, - 0.13550527393817902, - 0.44638532400131226, - 0.44638532400131226, - 0.13550527393817902, - -0.066637322306633, - -0.03399861603975296, - 0.015056144446134567, - 0.003689131001010537, - ], -} - - -class Downsample1d(nn.Module): - def __init__(self, kernel="linear", pad_mode="reflect"): - super().__init__() - self.pad_mode = pad_mode - kernel_1d = torch.tensor(_kernels[kernel]) - self.pad = kernel_1d.shape[0] // 2 - 1 - self.register_buffer("kernel", kernel_1d) - - def forward(self, hidden_states): - hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode) - weight = hidden_states.new_zeros( - [hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]] - ) - indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) - weight[indices, indices] = self.kernel - return F.conv1d(hidden_states, weight, stride=2) - - -class ResConvBlock(nn.Module): - def __init__(self, in_channels, mid_channels, out_channels, is_last=False): - super().__init__() - self.is_last = is_last - self.has_conv_skip = in_channels != out_channels - if self.has_conv_skip: - self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False) - self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2) - self.group_norm_1 = nn.GroupNorm(1, mid_channels) - self.gelu_1 = nn.GELU() - self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2) - if not self.is_last: - self.group_norm_2 = nn.GroupNorm(1, out_channels) - self.gelu_2 = nn.GELU() - - def forward(self, hidden_states): - residual = ( - self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states - ) - hidden_states = self.conv_1(hidden_states) - hidden_states = self.group_norm_1(hidden_states) - hidden_states = self.gelu_1(hidden_states) - hidden_states = self.conv_2(hidden_states) - if not self.is_last: - hidden_states = self.group_norm_2(hidden_states) - hidden_states = self.gelu_2(hidden_states) - output = hidden_states + residual - return output - - -class SelfAttention1d(nn.Module): - def __init__(self, in_channels, n_head=1, dropout_rate=0.0): - super().__init__() - self.channels = in_channels - self.group_norm = nn.GroupNorm(1, num_channels=in_channels) - self.num_heads = n_head - self.query = nn.Linear(self.channels, self.channels) - self.key = nn.Linear(self.channels, self.channels) - self.value = nn.Linear(self.channels, self.channels) - self.proj_attn = nn.Linear(self.channels, self.channels, 1) - self.dropout = nn.Dropout(dropout_rate, inplace=True) - - def transpose_for_scores(self, projection: "torch.Tensor") -> torch.Tensor: - new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) - new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) - return new_projection - - def forward(self, hidden_states): - residual = hidden_states - batch, channel_dim, seq = hidden_states.shape - hidden_states = self.group_norm(hidden_states) - hidden_states = hidden_states.transpose(1, 2) - query_proj = self.query(hidden_states) - key_proj = self.key(hidden_states) - value_proj = self.value(hidden_states) - query_states = self.transpose_for_scores(query_proj) - key_states = self.transpose_for_scores(key_proj) - value_states = self.transpose_for_scores(value_proj) - scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1])) - attention_scores = torch.matmul( - query_states * scale, key_states.transpose(-1, -2) * scale - ) - attention_probs = torch.softmax(attention_scores, dim=-1) - hidden_states = torch.matmul(attention_probs, value_states) - hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() - new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) - hidden_states = hidden_states.view(new_hidden_states_shape) - hidden_states = self.proj_attn(hidden_states) - hidden_states = hidden_states.transpose(1, 2) - hidden_states = self.dropout(hidden_states) - output = hidden_states + residual - return output - - -class Upsample1d(nn.Module): - def __init__(self, kernel="linear", pad_mode="reflect"): - super().__init__() - self.pad_mode = pad_mode - kernel_1d = torch.tensor(_kernels[kernel]) * 2 - self.pad = kernel_1d.shape[0] // 2 - 1 - self.register_buffer("kernel", kernel_1d) - - def forward(self, hidden_states, temb=None): - hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) - weight = hidden_states.new_zeros( - [hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]] - ) - indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) - weight[indices, indices] = self.kernel - return F.conv_transpose1d( - hidden_states, weight, stride=2, padding=self.pad * 2 + 1 - ) - - -class UNetMidBlock1D(nn.Module): - def __init__(self, mid_channels, in_channels, out_channels=None): - super().__init__() - out_channels = in_channels if out_channels is None else out_channels - self.down = Downsample1d("cubic") - resnets = [ - ResConvBlock(in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels), - ] - attentions = [ - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(out_channels, out_channels // 32), - ] - self.up = Upsample1d(kernel="cubic") - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward(self, hidden_states, temb=None): - hidden_states = self.down(hidden_states) - for attn, resnet in zip(self.attentions, self.resnets): - hidden_states = resnet(hidden_states) - hidden_states = attn(hidden_states) - hidden_states = self.up(hidden_states) - return hidden_states - - -class ValueFunctionMidBlock1D(nn.Module): - def __init__(self, in_channels, out_channels, embed_dim): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.embed_dim = embed_dim - self.res1 = ResidualTemporalBlock1D( - in_channels, in_channels // 2, embed_dim=embed_dim - ) - self.down1 = Downsample1D(out_channels // 2, use_conv=True) - self.res2 = ResidualTemporalBlock1D( - in_channels // 2, in_channels // 4, embed_dim=embed_dim - ) - self.down2 = Downsample1D(out_channels // 4, use_conv=True) - - def forward(self, x, temb=None): - x = self.res1(x, temb) - x = self.down1(x) - x = self.res2(x, temb) - x = self.down2(x) - return x - - -def get_mid_block( - mid_block_type, - num_layers, - in_channels, - mid_channels, - out_channels, - embed_dim, - add_downsample, -): - if mid_block_type == "MidResTemporalBlock1D": - return MidResTemporalBlock1D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - embed_dim=embed_dim, - add_downsample=add_downsample, - ) - elif mid_block_type == "ValueFunctionMidBlock1D": - return ValueFunctionMidBlock1D( - in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim - ) - elif mid_block_type == "UNetMidBlock1D": - return UNetMidBlock1D( - in_channels=in_channels, - mid_channels=mid_channels, - out_channels=out_channels, - ) - raise ValueError(f"{mid_block_type} does not exist.") - - -class OutConv1DBlock(nn.Module): - def __init__(self, num_groups_out, out_channels, embed_dim, act_fn): - super().__init__() - self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) - self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) - if act_fn == "silu": - self.final_conv1d_act = nn.SiLU() - if act_fn == "mish": - self.final_conv1d_act = nn.Mish() - self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) - - def forward(self, hidden_states, temb=None): - hidden_states = self.final_conv1d_1(hidden_states) - hidden_states = rearrange_dims(hidden_states) - hidden_states = self.final_conv1d_gn(hidden_states) - hidden_states = rearrange_dims(hidden_states) - hidden_states = self.final_conv1d_act(hidden_states) - hidden_states = self.final_conv1d_2(hidden_states) - return hidden_states - - -class OutValueFunctionBlock(nn.Module): - def __init__(self, fc_dim, embed_dim): - super().__init__() - self.final_block = nn.ModuleList( - [ - nn.Linear(fc_dim + embed_dim, fc_dim // 2), - nn.Mish(), - nn.Linear(fc_dim // 2, 1), - ] - ) - - def forward(self, hidden_states, temb): - hidden_states = hidden_states.view(hidden_states.shape[0], -1) - hidden_states = torch.cat((hidden_states, temb), dim=-1) - for layer in self.final_block: - hidden_states = layer(hidden_states) - return hidden_states - - -def get_out_block( - *, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim -): - if out_block_type == "OutConv1DBlock": - return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) - elif out_block_type == "ValueFunction": - return OutValueFunctionBlock(fc_dim, embed_dim) - return None - - -class CrossAttnUpBlockFlat(nn.Module): - def __init__( - self, - in_channels: "int", - out_channels: "int", - prev_output_channel: "int", - temb_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_groups: "int" = 32, - resnet_pre_norm: "bool" = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - add_upsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - ): - super().__init__() - resnets = [] - attentions = [] - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - for i in range(num_layers): - res_skip_channels = in_channels if i == num_layers - 1 else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - resnets.append( - ResnetBlockFlat( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - if add_upsample: - self.upsamplers = nn.ModuleList( - [LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)] - ) - else: - self.upsamplers = None - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - cross_attention_kwargs=None, - upsample_size=None, - attention_mask=None, - ): - for resnet, attn in zip(self.resnets, self.attentions): - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - return hidden_states - - -class UpBlockFlat(nn.Module): - def __init__( - self, - in_channels: "int", - prev_output_channel: "int", - out_channels: "int", - temb_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_groups: "int" = 32, - resnet_pre_norm: "bool" = True, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - for i in range(num_layers): - res_skip_channels = in_channels if i == num_layers - 1 else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - resnets.append( - ResnetBlockFlat( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - self.resnets = nn.ModuleList(resnets) - if add_upsample: - self.upsamplers = nn.ModuleList( - [LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)] - ) - else: - self.upsamplers = None - self.gradient_checkpointing = False - - def forward( - self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None - ): - for resnet in self.resnets: - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - else: - hidden_states = resnet(hidden_states, temb) - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - return hidden_states - - -# def get_up_block( -# up_block_type, -# num_layers, -# in_channels, -# out_channels, -# prev_output_channel, -# temb_channels, -# add_upsample, -# resnet_eps, -# resnet_act_fn, -# attn_num_head_channels, -# resnet_groups=None, -# cross_attention_dim=None, -# dual_cross_attention=False, -# use_linear_projection=False, -# only_cross_attention=False, -# upcast_attention=False, -# resnet_time_scale_shift="default", -# ): -# up_block_type = ( -# up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type -# ) -# if up_block_type == "UpBlockFlat": -# return UpBlockFlat( -# num_layers=num_layers, -# in_channels=in_channels, -# out_channels=out_channels, -# prev_output_channel=prev_output_channel, -# temb_channels=temb_channels, -# add_upsample=add_upsample, -# resnet_eps=resnet_eps, -# resnet_act_fn=resnet_act_fn, -# resnet_groups=resnet_groups, -# resnet_time_scale_shift=resnet_time_scale_shift, -# ) -# elif up_block_type == "CrossAttnUpBlockFlat": -# if cross_attention_dim is None: -# raise ValueError( -# "cross_attention_dim must be specified for CrossAttnUpBlockFlat" -# ) -# return CrossAttnUpBlockFlat( -# num_layers=num_layers, -# in_channels=in_channels, -# out_channels=out_channels, -# prev_output_channel=prev_output_channel, -# temb_channels=temb_channels, -# add_upsample=add_upsample, -# resnet_eps=resnet_eps, -# resnet_act_fn=resnet_act_fn, -# resnet_groups=resnet_groups, -# cross_attention_dim=cross_attention_dim, -# attn_num_head_channels=attn_num_head_channels, -# dual_cross_attention=dual_cross_attention, -# use_linear_projection=use_linear_projection, -# only_cross_attention=only_cross_attention, -# resnet_time_scale_shift=resnet_time_scale_shift, -# ) -# raise ValueError(f"{up_block_type} is not supported.") -# - - -def get_up_block( - up_block_type, - num_layers, - in_channels, - out_channels, - prev_output_channel, - temb_channels, - add_upsample, - resnet_eps, - resnet_act_fn, - attn_num_head_channels, - resnet_groups=None, - cross_attention_dim=None, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - resnet_time_scale_shift="default", -): - up_block_type = ( - up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type - ) - if up_block_type == "UpBlock2D": - return UpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "ResnetUpsampleBlock2D": - return ResnetUpsampleBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "CrossAttnUpBlock2D": - if cross_attention_dim is None: - raise ValueError( - "cross_attention_dim must be specified for CrossAttnUpBlock2D" - ) - return CrossAttnUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "SimpleCrossAttnUpBlock2D": - if cross_attention_dim is None: - raise ValueError( - "cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D" - ) - return SimpleCrossAttnUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "AttnUpBlock2D": - return AttnUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "SkipUpBlock2D": - return SkipUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "AttnSkipUpBlock2D": - return AttnSkipUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "UpDecoderBlock2D": - return UpDecoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "AttnUpDecoderBlock2D": - return AttnUpDecoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "KUpBlock2D": - return KUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - ) - elif up_block_type == "KCrossAttnUpBlock2D": - return KCrossAttnUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - ) - - raise ValueError(f"{up_block_type} does not exist.") - - -class UNet1DModel(ModelMixin, ConfigMixin): - """ - UNet1DModel is a 1D UNet model that takes in a noisy sample and a timestep and returns sample shaped output. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the model (such as downloading or saving, etc.) - - Parameters: - sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime. - in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. - time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. - freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for fourier time embedding. - flip_sin_to_cos (`bool`, *optional*, defaults to : - obj:`False`): Whether to flip sin to cos for fourier time embedding. - down_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`): Tuple of downsample block types. - up_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types. - block_out_channels (`Tuple[int]`, *optional*, defaults to : - obj:`(32, 32, 64)`): Tuple of block output channels. - mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet. - out_block_type (`str`, *optional*, defaults to `None`): optional output processing of UNet. - act_fn (`str`, *optional*, defaults to None): optional activitation function in UNet blocks. - norm_num_groups (`int`, *optional*, defaults to 8): group norm member count in UNet blocks. - layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block. - downsample_each_block (`int`, *optional*, defaults to False: - experimental feature for using a UNet without upsampling. - """ - - @register_to_config - def __init__( - self, - sample_size: "int" = 65536, - sample_rate: "Optional[int]" = None, - in_channels: "int" = 2, - out_channels: "int" = 2, - extra_in_channels: "int" = 0, - time_embedding_type: "str" = "fourier", - flip_sin_to_cos: "bool" = True, - use_timestep_embedding: "bool" = False, - freq_shift: "float" = 0.0, - down_block_types: "Tuple[str]" = ( - "DownBlock1DNoSkip", - "DownBlock1D", - "AttnDownBlock1D", - ), - up_block_types: "Tuple[str]" = ( - "AttnUpBlock1D", - "UpBlock1D", - "UpBlock1DNoSkip", - ), - mid_block_type: "Tuple[str]" = "UNetMidBlock1D", - out_block_type: "str" = None, - block_out_channels: "Tuple[int]" = (32, 32, 64), - act_fn: "str" = None, - norm_num_groups: "int" = 8, - layers_per_block: "int" = 1, - downsample_each_block: "bool" = False, - ): - super().__init__() - self.sample_size = sample_size - if time_embedding_type == "fourier": - self.time_proj = GaussianFourierProjection( - embedding_size=8, - set_W_to_weight=False, - log=False, - flip_sin_to_cos=flip_sin_to_cos, - ) - timestep_input_dim = 2 * block_out_channels[0] - elif time_embedding_type == "positional": - self.time_proj = Timesteps( - block_out_channels[0], - flip_sin_to_cos=flip_sin_to_cos, - downscale_freq_shift=freq_shift, - ) - timestep_input_dim = block_out_channels[0] - if use_timestep_embedding: - time_embed_dim = block_out_channels[0] * 4 - self.time_mlp = TimestepEmbedding( - in_channels=timestep_input_dim, - time_embed_dim=time_embed_dim, - act_fn=act_fn, - out_dim=block_out_channels[0], - ) - self.down_blocks = nn.ModuleList([]) - self.mid_block = None - self.up_blocks = nn.ModuleList([]) - self.out_block = None - output_channel = in_channels - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - if i == 0: - input_channel += extra_in_channels - is_final_block = i == len(block_out_channels) - 1 - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=block_out_channels[0], - add_downsample=not is_final_block or downsample_each_block, - ) - self.down_blocks.append(down_block) - self.mid_block = get_mid_block( - mid_block_type, - in_channels=block_out_channels[-1], - mid_channels=block_out_channels[-1], - out_channels=block_out_channels[-1], - embed_dim=block_out_channels[0], - num_layers=layers_per_block, - add_downsample=downsample_each_block, - ) - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - if out_block_type is None: - final_upsample_channels = out_channels - else: - final_upsample_channels = block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - prev_output_channel = output_channel - output_channel = ( - reversed_block_out_channels[i + 1] - if i < len(up_block_types) - 1 - else final_upsample_channels - ) - is_final_block = i == len(block_out_channels) - 1 - up_block = get_up_block( - up_block_type, - num_layers=layers_per_block, - in_channels=prev_output_channel, - out_channels=output_channel, - temb_channels=block_out_channels[0], - add_upsample=not is_final_block, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - num_groups_out = ( - norm_num_groups - if norm_num_groups is not None - else min(block_out_channels[0] // 4, 32) - ) - self.out_block = get_out_block( - out_block_type=out_block_type, - num_groups_out=num_groups_out, - embed_dim=block_out_channels[0], - out_channels=out_channels, - act_fn=act_fn, - fc_dim=block_out_channels[-1] // 4, - ) - - def forward( - self, - sample: "torch.FloatTensor", - timestep: "Union[torch.Tensor, float, int]", - return_dict: "bool" = True, - ) -> Union[UNet1DOutput, Tuple]: - """ - Args: - sample (`torch.FloatTensor`): `(batch_size, num_channels, sample_size)` noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple. - - Returns: - [`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True, - otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. - """ - timesteps = timestep - if not torch.is_tensor(timesteps): - timesteps = torch.tensor( - [timesteps], dtype=torch.long, device=sample.device - ) - elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: - timesteps = timesteps[None] - timestep_embed = self.time_proj(timesteps) - if self.config.use_timestep_embedding: - timestep_embed = self.time_mlp(timestep_embed) - else: - timestep_embed = timestep_embed[..., None] - timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]) - timestep_embed = timestep_embed.broadcast_to( - sample.shape[:1] + timestep_embed.shape[1:] - ) - down_block_res_samples = () - for downsample_block in self.down_blocks: - sample, res_samples = downsample_block( - hidden_states=sample, temb=timestep_embed - ) - down_block_res_samples += res_samples - if self.mid_block: - sample = self.mid_block(sample, timestep_embed) - for i, upsample_block in enumerate(self.up_blocks): - res_samples = down_block_res_samples[-1:] - down_block_res_samples = down_block_res_samples[:-1] - sample = upsample_block( - sample, res_hidden_states_tuple=res_samples, temb=timestep_embed - ) - if self.out_block: - sample = self.out_block(sample, timestep_embed) - if not return_dict: - return (sample,) - return UNet1DOutput(sample=sample) - - -class DownResnetBlock1D(nn.Module): - def __init__( - self, - in_channels, - out_channels=None, - num_layers=1, - conv_shortcut=False, - temb_channels=32, - groups=32, - groups_out=None, - non_linearity=None, - time_embedding_norm="default", - output_scale_factor=1.0, - add_downsample=True, - ): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - self.time_embedding_norm = time_embedding_norm - self.add_downsample = add_downsample - self.output_scale_factor = output_scale_factor - if groups_out is None: - groups_out = groups - resnets = [ - ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels) - ] - for _ in range(num_layers): - resnets.append( - ResidualTemporalBlock1D( - out_channels, out_channels, embed_dim=temb_channels - ) - ) - self.resnets = nn.ModuleList(resnets) - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - elif non_linearity == "mish": - self.nonlinearity = nn.Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - else: - self.nonlinearity = None - self.downsample = None - if add_downsample: - self.downsample = Downsample1D(out_channels, use_conv=True, padding=1) - - def forward(self, hidden_states, temb=None): - output_states = () - hidden_states = self.resnets[0](hidden_states, temb) - for resnet in self.resnets[1:]: - hidden_states = resnet(hidden_states, temb) - output_states += (hidden_states,) - if self.nonlinearity is not None: - hidden_states = self.nonlinearity(hidden_states) - if self.downsample is not None: - hidden_states = self.downsample(hidden_states) - return hidden_states, output_states - - -class UpResnetBlock1D(nn.Module): - def __init__( - self, - in_channels, - out_channels=None, - num_layers=1, - temb_channels=32, - groups=32, - groups_out=None, - non_linearity=None, - time_embedding_norm="default", - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.time_embedding_norm = time_embedding_norm - self.add_upsample = add_upsample - self.output_scale_factor = output_scale_factor - if groups_out is None: - groups_out = groups - resnets = [ - ResidualTemporalBlock1D( - 2 * in_channels, out_channels, embed_dim=temb_channels - ) - ] - for _ in range(num_layers): - resnets.append( - ResidualTemporalBlock1D( - out_channels, out_channels, embed_dim=temb_channels - ) - ) - self.resnets = nn.ModuleList(resnets) - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - elif non_linearity == "mish": - self.nonlinearity = nn.Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - else: - self.nonlinearity = None - self.upsample = None - if add_upsample: - self.upsample = Upsample1D(out_channels, use_conv_transpose=True) - - def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None): - if res_hidden_states_tuple is not None: - res_hidden_states = res_hidden_states_tuple[-1] - hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) - hidden_states = self.resnets[0](hidden_states, temb) - for resnet in self.resnets[1:]: - hidden_states = resnet(hidden_states, temb) - if self.nonlinearity is not None: - hidden_states = self.nonlinearity(hidden_states) - if self.upsample is not None: - hidden_states = self.upsample(hidden_states) - return hidden_states - - -class AttnDownBlock1D(nn.Module): - def __init__(self, out_channels, in_channels, mid_channels=None): - super().__init__() - mid_channels = out_channels if mid_channels is None else mid_channels - self.down = Downsample1d("cubic") - resnets = [ - ResConvBlock(in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels), - ] - attentions = [ - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(out_channels, out_channels // 32), - ] - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward(self, hidden_states, temb=None): - hidden_states = self.down(hidden_states) - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states) - hidden_states = attn(hidden_states) - return hidden_states, (hidden_states,) - - -class DownBlock1D(nn.Module): - def __init__(self, out_channels, in_channels, mid_channels=None): - super().__init__() - mid_channels = out_channels if mid_channels is None else mid_channels - self.down = Downsample1d("cubic") - resnets = [ - ResConvBlock(in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels), - ] - self.resnets = nn.ModuleList(resnets) - - def forward(self, hidden_states, temb=None): - hidden_states = self.down(hidden_states) - for resnet in self.resnets: - hidden_states = resnet(hidden_states) - return hidden_states, (hidden_states,) - - -class DownBlock1DNoSkip(nn.Module): - def __init__(self, out_channels, in_channels, mid_channels=None): - super().__init__() - mid_channels = out_channels if mid_channels is None else mid_channels - resnets = [ - ResConvBlock(in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels), - ] - self.resnets = nn.ModuleList(resnets) - - def forward(self, hidden_states, temb=None): - hidden_states = torch.cat([hidden_states, temb], dim=1) - for resnet in self.resnets: - hidden_states = resnet(hidden_states) - return hidden_states, (hidden_states,) - - -class AttnUpBlock1D(nn.Module): - def __init__(self, in_channels, out_channels, mid_channels=None): - super().__init__() - mid_channels = out_channels if mid_channels is None else mid_channels - resnets = [ - ResConvBlock(2 * in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels), - ] - attentions = [ - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(out_channels, out_channels // 32), - ] - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - self.up = Upsample1d(kernel="cubic") - - def forward(self, hidden_states, res_hidden_states_tuple, temb=None): - res_hidden_states = res_hidden_states_tuple[-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states) - hidden_states = attn(hidden_states) - hidden_states = self.up(hidden_states) - return hidden_states - - -class UpBlock1D(nn.Module): - def __init__(self, in_channels, out_channels, mid_channels=None): - super().__init__() - mid_channels = in_channels if mid_channels is None else mid_channels - resnets = [ - ResConvBlock(2 * in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels), - ] - self.resnets = nn.ModuleList(resnets) - self.up = Upsample1d(kernel="cubic") - - def forward(self, hidden_states, res_hidden_states_tuple, temb=None): - res_hidden_states = res_hidden_states_tuple[-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - for resnet in self.resnets: - hidden_states = resnet(hidden_states) - hidden_states = self.up(hidden_states) - return hidden_states - - -class UpBlock1DNoSkip(nn.Module): - def __init__(self, in_channels, out_channels, mid_channels=None): - super().__init__() - mid_channels = in_channels if mid_channels is None else mid_channels - resnets = [ - ResConvBlock(2 * in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True), - ] - self.resnets = nn.ModuleList(resnets) - - def forward(self, hidden_states, res_hidden_states_tuple, temb=None): - res_hidden_states = res_hidden_states_tuple[-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - for resnet in self.resnets: - hidden_states = resnet(hidden_states) - return hidden_states - - -@dataclass -class UNet2DOutput(BaseOutput): - """ - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Hidden states output. Output of last layer of model. - """ - - sample: "torch.FloatTensor" - - -class UNetMidBlock2D(nn.Module): - def __init__( - self, - in_channels: "int", - temb_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_groups: "int" = 32, - resnet_pre_norm: "bool" = True, - add_attention: "bool" = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - ): - super().__init__() - resnet_groups = ( - resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - ) - self.add_attention = add_attention - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - for _ in range(num_layers): - if self.add_attention: - attentions.append( - AttentionBlock( - in_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - ) - ) - else: - attentions.append(None) - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward(self, hidden_states, temb=None): - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - if attn is not None: - hidden_states = attn(hidden_states) - hidden_states = resnet(hidden_states, temb) - return hidden_states - - -class UNet2DModel(ModelMixin, ConfigMixin): - """ - UNet2DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the model (such as downloading or saving, etc.) - - Parameters: - sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): - Height and width of input/output sample. - in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. - out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. - center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. - time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use. - freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding. - flip_sin_to_cos (`bool`, *optional*, defaults to : - obj:`True`): Whether to flip sin to cos for fourier time embedding. - down_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block - types. - mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`): - The mid block type. Choose from `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`. - up_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types. - block_out_channels (`Tuple[int]`, *optional*, defaults to : - obj:`(224, 448, 672, 896)`): Tuple of block output channels. - layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block. - mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block. - downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. - norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization. - norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization. - resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config - for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. - class_embed_type (`str`, *optional*, defaults to None): - The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, - `"timestep"`, or `"identity"`. - num_class_embeds (`int`, *optional*, defaults to None): - Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing - class conditioning with `class_embed_type` equal to `None`. - """ - - @register_to_config - def __init__( - self, - sample_size: "Optional[Union[int, Tuple[int, int]]]" = None, - in_channels: "int" = 3, - out_channels: "int" = 3, - center_input_sample: "bool" = False, - time_embedding_type: "str" = "positional", - freq_shift: "int" = 0, - flip_sin_to_cos: "bool" = True, - down_block_types: "Tuple[str]" = ( - "DownBlock2D", - "AttnDownBlock2D", - "AttnDownBlock2D", - "AttnDownBlock2D", - ), - up_block_types: "Tuple[str]" = ( - "AttnUpBlock2D", - "AttnUpBlock2D", - "AttnUpBlock2D", - "UpBlock2D", - ), - block_out_channels: "Tuple[int]" = (224, 448, 672, 896), - layers_per_block: "int" = 2, - mid_block_scale_factor: "float" = 1, - downsample_padding: "int" = 1, - act_fn: "str" = "silu", - attention_head_dim: "Optional[int]" = 8, - norm_num_groups: "int" = 32, - norm_eps: "float" = 1e-05, - resnet_time_scale_shift: "str" = "default", - add_attention: "bool" = True, - class_embed_type: "Optional[str]" = None, - num_class_embeds: "Optional[int]" = None, - ): - super().__init__() - self.sample_size = sample_size - time_embed_dim = block_out_channels[0] * 4 - if len(down_block_types) != len(up_block_types): - raise ValueError( - f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." - ) - if len(block_out_channels) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." - ) - self.conv_in = nn.Conv2d( - in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1) - ) - if time_embedding_type == "fourier": - self.time_proj = GaussianFourierProjection( - embedding_size=block_out_channels[0], scale=16 - ) - timestep_input_dim = 2 * block_out_channels[0] - elif time_embedding_type == "positional": - self.time_proj = Timesteps( - block_out_channels[0], flip_sin_to_cos, freq_shift - ) - timestep_input_dim = block_out_channels[0] - self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - if class_embed_type is None and num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - elif class_embed_type == "identity": - self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) - else: - self.class_embedding = None - self.down_blocks = nn.ModuleList([]) - self.mid_block = None - self.up_blocks = nn.ModuleList([]) - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - attn_num_head_channels=attention_head_dim, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - self.down_blocks.append(down_block) - self.mid_block = UNetMidBlock2D( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift=resnet_time_scale_shift, - attn_num_head_channels=attention_head_dim, - resnet_groups=norm_num_groups, - add_attention=add_attention, - ) - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[ - min(i + 1, len(block_out_channels) - 1) - ] - is_final_block = i == len(block_out_channels) - 1 - up_block = get_up_block( - up_block_type, - num_layers=layers_per_block + 1, - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - temb_channels=time_embed_dim, - add_upsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - attn_num_head_channels=attention_head_dim, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - num_groups_out = ( - norm_num_groups - if norm_num_groups is not None - else min(block_out_channels[0] // 4, 32) - ) - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps - ) - self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d( - block_out_channels[0], out_channels, kernel_size=3, padding=1 - ) - - def forward( - self, - sample: "torch.FloatTensor", - timestep: "Union[torch.Tensor, float, int]", - class_labels: "Optional[torch.Tensor]" = None, - return_dict: "bool" = True, - ) -> Union[UNet2DOutput, Tuple]: - """ - Args: - sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps - class_labels (`torch.FloatTensor`, *optional*, defaults to `None`): - Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. - - Returns: - [`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True, - otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. - """ - if self.config.center_input_sample: - sample = 2 * sample - 1.0 - timesteps = timestep - if not torch.is_tensor(timesteps): - timesteps = torch.tensor( - [timesteps], dtype=torch.long, device=sample.device - ) - elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: - timesteps = timesteps[None] - timesteps = timesteps * torch.ones( - sample.shape[0], dtype=timesteps.dtype, device=timesteps.device - ) - t_emb = self.time_proj(timesteps) - t_emb = t_emb - emb = self.time_embedding(t_emb) - if self.class_embedding is not None: - if class_labels is None: - raise ValueError( - "class_labels should be provided when doing class conditioning" - ) - if self.config.class_embed_type == "timestep": - class_labels = self.time_proj(class_labels) - class_emb = self.class_embedding(class_labels) - emb = emb + class_emb - skip_sample = sample - sample = self.conv_in(sample) - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "skip_conv"): - sample, res_samples, skip_sample = downsample_block( - hidden_states=sample, temb=emb, skip_sample=skip_sample - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - down_block_res_samples += res_samples - sample = self.mid_block(sample, emb) - skip_sample = None - for upsample_block in self.up_blocks: - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[ - : -len(upsample_block.resnets) - ] - if hasattr(upsample_block, "skip_conv"): - sample, skip_sample = upsample_block( - sample, res_samples, emb, skip_sample - ) - else: - sample = upsample_block(sample, res_samples, emb) - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - if skip_sample is not None: - sample += skip_sample - if self.config.time_embedding_type == "fourier": - timesteps = timesteps.reshape( - (sample.shape[0], *([1] * len(sample.shape[1:]))) - ) - sample = sample / timesteps - if not return_dict: - return (sample,) - return UNet2DOutput(sample=sample) - - -class UNetMidBlock2DCrossAttn(nn.Module): - def __init__( - self, - in_channels: "int", - temb_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_groups: "int" = 32, - resnet_pre_norm: "bool" = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - cross_attention_dim=1280, - dual_cross_attention=False, - use_linear_projection=False, - upcast_attention=False, - ): - super().__init__() - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - resnet_groups = ( - resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - ) - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - for _ in range(num_layers): - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - attn_num_head_channels, - in_channels // attn_num_head_channels, - in_channels=in_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - attn_num_head_channels, - in_channels // attn_num_head_channels, - in_channels=in_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward( - self, - hidden_states, - temb=None, - encoder_hidden_states=None, - attention_mask=None, - cross_attention_kwargs=None, - ): - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - hidden_states = resnet(hidden_states, temb) - return hidden_states - - -class UNetMidBlock2DSimpleCrossAttn(nn.Module): - def __init__( - self, - in_channels: "int", - temb_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_groups: "int" = 32, - resnet_pre_norm: "bool" = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - cross_attention_dim=1280, - ): - super().__init__() - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - resnet_groups = ( - resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - ) - self.num_heads = in_channels // self.attn_num_head_channels - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - for _ in range(num_layers): - attentions.append( - CrossAttention( - query_dim=in_channels, - cross_attention_dim=in_channels, - heads=self.num_heads, - dim_head=attn_num_head_channels, - added_kv_proj_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - bias=True, - upcast_softmax=True, - processor=CrossAttnAddedKVProcessor(), - ) - ) - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward( - self, - hidden_states, - temb=None, - encoder_hidden_states=None, - attention_mask=None, - cross_attention_kwargs=None, - ): - cross_attention_kwargs = ( - cross_attention_kwargs if cross_attention_kwargs is not None else {} - ) - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - hidden_states = resnet(hidden_states, temb) - return hidden_states - - -class AttnDownBlock2D(nn.Module): - def __init__( - self, - in_channels: "int", - out_channels: "int", - temb_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_groups: "int" = 32, - resnet_pre_norm: "bool" = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - downsample_padding=1, - add_downsample=True, - ): - super().__init__() - resnets = [] - attentions = [] - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - ) - ] - ) - else: - self.downsamplers = None - - def forward(self, hidden_states, temb=None): - output_states = () - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) - output_states += (hidden_states,) - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - output_states += (hidden_states,) - return hidden_states, output_states - - -class CrossAttnDownBlock2D(nn.Module): - def __init__( - self, - in_channels: "int", - out_channels: "int", - temb_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_groups: "int" = 32, - resnet_pre_norm: "bool" = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - downsample_padding=1, - add_downsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - ): - super().__init__() - resnets = [] - attentions = [] - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - ) - ] - ) - else: - self.downsamplers = None - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - temb=None, - encoder_hidden_states=None, - attention_mask=None, - cross_attention_kwargs=None, - ): - output_states = () - for resnet, attn in zip(self.resnets, self.attentions): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - output_states += (hidden_states,) - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - output_states += (hidden_states,) - return hidden_states, output_states - - -class DownBlock2D(nn.Module): - def __init__( - self, - in_channels: "int", - out_channels: "int", - temb_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_groups: "int" = 32, - resnet_pre_norm: "bool" = True, - output_scale_factor=1.0, - add_downsample=True, - downsample_padding=1, - ): - super().__init__() - resnets = [] - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - self.resnets = nn.ModuleList(resnets) - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - ) - ] - ) - else: - self.downsamplers = None - self.gradient_checkpointing = False - - def forward(self, hidden_states, temb=None): - output_states = () - for resnet in self.resnets: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - else: - hidden_states = resnet(hidden_states, temb) - output_states += (hidden_states,) - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - output_states += (hidden_states,) - return hidden_states, output_states - - -class DownEncoderBlock2D(nn.Module): - def __init__( - self, - in_channels: "int", - out_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_groups: "int" = 32, - resnet_pre_norm: "bool" = True, - output_scale_factor=1.0, - add_downsample=True, - downsample_padding=1, - ): - super().__init__() - resnets = [] - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - self.resnets = nn.ModuleList(resnets) - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - ) - ] - ) - else: - self.downsamplers = None - - def forward(self, hidden_states): - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=None) - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - return hidden_states - - -class AttnDownEncoderBlock2D(nn.Module): - def __init__( - self, - in_channels: "int", - out_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_groups: "int" = 32, - resnet_pre_norm: "bool" = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - add_downsample=True, - downsample_padding=1, - ): - super().__init__() - resnets = [] - attentions = [] - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - ) - ] - ) - else: - self.downsamplers = None - - def forward(self, hidden_states): - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb=None) - hidden_states = attn(hidden_states) - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - return hidden_states - - -class AttnSkipDownBlock2D(nn.Module): - def __init__( - self, - in_channels: "int", - out_channels: "int", - temb_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_pre_norm: "bool" = True, - attn_num_head_channels=1, - output_scale_factor=np.sqrt(2.0), - downsample_padding=1, - add_downsample=True, - ): - super().__init__() - self.attentions = nn.ModuleList([]) - self.resnets = nn.ModuleList([]) - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - self.resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(in_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - self.attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - ) - ) - if add_downsample: - self.resnet_down = ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - use_in_shortcut=True, - down=True, - kernel="fir", - ) - self.downsamplers = nn.ModuleList( - [FirDownsample2D(out_channels, out_channels=out_channels)] - ) - self.skip_conv = nn.Conv2d( - 3, out_channels, kernel_size=(1, 1), stride=(1, 1) - ) - else: - self.resnet_down = None - self.downsamplers = None - self.skip_conv = None - - def forward(self, hidden_states, temb=None, skip_sample=None): - output_states = () - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) - output_states += (hidden_states,) - if self.downsamplers is not None: - hidden_states = self.resnet_down(hidden_states, temb) - for downsampler in self.downsamplers: - skip_sample = downsampler(skip_sample) - hidden_states = self.skip_conv(skip_sample) + hidden_states - output_states += (hidden_states,) - return hidden_states, output_states, skip_sample - - -class SkipDownBlock2D(nn.Module): - def __init__( - self, - in_channels: "int", - out_channels: "int", - temb_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_pre_norm: "bool" = True, - output_scale_factor=np.sqrt(2.0), - add_downsample=True, - downsample_padding=1, - ): - super().__init__() - self.resnets = nn.ModuleList([]) - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - self.resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(in_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - if add_downsample: - self.resnet_down = ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - use_in_shortcut=True, - down=True, - kernel="fir", - ) - self.downsamplers = nn.ModuleList( - [FirDownsample2D(out_channels, out_channels=out_channels)] - ) - self.skip_conv = nn.Conv2d( - 3, out_channels, kernel_size=(1, 1), stride=(1, 1) - ) - else: - self.resnet_down = None - self.downsamplers = None - self.skip_conv = None - - def forward(self, hidden_states, temb=None, skip_sample=None): - output_states = () - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb) - output_states += (hidden_states,) - if self.downsamplers is not None: - hidden_states = self.resnet_down(hidden_states, temb) - for downsampler in self.downsamplers: - skip_sample = downsampler(skip_sample) - hidden_states = self.skip_conv(skip_sample) + hidden_states - output_states += (hidden_states,) - return hidden_states, output_states, skip_sample - - -class ResnetDownsampleBlock2D(nn.Module): - def __init__( - self, - in_channels: "int", - out_channels: "int", - temb_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_groups: "int" = 32, - resnet_pre_norm: "bool" = True, - output_scale_factor=1.0, - add_downsample=True, - ): - super().__init__() - resnets = [] - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - self.resnets = nn.ModuleList(resnets) - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - down=True, - ) - ] - ) - else: - self.downsamplers = None - self.gradient_checkpointing = False - - def forward(self, hidden_states, temb=None): - output_states = () - for resnet in self.resnets: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - else: - hidden_states = resnet(hidden_states, temb) - output_states += (hidden_states,) - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, temb) - output_states += (hidden_states,) - return hidden_states, output_states - - -class SimpleCrossAttnDownBlock2D(nn.Module): - def __init__( - self, - in_channels: "int", - out_channels: "int", - temb_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_groups: "int" = 32, - resnet_pre_norm: "bool" = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - add_downsample=True, - ): - super().__init__() - self.has_cross_attention = True - resnets = [] - attentions = [] - self.attn_num_head_channels = attn_num_head_channels - self.num_heads = out_channels // self.attn_num_head_channels - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - CrossAttention( - query_dim=out_channels, - cross_attention_dim=out_channels, - heads=self.num_heads, - dim_head=attn_num_head_channels, - added_kv_proj_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - bias=True, - upcast_softmax=True, - processor=CrossAttnAddedKVProcessor(), - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - down=True, - ) - ] - ) - else: - self.downsamplers = None - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - temb=None, - encoder_hidden_states=None, - attention_mask=None, - cross_attention_kwargs=None, - ): - output_states = () - cross_attention_kwargs = ( - cross_attention_kwargs if cross_attention_kwargs is not None else {} - ) - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - output_states += (hidden_states,) - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, temb) - output_states += (hidden_states,) - return hidden_states, output_states - - -class KDownBlock2D(nn.Module): - def __init__( - self, - in_channels: "int", - out_channels: "int", - temb_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 4, - resnet_eps: "float" = 1e-05, - resnet_act_fn: "str" = "gelu", - resnet_group_size: "int" = 32, - add_downsample=False, - ): - super().__init__() - resnets = [] - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - groups = in_channels // resnet_group_size - groups_out = out_channels // resnet_group_size - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - dropout=dropout, - temb_channels=temb_channels, - groups=groups, - groups_out=groups_out, - eps=resnet_eps, - non_linearity=resnet_act_fn, - time_embedding_norm="ada_group", - conv_shortcut_bias=False, - ) - ) - self.resnets = nn.ModuleList(resnets) - if add_downsample: - self.downsamplers = nn.ModuleList([KDownsample2D()]) - else: - self.downsamplers = None - self.gradient_checkpointing = False - - def forward(self, hidden_states, temb=None): - output_states = () - for resnet in self.resnets: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - else: - hidden_states = resnet(hidden_states, temb) - output_states += (hidden_states,) - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - return hidden_states, output_states - - -class KAttentionBlock(nn.Module): - """ - A basic Transformer block. - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm (: - obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. - attention_bias (: - obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. - """ - - def __init__( - self, - dim: "int", - num_attention_heads: "int", - attention_head_dim: "int", - dropout: "float" = 0.0, - cross_attention_dim: "Optional[int]" = None, - attention_bias: "bool" = False, - upcast_attention: "bool" = False, - temb_channels: "int" = 768, - add_self_attention: "bool" = False, - cross_attention_norm: "bool" = False, - group_size: "int" = 32, - ): - super().__init__() - self.add_self_attention = add_self_attention - if add_self_attention: - self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) - self.attn1 = CrossAttention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=None, - cross_attention_norm=None, - ) - self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) - self.attn2 = CrossAttention( - query_dim=dim, - cross_attention_dim=cross_attention_dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - cross_attention_norm=cross_attention_norm, - ) - - def _to_3d(self, hidden_states, height, weight): - return hidden_states.permute(0, 2, 3, 1).reshape( - hidden_states.shape[0], height * weight, -1 - ) - - def _to_4d(self, hidden_states, height, weight): - return hidden_states.permute(0, 2, 1).reshape( - hidden_states.shape[0], -1, height, weight - ) - - def forward( - self, - hidden_states, - encoder_hidden_states=None, - emb=None, - attention_mask=None, - cross_attention_kwargs=None, - ): - cross_attention_kwargs = ( - cross_attention_kwargs if cross_attention_kwargs is not None else {} - ) - if self.add_self_attention: - norm_hidden_states = self.norm1(hidden_states, emb) - height, weight = norm_hidden_states.shape[2:] - norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) - attn_output = self.attn1( - norm_hidden_states, encoder_hidden_states=None, **cross_attention_kwargs - ) - attn_output = self._to_4d(attn_output, height, weight) - hidden_states = attn_output + hidden_states - norm_hidden_states = self.norm2(hidden_states, emb) - height, weight = norm_hidden_states.shape[2:] - norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - **cross_attention_kwargs, - ) - attn_output = self._to_4d(attn_output, height, weight) - hidden_states = attn_output + hidden_states - return hidden_states - - -class KCrossAttnDownBlock2D(nn.Module): - def __init__( - self, - in_channels: "int", - out_channels: "int", - temb_channels: "int", - cross_attention_dim: "int", - dropout: "float" = 0.0, - num_layers: "int" = 4, - resnet_group_size: "int" = 32, - add_downsample=True, - attn_num_head_channels: "int" = 64, - add_self_attention: "bool" = False, - resnet_eps: "float" = 1e-05, - resnet_act_fn: "str" = "gelu", - ): - super().__init__() - resnets = [] - attentions = [] - self.has_cross_attention = True - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - groups = in_channels // resnet_group_size - groups_out = out_channels // resnet_group_size - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - dropout=dropout, - temb_channels=temb_channels, - groups=groups, - groups_out=groups_out, - eps=resnet_eps, - non_linearity=resnet_act_fn, - time_embedding_norm="ada_group", - conv_shortcut_bias=False, - ) - ) - attentions.append( - KAttentionBlock( - out_channels, - out_channels // attn_num_head_channels, - attn_num_head_channels, - cross_attention_dim=cross_attention_dim, - temb_channels=temb_channels, - attention_bias=True, - add_self_attention=add_self_attention, - cross_attention_norm=True, - group_size=resnet_group_size, - ) - ) - self.resnets = nn.ModuleList(resnets) - self.attentions = nn.ModuleList(attentions) - if add_downsample: - self.downsamplers = nn.ModuleList([KDownsample2D()]) - else: - self.downsamplers = None - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - temb=None, - encoder_hidden_states=None, - attention_mask=None, - cross_attention_kwargs=None, - ): - output_states = () - for resnet, attn in zip(self.resnets, self.attentions): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - attention_mask, - cross_attention_kwargs, - ) - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - emb=temb, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) - if self.downsamplers is None: - output_states += (None,) - else: - output_states += (hidden_states,) - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - return hidden_states, output_states - - -class AttnUpBlock2D(nn.Module): - def __init__( - self, - in_channels: "int", - prev_output_channel: "int", - out_channels: "int", - temb_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_groups: "int" = 32, - resnet_pre_norm: "bool" = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - attentions = [] - for i in range(num_layers): - res_skip_channels = in_channels if i == num_layers - 1 else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - if add_upsample: - self.upsamplers = nn.ModuleList( - [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] - ) - else: - self.upsamplers = None - - def forward(self, hidden_states, res_hidden_states_tuple, temb=None): - for resnet, attn in zip(self.resnets, self.attentions): - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - return hidden_states - - -class CrossAttnUpBlock2D(nn.Module): - def __init__( - self, - in_channels: "int", - out_channels: "int", - prev_output_channel: "int", - temb_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_groups: "int" = 32, - resnet_pre_norm: "bool" = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - add_upsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - ): - super().__init__() - resnets = [] - attentions = [] - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - for i in range(num_layers): - res_skip_channels = in_channels if i == num_layers - 1 else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - if add_upsample: - self.upsamplers = nn.ModuleList( - [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] - ) - else: - self.upsamplers = None - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - cross_attention_kwargs=None, - upsample_size=None, - attention_mask=None, - ): - for resnet, attn in zip(self.resnets, self.attentions): - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - return hidden_states - - -class UpBlock2D(nn.Module): - def __init__( - self, - in_channels: "int", - prev_output_channel: "int", - out_channels: "int", - temb_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_groups: "int" = 32, - resnet_pre_norm: "bool" = True, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - for i in range(num_layers): - res_skip_channels = in_channels if i == num_layers - 1 else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - self.resnets = nn.ModuleList(resnets) - if add_upsample: - self.upsamplers = nn.ModuleList( - [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] - ) - else: - self.upsamplers = None - self.gradient_checkpointing = False - - def forward( - self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None - ): - for resnet in self.resnets: - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - else: - hidden_states = resnet(hidden_states, temb) - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - return hidden_states - - -class UpDecoderBlock2D(nn.Module): - def __init__( - self, - in_channels: "int", - out_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_groups: "int" = 32, - resnet_pre_norm: "bool" = True, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - self.resnets = nn.ModuleList(resnets) - if add_upsample: - self.upsamplers = nn.ModuleList( - [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] - ) - else: - self.upsamplers = None - - def forward(self, hidden_states): - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=None) - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - return hidden_states - - -class AttnUpDecoderBlock2D(nn.Module): - def __init__( - self, - in_channels: "int", - out_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_groups: "int" = 32, - resnet_pre_norm: "bool" = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - attentions = [] - for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - if add_upsample: - self.upsamplers = nn.ModuleList( - [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] - ) - else: - self.upsamplers = None - - def forward(self, hidden_states): - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb=None) - hidden_states = attn(hidden_states) - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - return hidden_states - - -class AttnSkipUpBlock2D(nn.Module): - def __init__( - self, - in_channels: "int", - prev_output_channel: "int", - out_channels: "int", - temb_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_pre_norm: "bool" = True, - attn_num_head_channels=1, - output_scale_factor=np.sqrt(2.0), - upsample_padding=1, - add_upsample=True, - ): - super().__init__() - self.attentions = nn.ModuleList([]) - self.resnets = nn.ModuleList([]) - for i in range(num_layers): - res_skip_channels = in_channels if i == num_layers - 1 else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - self.resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(resnet_in_channels + res_skip_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - self.attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - ) - ) - self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) - if add_upsample: - self.resnet_up = ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(out_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - use_in_shortcut=True, - up=True, - kernel="fir", - ) - self.skip_conv = nn.Conv2d( - out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) - ) - self.skip_norm = torch.nn.GroupNorm( - num_groups=min(out_channels // 4, 32), - num_channels=out_channels, - eps=resnet_eps, - affine=True, - ) - self.act = nn.SiLU() - else: - self.resnet_up = None - self.skip_conv = None - self.skip_norm = None - self.act = None - - def forward( - self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None - ): - for resnet in self.resnets: - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb) - hidden_states = self.attentions[0](hidden_states) - if skip_sample is not None: - skip_sample = self.upsampler(skip_sample) - else: - skip_sample = 0 - if self.resnet_up is not None: - skip_sample_states = self.skip_norm(hidden_states) - skip_sample_states = self.act(skip_sample_states) - skip_sample_states = self.skip_conv(skip_sample_states) - skip_sample = skip_sample + skip_sample_states - hidden_states = self.resnet_up(hidden_states, temb) - return hidden_states, skip_sample - - -class SkipUpBlock2D(nn.Module): - def __init__( - self, - in_channels: "int", - prev_output_channel: "int", - out_channels: "int", - temb_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_pre_norm: "bool" = True, - output_scale_factor=np.sqrt(2.0), - add_upsample=True, - upsample_padding=1, - ): - super().__init__() - self.resnets = nn.ModuleList([]) - for i in range(num_layers): - res_skip_channels = in_channels if i == num_layers - 1 else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - self.resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min((resnet_in_channels + res_skip_channels) // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) - if add_upsample: - self.resnet_up = ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(out_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - use_in_shortcut=True, - up=True, - kernel="fir", - ) - self.skip_conv = nn.Conv2d( - out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) - ) - self.skip_norm = torch.nn.GroupNorm( - num_groups=min(out_channels // 4, 32), - num_channels=out_channels, - eps=resnet_eps, - affine=True, - ) - self.act = nn.SiLU() - else: - self.resnet_up = None - self.skip_conv = None - self.skip_norm = None - self.act = None - - def forward( - self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None - ): - for resnet in self.resnets: - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb) - if skip_sample is not None: - skip_sample = self.upsampler(skip_sample) - else: - skip_sample = 0 - if self.resnet_up is not None: - skip_sample_states = self.skip_norm(hidden_states) - skip_sample_states = self.act(skip_sample_states) - skip_sample_states = self.skip_conv(skip_sample_states) - skip_sample = skip_sample + skip_sample_states - hidden_states = self.resnet_up(hidden_states, temb) - return hidden_states, skip_sample - - -class ResnetUpsampleBlock2D(nn.Module): - def __init__( - self, - in_channels: "int", - prev_output_channel: "int", - out_channels: "int", - temb_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_groups: "int" = 32, - resnet_pre_norm: "bool" = True, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - for i in range(num_layers): - res_skip_channels = in_channels if i == num_layers - 1 else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - self.resnets = nn.ModuleList(resnets) - if add_upsample: - self.upsamplers = nn.ModuleList( - [ - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - up=True, - ) - ] - ) - else: - self.upsamplers = None - self.gradient_checkpointing = False - - def forward( - self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None - ): - for resnet in self.resnets: - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - else: - hidden_states = resnet(hidden_states, temb) - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, temb) - return hidden_states - - -class SimpleCrossAttnUpBlock2D(nn.Module): - def __init__( - self, - in_channels: "int", - out_channels: "int", - prev_output_channel: "int", - temb_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_groups: "int" = 32, - resnet_pre_norm: "bool" = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - attentions = [] - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - self.num_heads = out_channels // self.attn_num_head_channels - for i in range(num_layers): - res_skip_channels = in_channels if i == num_layers - 1 else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - CrossAttention( - query_dim=out_channels, - cross_attention_dim=out_channels, - heads=self.num_heads, - dim_head=attn_num_head_channels, - added_kv_proj_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - bias=True, - upcast_softmax=True, - processor=CrossAttnAddedKVProcessor(), - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - if add_upsample: - self.upsamplers = nn.ModuleList( - [ - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - up=True, - ) - ] - ) - else: - self.upsamplers = None - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - upsample_size=None, - attention_mask=None, - cross_attention_kwargs=None, - ): - cross_attention_kwargs = ( - cross_attention_kwargs if cross_attention_kwargs is not None else {} - ) - for resnet, attn in zip(self.resnets, self.attentions): - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, temb) - return hidden_states - - -class KUpBlock2D(nn.Module): - def __init__( - self, - in_channels: "int", - out_channels: "int", - temb_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 5, - resnet_eps: "float" = 1e-05, - resnet_act_fn: "str" = "gelu", - resnet_group_size: "Optional[int]" = 32, - add_upsample=True, - ): - super().__init__() - resnets = [] - k_in_channels = 2 * out_channels - k_out_channels = in_channels - num_layers = num_layers - 1 - for i in range(num_layers): - in_channels = k_in_channels if i == 0 else out_channels - groups = in_channels // resnet_group_size - groups_out = out_channels // resnet_group_size - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=k_out_channels - if i == num_layers - 1 - else out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=groups, - groups_out=groups_out, - dropout=dropout, - non_linearity=resnet_act_fn, - time_embedding_norm="ada_group", - conv_shortcut_bias=False, - ) - ) - self.resnets = nn.ModuleList(resnets) - if add_upsample: - self.upsamplers = nn.ModuleList([KUpsample2D()]) - else: - self.upsamplers = None - self.gradient_checkpointing = False - - def forward( - self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None - ): - res_hidden_states_tuple = res_hidden_states_tuple[-1] - if res_hidden_states_tuple is not None: - hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) - for resnet in self.resnets: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - else: - hidden_states = resnet(hidden_states, temb) - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - return hidden_states - - -class KCrossAttnUpBlock2D(nn.Module): - def __init__( - self, - in_channels: "int", - out_channels: "int", - temb_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 4, - resnet_eps: "float" = 1e-05, - resnet_act_fn: "str" = "gelu", - resnet_group_size: "int" = 32, - attn_num_head_channels=1, - cross_attention_dim: "int" = 768, - add_upsample: "bool" = True, - upcast_attention: "bool" = False, - ): - super().__init__() - resnets = [] - attentions = [] - is_first_block = in_channels == out_channels == temb_channels - is_middle_block = in_channels != out_channels - add_self_attention = True if is_first_block else False - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - k_in_channels = out_channels if is_first_block else 2 * out_channels - k_out_channels = in_channels - num_layers = num_layers - 1 - for i in range(num_layers): - in_channels = k_in_channels if i == 0 else out_channels - groups = in_channels // resnet_group_size - groups_out = out_channels // resnet_group_size - if is_middle_block and i == num_layers - 1: - conv_2d_out_channels = k_out_channels - else: - conv_2d_out_channels = None - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - conv_2d_out_channels=conv_2d_out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=groups, - groups_out=groups_out, - dropout=dropout, - non_linearity=resnet_act_fn, - time_embedding_norm="ada_group", - conv_shortcut_bias=False, - ) - ) - attentions.append( - KAttentionBlock( - k_out_channels if i == num_layers - 1 else out_channels, - k_out_channels // attn_num_head_channels - if i == num_layers - 1 - else out_channels // attn_num_head_channels, - attn_num_head_channels, - cross_attention_dim=cross_attention_dim, - temb_channels=temb_channels, - attention_bias=True, - add_self_attention=add_self_attention, - cross_attention_norm=True, - upcast_attention=upcast_attention, - ) - ) - self.resnets = nn.ModuleList(resnets) - self.attentions = nn.ModuleList(attentions) - if add_upsample: - self.upsamplers = nn.ModuleList([KUpsample2D()]) - else: - self.upsamplers = None - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - cross_attention_kwargs=None, - upsample_size=None, - attention_mask=None, - ): - res_hidden_states_tuple = res_hidden_states_tuple[-1] - if res_hidden_states_tuple is not None: - hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) - for resnet, attn in zip(self.resnets, self.attentions): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - attention_mask, - cross_attention_kwargs, - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - emb=temb, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - return hidden_states - - -AttnProcessor = Union[ - CrossAttnProcessor, - XFormersCrossAttnProcessor, - SlicedAttnProcessor, - CrossAttnAddedKVProcessor, - SlicedAttnAddedKVProcessor, - LoRACrossAttnProcessor, - LoRAXFormersCrossAttnProcessor, -] - -LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" - -LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" - - -class UNet2DConditionLoadersMixin: - def load_attn_procs( - self, - pretrained_model_name_or_path_or_dict: "Union[str, Dict[str, torch.Tensor]]", - **kwargs, - ): - """ - Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be - defined in - [cross_attention.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py) - and be a `torch.nn.Module` class. - - - - This function is experimental and might change in the future. - - - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. - - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., - `./my_model_directory/`. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `diffusers-cli login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - subfolder (`str`, *optional*, defaults to `""`): - In case the relevant files are located inside a subfolder of the model repo (either remote in - huggingface.co or downloaded locally), you can specify the folder name here. - - mirror (`str`, *optional*): - Mirror source to accelerate downloads in China. If you are from China and have an accessibility - problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. - Please refer to the mirror site for more information. - - - - It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated - models](https://huggingface.co/docs/hub/models-gated#gated-models). - - - - - - Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use - this method in a firewalled environment. - - - """ - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - force_download = kwargs.pop("force_download", False) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - model_file = None - if not isinstance(pretrained_model_name_or_path_or_dict, dict): - if ( - is_safetensors_available() - and weight_name is None - or weight_name is not None - and weight_name.endswith(".safetensors") - ): - try: - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = safetensors.torch.load_file(model_file, device="cpu") - except EnvironmentError: - pass - if model_file is None: - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name or LORA_WEIGHT_NAME, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = torch.load(model_file, map_location="cpu") - else: - state_dict = pretrained_model_name_or_path_or_dict - attn_processors = {} - is_lora = all("lora" in k for k in state_dict.keys()) - if is_lora: - lora_grouped_dict = defaultdict(dict) - for key, value in state_dict.items(): - attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join( - key.split(".")[-3:] - ) - lora_grouped_dict[attn_processor_key][sub_key] = value - for key, value_dict in lora_grouped_dict.items(): - rank = value_dict["to_k_lora.down.weight"].shape[0] - cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] - hidden_size = value_dict["to_k_lora.up.weight"].shape[0] - attn_processors[key] = LoRACrossAttnProcessor( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - rank=rank, - ) - attn_processors[key].load_state_dict(value_dict) - else: - raise ValueError( - f"{model_file} does not seem to be in the correct format expected by LoRA training." - ) - attn_processors = {k: v for k, v in attn_processors.items()} - self.set_attn_processor(attn_processors) - - def save_attn_procs( - self, - save_directory: "Union[str, os.PathLike]", - is_main_process: "bool" = True, - weight_name: "str" = None, - save_function: "Callable" = None, - safe_serialization: "bool" = False, - **kwargs, - ): - """ - Save an attention processor to a directory, so that it can be re-loaded using the - `[`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`]` method. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to which to save. Will be created if it doesn't exist. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful when in distributed training like - TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on - the main process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful on distributed training like TPUs when one - need to replace `torch.save` by another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - """ - weight_name = weight_name or deprecate( - "weights_name", - "0.18.0", - "`weights_name` is deprecated, please use `weight_name` instead.", - take_from=kwargs, - ) - if os.path.isfile(save_directory): - logger.error( - f"Provided path ({save_directory}) should be a directory, not a file" - ) - return - if save_function is None: - if safe_serialization: - - def save_function(weights, filename): - return safetensors.torch.save_file( - weights, filename, metadata={"format": "pt"} - ) - - else: - save_function = torch.save - os.makedirs(save_directory, exist_ok=True) - model_to_save = AttnProcsLayers(self.attn_processors) - state_dict = model_to_save.state_dict() - if weight_name is None: - if safe_serialization: - weight_name = LORA_WEIGHT_NAME_SAFE - else: - weight_name = LORA_WEIGHT_NAME - save_function(state_dict, os.path.join(save_directory, weight_name)) - logger.info( - f"Model weights saved in {os.path.join(save_directory, weight_name)}" - ) - - -@dataclass -class UNet2DConditionOutput(BaseOutput): - """ - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. - """ - - sample: "torch.FloatTensor" - - -class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): - """ - UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep - and returns sample shaped output. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the models (such as downloading or saving, etc.) - - Parameters: - sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): - Height and width of input/output sample. - in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. - center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. - flip_sin_to_cos (`bool`, *optional*, defaults to `False`): - Whether to flip the sin to cos in the time embedding. - freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. - down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): - The tuple of downsample blocks to use. - mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): - The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the - mid block layer if `None`. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): - The tuple of upsample blocks to use. - only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): - Whether to include self-attention in the basic transformer blocks, see - [`~models.attention.BasicTransformerBlock`]. - block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): - The tuple of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. - downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. - mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. - If `None`, it will skip the normalization and activation layers in post-processing - norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. - cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. - attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. - resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config - for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. - class_embed_type (`str`, *optional*, defaults to None): - The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, - `"timestep"`, `"identity"`, or `"projection"`. - num_class_embeds (`int`, *optional*, defaults to None): - Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing - class conditioning with `class_embed_type` equal to `None`. - time_embedding_type (`str`, *optional*, default to `positional`): - The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. - timestep_post_act (`str, *optional*, default to `None`): - The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. - time_cond_proj_dim (`int`, *optional*, default to `None`): - The dimension of `cond_proj` layer in timestep embedding. - conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. - conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. - projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when - using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`. - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - sample_size: "Optional[int]" = None, - in_channels: "int" = 4, - out_channels: "int" = 4, - center_input_sample: "bool" = False, - flip_sin_to_cos: "bool" = True, - freq_shift: "int" = 0, - down_block_types: "Tuple[str]" = ( - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D", - ), - mid_block_type: "Optional[str]" = "UNetMidBlock2DCrossAttn", - up_block_types: "Tuple[str]" = ( - "UpBlock2D", - "CrossAttnUpBlock2D", - "CrossAttnUpBlock2D", - "CrossAttnUpBlock2D", - ), - only_cross_attention: "Union[bool, Tuple[bool]]" = False, - block_out_channels: "Tuple[int]" = (320, 640, 1280, 1280), - layers_per_block: "int" = 2, - downsample_padding: "int" = 1, - mid_block_scale_factor: "float" = 1, - act_fn: "str" = "silu", - norm_num_groups: "Optional[int]" = 32, - norm_eps: "float" = 1e-05, - cross_attention_dim: "int" = 1280, - attention_head_dim: "Union[int, Tuple[int]]" = 8, - dual_cross_attention: "bool" = False, - use_linear_projection: "bool" = False, - class_embed_type: "Optional[str]" = None, - num_class_embeds: "Optional[int]" = None, - upcast_attention: "bool" = False, - resnet_time_scale_shift: "str" = "default", - time_embedding_type: "str" = "positional", - timestep_post_act: "Optional[str]" = None, - time_cond_proj_dim: "Optional[int]" = None, - conv_in_kernel: "int" = 3, - conv_out_kernel: "int" = 3, - projection_class_embeddings_input_dim: "Optional[int]" = None, - ): - super().__init__() - self.sample_size = sample_size - if len(down_block_types) != len(up_block_types): - raise ValueError( - f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." - ) - if len(block_out_channels) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." - ) - if not isinstance(only_cross_attention, bool) and len( - only_cross_attention - ) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." - ) - if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len( - down_block_types - ): - raise ValueError( - f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." - ) - conv_in_padding = (conv_in_kernel - 1) // 2 - self.conv_in = nn.Conv2d( - in_channels, - block_out_channels[0], - kernel_size=conv_in_kernel, - padding=conv_in_padding, - ) - if time_embedding_type == "fourier": - time_embed_dim = block_out_channels[0] * 2 - if time_embed_dim % 2 != 0: - raise ValueError( - f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}." - ) - self.time_proj = GaussianFourierProjection( - time_embed_dim // 2, - set_W_to_weight=False, - log=False, - flip_sin_to_cos=flip_sin_to_cos, - ) - timestep_input_dim = time_embed_dim - elif time_embedding_type == "positional": - time_embed_dim = block_out_channels[0] * 4 - self.time_proj = Timesteps( - block_out_channels[0], flip_sin_to_cos, freq_shift - ) - timestep_input_dim = block_out_channels[0] - else: - raise ValueError( - f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`." - ) - self.time_embedding = TimestepEmbedding( - timestep_input_dim, - time_embed_dim, - act_fn=act_fn, - post_act_fn=timestep_post_act, - cond_proj_dim=time_cond_proj_dim, - ) - if class_embed_type is None and num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - elif class_embed_type == "identity": - self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) - elif class_embed_type == "projection": - if projection_class_embeddings_input_dim is None: - raise ValueError( - "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" - ) - self.class_embedding = TimestepEmbedding( - projection_class_embeddings_input_dim, time_embed_dim - ) - else: - self.class_embedding = None - self.down_blocks = nn.ModuleList([]) - self.up_blocks = nn.ModuleList([]) - if isinstance(only_cross_attention, bool): - only_cross_attention = [only_cross_attention] * len(down_block_types) - if isinstance(attention_head_dim, int): - attention_head_dim = (attention_head_dim,) * len(down_block_types) - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[i], - downsample_padding=downsample_padding, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - self.down_blocks.append(down_block) - if mid_block_type == "UNetMidBlock2DCrossAttn": - self.mid_block = UNetMidBlock2DCrossAttn( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift=resnet_time_scale_shift, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[-1], - resnet_groups=norm_num_groups, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) - elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": - self.mid_block = UNetMidBlock2DSimpleCrossAttn( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[-1], - resnet_groups=norm_num_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif mid_block_type is None: - self.mid_block = None - else: - raise ValueError(f"unknown mid_block_type : {mid_block_type}") - self.num_upsamplers = 0 - reversed_block_out_channels = list(reversed(block_out_channels)) - reversed_attention_head_dim = list(reversed(attention_head_dim)) - only_cross_attention = list(reversed(only_cross_attention)) - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - is_final_block = i == len(block_out_channels) - 1 - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[ - min(i + 1, len(block_out_channels) - 1) - ] - if not is_final_block: - add_upsample = True - self.num_upsamplers += 1 - else: - add_upsample = False - up_block = get_up_block( - up_block_type, - num_layers=layers_per_block + 1, - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - temb_channels=time_embed_dim, - add_upsample=add_upsample, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=reversed_attention_head_dim[i], - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - if norm_num_groups is not None: - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[0], - num_groups=norm_num_groups, - eps=norm_eps, - ) - self.conv_act = nn.SiLU() - else: - self.conv_norm_out = None - self.conv_act = None - conv_out_padding = (conv_out_kernel - 1) // 2 - self.conv_out = nn.Conv2d( - block_out_channels[0], - out_channels, - kernel_size=conv_out_kernel, - padding=conv_out_padding, - ) - - @property - def attn_processors(self) -> Dict[str, AttnProcessor]: - """ - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - processors = {} - - def fn_recursive_add_processors( - name: "str", - module: "torch.nn.Module", - processors: "Dict[str, AttnProcessor]", - ): - if hasattr(module, "set_processor"): - processors[f"{name}.processor"] = module.processor - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - return processors - - def set_attn_processor( - self, processor: "Union[AttnProcessor, Dict[str, AttnProcessor]]" - ): - """ - Parameters: - `processor (`dict` of `AttnProcessor` or `AttnProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - of **all** `CrossAttention` layers. - In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: - - """ - count = len(self.attn_processors.keys()) - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor( - name: "str", module: "torch.nn.Module", processor - ): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - def set_attention_slice(self, slice_size): - """ - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is - provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` - must be a multiple of `slice_size`. - """ - sliceable_head_dims = [] - - def fn_recursive_retrieve_slicable_dims(module: "torch.nn.Module"): - if hasattr(module, "set_attention_slice"): - sliceable_head_dims.append(module.sliceable_head_dim) - for child in module.children(): - fn_recursive_retrieve_slicable_dims(child) - - for module in self.children(): - fn_recursive_retrieve_slicable_dims(module) - num_slicable_layers = len(sliceable_head_dims) - if slice_size == "auto": - slice_size = [(dim // 2) for dim in sliceable_head_dims] - elif slice_size == "max": - slice_size = num_slicable_layers * [1] - slice_size = ( - num_slicable_layers * [slice_size] - if not isinstance(slice_size, list) - else slice_size - ) - if len(slice_size) != len(sliceable_head_dims): - raise ValueError( - f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." - ) - for i in range(len(slice_size)): - size = slice_size[i] - dim = sliceable_head_dims[i] - if size is not None and size > dim: - raise ValueError(f"size {size} has to be smaller or equal to {dim}.") - - def fn_recursive_set_attention_slice( - module: "torch.nn.Module", slice_size: "List[int]" - ): - if hasattr(module, "set_attention_slice"): - module.set_attention_slice(slice_size.pop()) - for child in module.children(): - fn_recursive_set_attention_slice(child, slice_size) - - reversed_slice_size = list(reversed(slice_size)) - for module in self.children(): - fn_recursive_set_attention_slice(module, reversed_slice_size) - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance( - module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D) - ): - module.gradient_checkpointing = value - - def forward( - self, - sample: "torch.FloatTensor", - timestep: "Union[torch.Tensor, float, int]", - encoder_hidden_states: "torch.Tensor", - class_labels: "Optional[torch.Tensor]" = None, - timestep_cond: "Optional[torch.Tensor]" = None, - attention_mask: "Optional[torch.Tensor]" = None, - cross_attention_kwargs: "Optional[Dict[str, Any]]" = None, - down_block_additional_residuals: "Optional[Tuple[torch.Tensor]]" = None, - mid_block_additional_residual: "Optional[torch.Tensor]" = None, - return_dict: "bool" = True, - ) -> Union[UNet2DConditionOutput, Tuple]: - """ - Args: - sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps - encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - - Returns: - [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - """ - default_overall_up_factor = 2 ** self.num_upsamplers - forward_upsample_size = False - upsample_size = None - if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - logger.info("Forward upsample size to force interpolation output size.") - forward_upsample_size = True - if attention_mask is not None: - attention_mask = (1 - attention_mask) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - if self.config.center_input_sample: - sample = 2 * sample - 1.0 - timesteps = timestep - if not torch.is_tensor(timesteps): - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None] - timesteps = timesteps.expand(sample.shape[0]) - t_emb = self.time_proj(timesteps) - t_emb = t_emb - emb = self.time_embedding(t_emb, timestep_cond) - if self.class_embedding is not None: - if class_labels is None: - raise ValueError( - "class_labels should be provided when num_class_embeds > 0" - ) - if self.config.class_embed_type == "timestep": - class_labels = self.time_proj(class_labels) - class_emb = self.class_embedding(class_labels) - emb = emb + class_emb - sample = self.conv_in(sample) - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if ( - hasattr(downsample_block, "has_cross_attention") - and downsample_block.has_cross_attention - ): - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - down_block_res_samples += res_samples - if down_block_additional_residuals is not None: - new_down_block_res_samples = () - for down_block_res_sample, down_block_additional_residual in zip( - down_block_res_samples, down_block_additional_residuals - ): - down_block_res_sample = ( - down_block_res_sample + down_block_additional_residual - ) - new_down_block_res_samples += (down_block_res_sample,) - down_block_res_samples = new_down_block_res_samples - if self.mid_block is not None: - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) - if mid_block_additional_residual is not None: - sample = sample + mid_block_additional_residual - for i, upsample_block in enumerate(self.up_blocks): - is_final_block = i == len(self.up_blocks) - 1 - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[ - : -len(upsample_block.resnets) - ] - if not is_final_block and forward_upsample_size: - upsample_size = down_block_res_samples[-1].shape[2:] - if ( - hasattr(upsample_block, "has_cross_attention") - and upsample_block.has_cross_attention - ): - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - upsample_size=upsample_size, - attention_mask=attention_mask, - ) - else: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - upsample_size=upsample_size, - ) - if self.conv_norm_out: - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - if not return_dict: - return (sample,) - return UNet2DConditionOutput(sample=sample) - - -def rename_key(key): - regex = "\\w+[.]\\d+" - pats = re.findall(regex, key) - for pat in pats: - key = key.replace(pat, "_".join(pat.split("."))) - return key - - -class Encoder(nn.Module): - def __init__( - self, - in_channels=3, - out_channels=3, - down_block_types=("DownEncoderBlock2D",), - block_out_channels=(64,), - layers_per_block=2, - norm_num_groups=32, - act_fn="silu", - double_z=True, - ): - super().__init__() - self.layers_per_block = layers_per_block - self.conv_in = torch.nn.Conv2d( - in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1 - ) - self.mid_block = None - self.down_blocks = nn.ModuleList([]) - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - down_block = get_down_block( - down_block_type, - num_layers=self.layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - add_downsample=not is_final_block, - resnet_eps=1e-06, - downsample_padding=0, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - attn_num_head_channels=None, - temb_channels=None, - ) - self.down_blocks.append(down_block) - self.mid_block = UNetMidBlock2D( - in_channels=block_out_channels[-1], - resnet_eps=1e-06, - resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_time_scale_shift="default", - attn_num_head_channels=None, - resnet_groups=norm_num_groups, - temb_channels=None, - ) - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-06 - ) - self.conv_act = nn.SiLU() - conv_out_channels = 2 * out_channels if double_z else out_channels - self.conv_out = nn.Conv2d( - block_out_channels[-1], conv_out_channels, 3, padding=1 - ) - - def forward(self, x): - sample = x - sample = self.conv_in(sample) - for down_block in self.down_blocks: - sample = down_block(sample) - sample = self.mid_block(sample) - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - return sample - - -class Decoder(nn.Module): - def __init__( - self, - in_channels=3, - out_channels=3, - up_block_types=("UpDecoderBlock2D",), - block_out_channels=(64,), - layers_per_block=2, - norm_num_groups=32, - act_fn="silu", - ): - super().__init__() - self.layers_per_block = layers_per_block - self.conv_in = nn.Conv2d( - in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1 - ) - self.mid_block = None - self.up_blocks = nn.ModuleList([]) - self.mid_block = UNetMidBlock2D( - in_channels=block_out_channels[-1], - resnet_eps=1e-06, - resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_time_scale_shift="default", - attn_num_head_channels=None, - resnet_groups=norm_num_groups, - temb_channels=None, - ) - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - up_block = get_up_block( - up_block_type, - num_layers=self.layers_per_block + 1, - in_channels=prev_output_channel, - out_channels=output_channel, - prev_output_channel=None, - add_upsample=not is_final_block, - resnet_eps=1e-06, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - attn_num_head_channels=None, - temb_channels=None, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-06 - ) - self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) - - def forward(self, z): - sample = z - sample = self.conv_in(sample) - sample = self.mid_block(sample) - for up_block in self.up_blocks: - sample = up_block(sample) - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - return sample - - -class VectorQuantizer(nn.Module): - """ - Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix - multiplications and allows for post-hoc remapping of indices. - """ - - def __init__( - self, - n_e, - vq_embed_dim, - beta, - remap=None, - unknown_index="random", - sane_index_shape=False, - legacy=True, - ): - super().__init__() - self.n_e = n_e - self.vq_embed_dim = vq_embed_dim - self.beta = beta - self.legacy = legacy - self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim) - self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) - self.remap = remap - if self.remap is not None: - self.register_buffer("used", torch.tensor(np.load(self.remap))) - self.re_embed = self.used.shape[0] - self.unknown_index = unknown_index - if self.unknown_index == "extra": - self.unknown_index = self.re_embed - self.re_embed = self.re_embed + 1 - None - else: - self.re_embed = n_e - self.sane_index_shape = sane_index_shape - - def remap_to_used(self, inds): - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used - match = (inds[:, :, None] == used[None, None, ...]).long() - new = match.argmax(-1) - unknown = match.sum(2) < 1 - if self.unknown_index == "random": - new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape) - else: - new[unknown] = self.unknown_index - return new.reshape(ishape) - - def unmap_to_all(self, inds): - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used - if self.re_embed > self.used.shape[0]: - inds[inds >= self.used.shape[0]] = 0 - back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) - return back.reshape(ishape) - - def forward(self, z): - z = z.permute(0, 2, 3, 1).contiguous() - z_flattened = z.view(-1, self.vq_embed_dim) - min_encoding_indices = torch.argmin( - torch.cdist(z_flattened, self.embedding.weight), dim=1 - ) - z_q = self.embedding(min_encoding_indices).view(z.shape) - perplexity = None - min_encodings = None - if not self.legacy: - loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean( - (z_q - z.detach()) ** 2 - ) - else: - loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean( - (z_q - z.detach()) ** 2 - ) - z_q = z + (z_q - z).detach() - z_q = z_q.permute(0, 3, 1, 2).contiguous() - if self.remap is not None: - min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) - min_encoding_indices = self.remap_to_used(min_encoding_indices) - min_encoding_indices = min_encoding_indices.reshape(-1, 1) - if self.sane_index_shape: - min_encoding_indices = min_encoding_indices.reshape( - z_q.shape[0], z_q.shape[2], z_q.shape[3] - ) - return z_q, loss, (perplexity, min_encodings, min_encoding_indices) - - def get_codebook_entry(self, indices, shape): - if self.remap is not None: - indices = indices.reshape(shape[0], -1) - indices = self.unmap_to_all(indices) - indices = indices.reshape(-1) - z_q = self.embedding(indices) - if shape is not None: - z_q = z_q.view(shape) - z_q = z_q.permute(0, 3, 1, 2).contiguous() - return z_q - - -@dataclass -class DecoderOutput(BaseOutput): - """ - Output of decoding method. - - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Decoded output sample of the model. Output of the last layer of the model. - """ - - sample: "torch.FloatTensor" - - -@dataclass -class VQEncoderOutput(BaseOutput): - """ - Output of VQModel encoding method. - - Args: - latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Encoded output sample of the model. Output of the last layer of the model. - """ - - latents: "torch.FloatTensor" - - -class VQModel(ModelMixin, ConfigMixin): - """VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray - Kavukcuoglu. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the model (such as downloading or saving, etc.) - - Parameters: - in_channels (int, *optional*, defaults to 3): Number of channels in the input image. - out_channels (int, *optional*, defaults to 3): Number of channels in the output. - down_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. - up_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. - block_out_channels (`Tuple[int]`, *optional*, defaults to : - obj:`(64,)`): Tuple of block output channels. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. - sample_size (`int`, *optional*, defaults to `32`): TODO - num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. - vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. - scaling_factor (`float`, *optional*, defaults to `0.18215`): - The component-wise standard deviation of the trained latent space computed using the first batch of the - training set. This is used to scale the latent space to have unit variance when training the diffusion - model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the - diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 - / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image - Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. - """ - - @register_to_config - def __init__( - self, - in_channels: "int" = 3, - out_channels: "int" = 3, - down_block_types: "Tuple[str]" = ("DownEncoderBlock2D",), - up_block_types: "Tuple[str]" = ("UpDecoderBlock2D",), - block_out_channels: "Tuple[int]" = (64,), - layers_per_block: "int" = 1, - act_fn: "str" = "silu", - latent_channels: "int" = 3, - sample_size: "int" = 32, - num_vq_embeddings: "int" = 256, - norm_num_groups: "int" = 32, - vq_embed_dim: "Optional[int]" = None, - scaling_factor: "float" = 0.18215, - ): - super().__init__() - self.encoder = Encoder( - in_channels=in_channels, - out_channels=latent_channels, - down_block_types=down_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - norm_num_groups=norm_num_groups, - double_z=False, - ) - vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels - self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1) - self.quantize = VectorQuantizer( - num_vq_embeddings, - vq_embed_dim, - beta=0.25, - remap=None, - sane_index_shape=False, - ) - self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1) - self.decoder = Decoder( - in_channels=latent_channels, - out_channels=out_channels, - up_block_types=up_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - norm_num_groups=norm_num_groups, - ) - - def encode( - self, x: "torch.FloatTensor", return_dict: "bool" = True - ) -> VQEncoderOutput: - h = self.encoder(x) - h = self.quant_conv(h) - if not return_dict: - return (h,) - return VQEncoderOutput(latents=h) - - def decode( - self, - h: "torch.FloatTensor", - force_not_quantize: "bool" = False, - return_dict: "bool" = True, - ) -> Union[DecoderOutput, torch.FloatTensor]: - if not force_not_quantize: - quant, emb_loss, info = self.quantize(h) - else: - quant = h - quant = self.post_quant_conv(quant) - dec = self.decoder(quant) - if not return_dict: - return (dec,) - return DecoderOutput(sample=dec) - - def forward( - self, sample: "torch.FloatTensor", return_dict: "bool" = True - ) -> Union[DecoderOutput, torch.FloatTensor]: - """ - Args: - sample (`torch.FloatTensor`): Input sample. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`DecoderOutput`] instead of a plain tuple. - """ - x = sample - h = self.encode(x).latents - dec = self.decode(h).sample - if not return_dict: - return (dec,) - return DecoderOutput(sample=dec) - - -class LDMBertAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - embed_dim: "int", - num_heads: "int", - head_dim: "int", - dropout: "float" = 0.0, - is_decoder: "bool" = False, - bias: "bool" = False, - ): - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = head_dim - self.inner_dim = head_dim * num_heads - self.scaling = self.head_dim ** -0.5 - self.is_decoder = is_decoder - self.k_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) - self.v_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) - self.q_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) - self.out_proj = nn.Linear(self.inner_dim, embed_dim) - - def _shape(self, tensor: "torch.Tensor", seq_len: "int", bsz: "int"): - return ( - tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - .transpose(1, 2) - .contiguous() - ) - - def forward( - self, - hidden_states: "torch.Tensor", - key_value_states: "Optional[torch.Tensor]" = None, - past_key_value: "Optional[Tuple[torch.Tensor]]" = None, - attention_mask: "Optional[torch.Tensor]" = None, - layer_head_mask: "Optional[torch.Tensor]" = None, - output_attentions: "bool" = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) * self.scaling - if is_cross_attention and past_key_value is not None: - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - if self.is_decoder: - past_key_value = key_states, value_states - proj_shape = bsz * self.num_heads, -1, self.head_dim - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {bsz * self.num_heads, tgt_len, src_len}, but is {attn_weights.size()}" - ) - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {bsz, 1, tgt_len, src_len}, but is {attention_mask.size()}" - ) - attn_weights = ( - attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - + attention_mask - ) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {self.num_heads,}, but is {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( - bsz, self.num_heads, tgt_len, src_len - ) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - if output_attentions: - attn_weights_reshaped = attn_weights.view( - bsz, self.num_heads, tgt_len, src_len - ) - attn_weights = attn_weights_reshaped.view( - bsz * self.num_heads, tgt_len, src_len - ) - else: - attn_weights_reshaped = None - attn_probs = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) - attn_output = torch.bmm(attn_probs, value_states) - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {bsz, self.num_heads, tgt_len, self.head_dim}, but is {attn_output.size()}" - ) - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim) - attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value - - -class ClassInstantier(OrderedDict): - def __getitem__(self, key): - content = super().__getitem__(key) - cls, kwargs = content if isinstance(content, tuple) else (content, {}) - return cls(**kwargs) - - -# ACT2CLS = { -# "gelu": GELUActivation, -# "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}), -# "gelu_fast": FastGELUActivation, -# "gelu_new": NewGELUActivation, -# "gelu_python": (GELUActivation, {"use_gelu_python": True}), -# "gelu_pytorch_tanh": PytorchGELUTanh, -# "linear": LinearActivation, -# "mish": MishActivation, -# "quick_gelu": QuickGELUActivation, -# "relu": nn.ReLU, -# "relu6": nn.ReLU6, -# "sigmoid": nn.Sigmoid, -# "silu": SiLUActivation, -# "swish": SiLUActivation, -# "tanh": nn.Tanh, -# } -# ACT2FN = ClassInstantier(ACT2CLS) - - -class LDMBertEncoderLayer(nn.Module): - def __init__(self, config: "LDMBertConfig"): - super().__init__() - self.embed_dim = config.d_model - self.self_attn = LDMBertAttention( - embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, - head_dim=config.head_dim, - dropout=config.attention_dropout, - ) - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.dropout = config.dropout - self.activation_fn = ACT2FN[config.activation_function] - self.activation_dropout = config.activation_dropout - self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) - self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) - self.final_layer_norm = nn.LayerNorm(self.embed_dim) - - def forward( - self, - hidden_states: "torch.FloatTensor", - attention_mask: "torch.FloatTensor", - layer_head_mask: "torch.FloatTensor", - output_attentions: "Optional[bool]" = False, - ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` - attention_mask (`torch.FloatTensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size - `(encoder_attention_heads,)`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - """ - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - hidden_states = nn.functional.dropout( - hidden_states, p=self.dropout, training=self.training - ) - hidden_states = residual + hidden_states - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = nn.functional.dropout( - hidden_states, p=self.activation_dropout, training=self.training - ) - hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout( - hidden_states, p=self.dropout, training=self.training - ) - hidden_states = residual + hidden_states - if hidden_states.dtype == torch.float16 and ( - torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() - ): - clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - hidden_states = torch.clamp( - hidden_states, min=-clamp_value, max=clamp_value - ) - outputs = (hidden_states,) - if output_attentions: - outputs += (attn_weights,) - return outputs - - -class PaintByExampleMapper(nn.Module): - def __init__(self, config): - super().__init__() - num_layers = (config.num_hidden_layers + 1) // 5 - hid_size = config.hidden_size - num_heads = 1 - self.blocks = nn.ModuleList( - [ - BasicTransformerBlock( - hid_size, - num_heads, - hid_size, - activation_fn="gelu", - attention_bias=True, - ) - for _ in range(num_layers) - ] - ) - - def forward(self, hidden_states): - for block in self.blocks: - hidden_states = block(hidden_states) - return hidden_states - - -class GaussianSmoothing(torch.nn.Module): - """ - Arguments: - Apply gaussian smoothing on a 1d, 2d or 3d tensor. Filtering is performed seperately for each channel in the input - using a depthwise convolution. - channels (int, sequence): Number of channels of the input tensors. Output will - have this number of channels as well. - kernel_size (int, sequence): Size of the gaussian kernel. sigma (float, sequence): Standard deviation of the - gaussian kernel. dim (int, optional): The number of dimensions of the data. - Default value is 2 (spatial). - """ - - def __init__( - self, - channels: "int" = 1, - kernel_size: "int" = 3, - sigma: "float" = 0.5, - dim: "int" = 2, - ): - super().__init__() - if isinstance(kernel_size, int): - kernel_size = [kernel_size] * dim - if isinstance(sigma, float): - sigma = [sigma] * dim - kernel = 1 - meshgrids = torch.meshgrid( - [torch.arange(size, dtype=torch.float32) for size in kernel_size] - ) - for size, std, mgrid in zip(kernel_size, sigma, meshgrids): - mean = (size - 1) / 2 - kernel *= ( - 1 - / (std * math.sqrt(2 * math.pi)) - * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2)) - ) - kernel = kernel / torch.sum(kernel) - kernel = kernel.view(1, 1, *kernel.size()) - kernel = kernel.repeat(channels, *([1] * (kernel.dim() - 1))) - self.register_buffer("weight", kernel) - self.groups = channels - if dim == 1: - self.conv = F.conv1d - elif dim == 2: - self.conv = F.conv2d - elif dim == 3: - self.conv = F.conv3d - else: - raise RuntimeError( - "Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim) - ) - - def forward(self, input): - """ - Arguments: - Apply gaussian filter to input. - input (torch.Tensor): Input to apply gaussian filter on. - Returns: - filtered (torch.Tensor): Filtered output. - """ - return self.conv(input, weight=self.weight, groups=self.groups) - - -class StableUnCLIPImageNormalizer(ModelMixin, ConfigMixin): - """ - This class is used to hold the mean and standard deviation of the CLIP embedder used in stable unCLIP. - - It is used to normalize the image embeddings before the noise is applied and un-normalize the noised image - embeddings. - """ - - @register_to_config - def __init__(self, embedding_dim: "int" = 768): - super().__init__() - self.mean = nn.Parameter(torch.zeros(1, embedding_dim)) - self.std = nn.Parameter(torch.ones(1, embedding_dim)) - - def scale(self, embeds): - embeds = (embeds - self.mean) * 1.0 / self.std - return embeds - - def unscale(self, embeds): - embeds = embeds * self.std + self.mean - return embeds - - -class UnCLIPTextProjModel(ModelMixin, ConfigMixin): - """ - Utility class for CLIP embeddings. Used to combine the image and text embeddings into a format usable by the - decoder. - - For more details, see the original paper: https://arxiv.org/abs/2204.06125 section 2.1 - """ - - @register_to_config - def __init__( - self, - *, - clip_extra_context_tokens: int = 4, - clip_embeddings_dim: int = 768, - time_embed_dim: int, - cross_attention_dim, - ): - super().__init__() - self.learned_classifier_free_guidance_embeddings = nn.Parameter( - torch.zeros(clip_embeddings_dim) - ) - self.embedding_proj = nn.Linear(clip_embeddings_dim, time_embed_dim) - self.clip_image_embeddings_project_to_time_embeddings = nn.Linear( - clip_embeddings_dim, time_embed_dim - ) - self.clip_extra_context_tokens = clip_extra_context_tokens - self.clip_extra_context_tokens_proj = nn.Linear( - clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim - ) - self.encoder_hidden_states_proj = nn.Linear( - clip_embeddings_dim, cross_attention_dim - ) - self.text_encoder_hidden_states_norm = nn.LayerNorm(cross_attention_dim) - - def forward( - self, - *, - image_embeddings, - prompt_embeds, - text_encoder_hidden_states, - do_classifier_free_guidance, - ): - if do_classifier_free_guidance: - image_embeddings_batch_size = image_embeddings.shape[0] - classifier_free_guidance_embeddings = ( - self.learned_classifier_free_guidance_embeddings.unsqueeze(0) - ) - classifier_free_guidance_embeddings = ( - classifier_free_guidance_embeddings.expand( - image_embeddings_batch_size, -1 - ) - ) - image_embeddings = torch.cat( - [classifier_free_guidance_embeddings, image_embeddings], dim=0 - ) - assert image_embeddings.shape[0] == prompt_embeds.shape[0] - batch_size = prompt_embeds.shape[0] - time_projected_prompt_embeds = self.embedding_proj(prompt_embeds) - time_projected_image_embeddings = ( - self.clip_image_embeddings_project_to_time_embeddings(image_embeddings) - ) - additive_clip_time_embeddings = ( - time_projected_image_embeddings + time_projected_prompt_embeds - ) - clip_extra_context_tokens = self.clip_extra_context_tokens_proj( - image_embeddings - ) - clip_extra_context_tokens = clip_extra_context_tokens.reshape( - batch_size, -1, self.clip_extra_context_tokens - ) - text_encoder_hidden_states = self.encoder_hidden_states_proj( - text_encoder_hidden_states - ) - text_encoder_hidden_states = self.text_encoder_hidden_states_norm( - text_encoder_hidden_states - ) - text_encoder_hidden_states = text_encoder_hidden_states.permute(0, 2, 1) - text_encoder_hidden_states = torch.cat( - [clip_extra_context_tokens, text_encoder_hidden_states], dim=2 - ) - return text_encoder_hidden_states, additive_clip_time_embeddings - - -class UNetMidBlockFlatCrossAttn(nn.Module): - def __init__( - self, - in_channels: "int", - temb_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_groups: "int" = 32, - resnet_pre_norm: "bool" = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - cross_attention_dim=1280, - dual_cross_attention=False, - use_linear_projection=False, - upcast_attention=False, - ): - super().__init__() - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - resnet_groups = ( - resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - ) - resnets = [ - ResnetBlockFlat( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - for _ in range(num_layers): - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - attn_num_head_channels, - in_channels // attn_num_head_channels, - in_channels=in_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - attn_num_head_channels, - in_channels // attn_num_head_channels, - in_channels=in_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - resnets.append( - ResnetBlockFlat( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward( - self, - hidden_states, - temb=None, - encoder_hidden_states=None, - attention_mask=None, - cross_attention_kwargs=None, - ): - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - hidden_states = resnet(hidden_states, temb) - return hidden_states - - -class UNetMidBlockFlatSimpleCrossAttn(nn.Module): - def __init__( - self, - in_channels: "int", - temb_channels: "int", - dropout: "float" = 0.0, - num_layers: "int" = 1, - resnet_eps: "float" = 1e-06, - resnet_time_scale_shift: "str" = "default", - resnet_act_fn: "str" = "swish", - resnet_groups: "int" = 32, - resnet_pre_norm: "bool" = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - cross_attention_dim=1280, - ): - super().__init__() - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - resnet_groups = ( - resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - ) - self.num_heads = in_channels // self.attn_num_head_channels - resnets = [ - ResnetBlockFlat( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - for _ in range(num_layers): - attentions.append( - CrossAttention( - query_dim=in_channels, - cross_attention_dim=in_channels, - heads=self.num_heads, - dim_head=attn_num_head_channels, - added_kv_proj_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - bias=True, - upcast_softmax=True, - processor=CrossAttnAddedKVProcessor(), - ) - ) - resnets.append( - ResnetBlockFlat( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward( - self, - hidden_states, - temb=None, - encoder_hidden_states=None, - attention_mask=None, - cross_attention_kwargs=None, - ): - cross_attention_kwargs = ( - cross_attention_kwargs if cross_attention_kwargs is not None else {} - ) - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - hidden_states = resnet(hidden_states, temb) - return hidden_states - - -class UNetFlatConditionModel(ModelMixin, ConfigMixin): - """ - UNetFlatConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a - timestep and returns sample shaped output. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the models (such as downloading or saving, etc.) - - Parameters: - sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): - Height and width of input/output sample. - in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. - center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. - flip_sin_to_cos (`bool`, *optional*, defaults to `False`): - Whether to flip the sin to cos in the time embedding. - freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. - down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`): - The tuple of downsample blocks to use. - mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`): - The mid block type. Choose from `UNetMidBlockFlatCrossAttn` or `UNetMidBlockFlatSimpleCrossAttn`, will skip - the mid block layer if `None`. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",)`): - The tuple of upsample blocks to use. - only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): - Whether to include self-attention in the basic transformer blocks, see - [`~models.attention.BasicTransformerBlock`]. - block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): - The tuple of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. - downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. - mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. - If `None`, it will skip the normalization and activation layers in post-processing - norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. - cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. - attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. - resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config - for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`. - class_embed_type (`str`, *optional*, defaults to None): - The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, - `"timestep"`, `"identity"`, or `"projection"`. - num_class_embeds (`int`, *optional*, defaults to None): - Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing - class conditioning with `class_embed_type` equal to `None`. - time_embedding_type (`str`, *optional*, default to `positional`): - The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. - timestep_post_act (`str, *optional*, default to `None`): - The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. - time_cond_proj_dim (`int`, *optional*, default to `None`): - The dimension of `cond_proj` layer in timestep embedding. - conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. - conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. - projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when - using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`. - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - sample_size: "Optional[int]" = None, - in_channels: "int" = 4, - out_channels: "int" = 4, - center_input_sample: "bool" = False, - flip_sin_to_cos: "bool" = True, - freq_shift: "int" = 0, - down_block_types: "Tuple[str]" = ( - "CrossAttnDownBlockFlat", - "CrossAttnDownBlockFlat", - "CrossAttnDownBlockFlat", - "DownBlockFlat", - ), - mid_block_type: "Optional[str]" = "UNetMidBlockFlatCrossAttn", - up_block_types: "Tuple[str]" = ( - "UpBlockFlat", - "CrossAttnUpBlockFlat", - "CrossAttnUpBlockFlat", - "CrossAttnUpBlockFlat", - ), - only_cross_attention: "Union[bool, Tuple[bool]]" = False, - block_out_channels: "Tuple[int]" = (320, 640, 1280, 1280), - layers_per_block: "int" = 2, - downsample_padding: "int" = 1, - mid_block_scale_factor: "float" = 1, - act_fn: "str" = "silu", - norm_num_groups: "Optional[int]" = 32, - norm_eps: "float" = 1e-05, - cross_attention_dim: "int" = 1280, - attention_head_dim: "Union[int, Tuple[int]]" = 8, - dual_cross_attention: "bool" = False, - use_linear_projection: "bool" = False, - class_embed_type: "Optional[str]" = None, - num_class_embeds: "Optional[int]" = None, - upcast_attention: "bool" = False, - resnet_time_scale_shift: "str" = "default", - time_embedding_type: "str" = "positional", - timestep_post_act: "Optional[str]" = None, - time_cond_proj_dim: "Optional[int]" = None, - conv_in_kernel: "int" = 3, - conv_out_kernel: "int" = 3, - projection_class_embeddings_input_dim: "Optional[int]" = None, - ): - super().__init__() - self.sample_size = sample_size - if len(down_block_types) != len(up_block_types): - raise ValueError( - f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." - ) - if len(block_out_channels) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." - ) - if not isinstance(only_cross_attention, bool) and len( - only_cross_attention - ) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." - ) - if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len( - down_block_types - ): - raise ValueError( - f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." - ) - conv_in_padding = (conv_in_kernel - 1) // 2 - self.conv_in = LinearMultiDim( - in_channels, - block_out_channels[0], - kernel_size=conv_in_kernel, - padding=conv_in_padding, - ) - if time_embedding_type == "fourier": - time_embed_dim = block_out_channels[0] * 2 - if time_embed_dim % 2 != 0: - raise ValueError( - f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}." - ) - self.time_proj = GaussianFourierProjection( - time_embed_dim // 2, - set_W_to_weight=False, - log=False, - flip_sin_to_cos=flip_sin_to_cos, - ) - timestep_input_dim = time_embed_dim - elif time_embedding_type == "positional": - time_embed_dim = block_out_channels[0] * 4 - self.time_proj = Timesteps( - block_out_channels[0], flip_sin_to_cos, freq_shift - ) - timestep_input_dim = block_out_channels[0] - else: - raise ValueError( - f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`." - ) - self.time_embedding = TimestepEmbedding( - timestep_input_dim, - time_embed_dim, - act_fn=act_fn, - post_act_fn=timestep_post_act, - cond_proj_dim=time_cond_proj_dim, - ) - if class_embed_type is None and num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - elif class_embed_type == "identity": - self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) - elif class_embed_type == "projection": - if projection_class_embeddings_input_dim is None: - raise ValueError( - "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" - ) - self.class_embedding = TimestepEmbedding( - projection_class_embeddings_input_dim, time_embed_dim - ) - else: - self.class_embedding = None - self.down_blocks = nn.ModuleList([]) - self.up_blocks = nn.ModuleList([]) - if isinstance(only_cross_attention, bool): - only_cross_attention = [only_cross_attention] * len(down_block_types) - if isinstance(attention_head_dim, int): - attention_head_dim = (attention_head_dim,) * len(down_block_types) - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[i], - downsample_padding=downsample_padding, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - self.down_blocks.append(down_block) - if mid_block_type == "UNetMidBlockFlatCrossAttn": - self.mid_block = UNetMidBlockFlatCrossAttn( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift=resnet_time_scale_shift, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[-1], - resnet_groups=norm_num_groups, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) - elif mid_block_type == "UNetMidBlockFlatSimpleCrossAttn": - self.mid_block = UNetMidBlockFlatSimpleCrossAttn( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[-1], - resnet_groups=norm_num_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif mid_block_type is None: - self.mid_block = None - else: - raise ValueError(f"unknown mid_block_type : {mid_block_type}") - self.num_upsamplers = 0 - reversed_block_out_channels = list(reversed(block_out_channels)) - reversed_attention_head_dim = list(reversed(attention_head_dim)) - only_cross_attention = list(reversed(only_cross_attention)) - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - is_final_block = i == len(block_out_channels) - 1 - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[ - min(i + 1, len(block_out_channels) - 1) - ] - if not is_final_block: - add_upsample = True - self.num_upsamplers += 1 - else: - add_upsample = False - up_block = get_up_block( - up_block_type, - num_layers=layers_per_block + 1, - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - temb_channels=time_embed_dim, - add_upsample=add_upsample, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=reversed_attention_head_dim[i], - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - if norm_num_groups is not None: - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[0], - num_groups=norm_num_groups, - eps=norm_eps, - ) - self.conv_act = nn.SiLU() - else: - self.conv_norm_out = None - self.conv_act = None - conv_out_padding = (conv_out_kernel - 1) // 2 - self.conv_out = LinearMultiDim( - block_out_channels[0], - out_channels, - kernel_size=conv_out_kernel, - padding=conv_out_padding, - ) - - @property - def attn_processors(self) -> Dict[str, AttnProcessor]: - """ - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - processors = {} - - def fn_recursive_add_processors( - name: "str", - module: "torch.nn.Module", - processors: "Dict[str, AttnProcessor]", - ): - if hasattr(module, "set_processor"): - processors[f"{name}.processor"] = module.processor - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - return processors - - def set_attn_processor( - self, processor: "Union[AttnProcessor, Dict[str, AttnProcessor]]" - ): - """ - Parameters: - `processor (`dict` of `AttnProcessor` or `AttnProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - of **all** `CrossAttention` layers. - In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: - - """ - count = len(self.attn_processors.keys()) - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor( - name: "str", module: "torch.nn.Module", processor - ): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - def set_attention_slice(self, slice_size): - """ - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is - provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` - must be a multiple of `slice_size`. - """ - sliceable_head_dims = [] - - def fn_recursive_retrieve_slicable_dims(module: "torch.nn.Module"): - if hasattr(module, "set_attention_slice"): - sliceable_head_dims.append(module.sliceable_head_dim) - for child in module.children(): - fn_recursive_retrieve_slicable_dims(child) - - for module in self.children(): - fn_recursive_retrieve_slicable_dims(module) - num_slicable_layers = len(sliceable_head_dims) - if slice_size == "auto": - slice_size = [(dim // 2) for dim in sliceable_head_dims] - elif slice_size == "max": - slice_size = num_slicable_layers * [1] - slice_size = ( - num_slicable_layers * [slice_size] - if not isinstance(slice_size, list) - else slice_size - ) - if len(slice_size) != len(sliceable_head_dims): - raise ValueError( - f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." - ) - for i in range(len(slice_size)): - size = slice_size[i] - dim = sliceable_head_dims[i] - if size is not None and size > dim: - raise ValueError(f"size {size} has to be smaller or equal to {dim}.") - - def fn_recursive_set_attention_slice( - module: "torch.nn.Module", slice_size: "List[int]" - ): - if hasattr(module, "set_attention_slice"): - module.set_attention_slice(slice_size.pop()) - for child in module.children(): - fn_recursive_set_attention_slice(child, slice_size) - - reversed_slice_size = list(reversed(slice_size)) - for module in self.children(): - fn_recursive_set_attention_slice(module, reversed_slice_size) - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance( - module, - (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat), - ): - module.gradient_checkpointing = value - - def forward( - self, - sample: "torch.FloatTensor", - timestep: "Union[torch.Tensor, float, int]", - encoder_hidden_states: "torch.Tensor", - class_labels: "Optional[torch.Tensor]" = None, - timestep_cond: "Optional[torch.Tensor]" = None, - attention_mask: "Optional[torch.Tensor]" = None, - cross_attention_kwargs: "Optional[Dict[str, Any]]" = None, - down_block_additional_residuals: "Optional[Tuple[torch.Tensor]]" = None, - mid_block_additional_residual: "Optional[torch.Tensor]" = None, - return_dict: "bool" = True, - ) -> Union[UNet2DConditionOutput, Tuple]: - """ - Args: - sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps - encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - - Returns: - [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - """ - default_overall_up_factor = 2 ** self.num_upsamplers - forward_upsample_size = False - upsample_size = None - if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - logger.info("Forward upsample size to force interpolation output size.") - forward_upsample_size = True - if attention_mask is not None: - attention_mask = (1 - attention_mask) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - if self.config.center_input_sample: - sample = 2 * sample - 1.0 - timesteps = timestep - if not torch.is_tensor(timesteps): - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None] - timesteps = timesteps.expand(sample.shape[0]) - t_emb = self.time_proj(timesteps) - t_emb = t_emb - emb = self.time_embedding(t_emb, timestep_cond) - if self.class_embedding is not None: - if class_labels is None: - raise ValueError( - "class_labels should be provided when num_class_embeds > 0" - ) - if self.config.class_embed_type == "timestep": - class_labels = self.time_proj(class_labels) - class_emb = self.class_embedding(class_labels) - emb = emb + class_emb - sample = self.conv_in(sample) - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if ( - hasattr(downsample_block, "has_cross_attention") - and downsample_block.has_cross_attention - ): - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - down_block_res_samples += res_samples - if down_block_additional_residuals is not None: - new_down_block_res_samples = () - for down_block_res_sample, down_block_additional_residual in zip( - down_block_res_samples, down_block_additional_residuals - ): - down_block_res_sample = ( - down_block_res_sample + down_block_additional_residual - ) - new_down_block_res_samples += (down_block_res_sample,) - down_block_res_samples = new_down_block_res_samples - if self.mid_block is not None: - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) - if mid_block_additional_residual is not None: - sample = sample + mid_block_additional_residual - for i, upsample_block in enumerate(self.up_blocks): - is_final_block = i == len(self.up_blocks) - 1 - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[ - : -len(upsample_block.resnets) - ] - if not is_final_block and forward_upsample_size: - upsample_size = down_block_res_samples[-1].shape[2:] - if ( - hasattr(upsample_block, "has_cross_attention") - and upsample_block.has_cross_attention - ): - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - upsample_size=upsample_size, - attention_mask=attention_mask, - ) - else: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - upsample_size=upsample_size, - ) - if self.conv_norm_out: - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - if not return_dict: - return (sample,) - return UNet2DConditionOutput(sample=sample) - - -class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin): - """ - Utility class for storing learned text embeddings for classifier free sampling - """ - - @register_to_config - def __init__( - self, - learnable: "bool", - hidden_size: "Optional[int]" = None, - length: "Optional[int]" = None, - ): - super().__init__() - self.learnable = learnable - if self.learnable: - assert ( - hidden_size is not None - ), "learnable=True requires `hidden_size` to be set" - assert length is not None, "learnable=True requires `length` to be set" - embeddings = torch.zeros(length, hidden_size) - else: - embeddings = None - self.embeddings = torch.nn.Parameter(embeddings) diff --git a/pi/models/unet/hand_imports.py b/pi/models/unet/hand_imports.py new file mode 100644 index 0000000..652d090 --- /dev/null +++ b/pi/models/unet/hand_imports.py @@ -0,0 +1,21 @@ +from diffusers.models.unet_2d_blocks import ResnetDownsampleBlock2D, AttnDownBlock2D, SimpleCrossAttnDownBlock2D, \ + SkipDownBlock2D, AttnSkipDownBlock2D, DownEncoderBlock2D, AttnDownEncoderBlock2D, KDownBlock2D, \ + KCrossAttnDownBlock2D, ResnetUpsampleBlock2D, SimpleCrossAttnUpBlock2D, AttnUpBlock2D, SkipUpBlock2D, \ + AttnSkipUpBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D, KUpBlock2D, KCrossAttnUpBlock2D +from diffusers.models import DualTransformer2DModel +from diffusers.models.attention import AdaGroupNorm +from diffusers.models.cross_attention import CrossAttnAddedKVProcessor, AttnProcessor +from diffusers.models.embeddings import GaussianFourierProjection, PatchEmbed, ImagePositionalEmbeddings +from diffusers.models.resnet import FirUpsample2D, FirDownsample2D, downsample_2d, upsample_2d +from diffusers.models.unet_2d_blocks import UNetMidBlock2DSimpleCrossAttn +from diffusers.models.unet_2d_condition import UNet2DConditionOutput +from diffusers.models.resnet import upfirdn2d_native +from diffusers.loaders import AttnProcsLayers +from diffusers.models.attention import AttentionBlock, AdaLayerNorm, GELU, ApproximateGELU +from diffusers.models.cross_attention import LoRACrossAttnProcessor, CrossAttnProcessor, LoRAXFormersCrossAttnProcessor, \ + XFormersCrossAttnProcessor, SlicedAttnAddedKVProcessor, SlicedAttnProcessor, AttnProcessor +from diffusers.models.resnet import KDownsample2D, KUpsample2D +from diffusers.models.unet_2d_blocks import KAttentionBlock +from diffusers.models.cross_attention import LoRALinearLayer, AttnProcessor +from diffusers.models.attention import AdaLayerNormZero +from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings diff --git a/pi/models/unet/linearized.py b/pi/models/unet/linearized.py new file mode 100644 index 0000000..d421b35 --- /dev/null +++ b/pi/models/unet/linearized.py @@ -0,0 +1,5014 @@ +from __future__ import annotations +import functools +import importlib +import inspect +import json +import logging +import math +import os +from collections import defaultdict +from dataclasses import dataclass, fields +from functools import partial +from pathlib import PosixPath +from typing import Optional, Callable, Tuple, Union, Dict, List, Any, OrderedDict +import sys + +import numpy as np +import torch +from torch import nn, Tensor +from torch.nn import functional as F + +from pi.models.unet.prologue import CONFIG_NAME, LORA_WEIGHT_NAME + +logger = logging.getLogger(__name__) + + +def register_to_config(init): + r""" + Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are + automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that + shouldn't be registered in the config, use the `ignore_for_config` class variable + + Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! + """ + + @functools.wraps(init) + def inner_init(self, *args, **kwargs): + # Ignore private kwargs in the init. + init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")} + if not isinstance(self, ConfigMixin): + raise RuntimeError( + f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " + "not inherit from `ConfigMixin`." + ) + + ignore = getattr(self, "ignore_for_config", []) + # Get positional arguments aligned with kwargs + new_kwargs = {} + signature = inspect.signature(init) + parameters = { + name: p.default + for i, (name, p) in enumerate(signature.parameters.items()) + if i > 0 and name not in ignore + } + for arg, name in zip(args, parameters.keys()): + new_kwargs[name] = arg + + # Then add all kwargs + new_kwargs.update( + { + k: init_kwargs.get(k, default) + for k, default in parameters.items() + if k not in ignore and k not in new_kwargs + } + ) + new_kwargs = {**config_init_kwargs, **new_kwargs} + getattr(self, "register_to_config")(**new_kwargs) + init(self, *args, **init_kwargs) + + return inner_init + + +class BaseOutput(OrderedDict): + """ + Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a + tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular + python dictionary. + + + + You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple + before. + + + """ + + def __post_init__(self): + class_fields = fields(self) + + # Safety and consistency checks + if not len(class_fields): + raise ValueError(f"{self.__class__.__name__} has no fields.") + + first_field = getattr(self, class_fields[0].name) + other_fields_are_none = all( + getattr(self, field.name) is None for field in class_fields[1:] + ) + + if other_fields_are_none and isinstance(first_field, dict): + for key, value in first_field.items(): + self[key] = value + else: + for field in class_fields: + v = getattr(self, field.name) + if v is not None: + self[field.name] = v + + def __delitem__(self, *args, **kwargs): + raise Exception( + f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance." + ) + + def setdefault(self, *args, **kwargs): + raise Exception( + f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance." + ) + + def pop(self, *args, **kwargs): + raise Exception( + f"You cannot use ``pop`` on a {self.__class__.__name__} instance." + ) + + def update(self, *args, **kwargs): + raise Exception( + f"You cannot use ``update`` on a {self.__class__.__name__} instance." + ) + + def __getitem__(self, k): + if isinstance(k, str): + inner_dict = {k: v for (k, v) in self.items()} + return inner_dict[k] + else: + return self.to_tuple()[k] + + def __setattr__(self, name, value): + if name in self.keys() and value is not None: + # Don't call self.__setitem__ to avoid recursion errors + super().__setitem__(name, value) + super().__setattr__(name, value) + + def __setitem__(self, key, value): + # Will raise a KeyException if needed + super().__setitem__(key, value) + # Don't call self.__setattr__ to avoid recursion errors + super().__setattr__(key, value) + + def to_tuple(self) -> Tuple[Any]: + """ + Convert self to a tuple containing all the attributes/keys that are not `None`. + """ + return tuple(self[k] for k in self.keys()) + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions + for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class AttnProcessor2_0: + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: CrossAttention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + ): + batch_size, sequence_length, inner_dim = hidden_states.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.cross_attention_norm: + encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = ( + num_embeds_ada_norm is not None + ) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = ( + num_embeds_ada_norm is not None + ) and norm_type == "ada_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # 1. Self-Attn + self.attn1 = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None: + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.attn2 = None + + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + if cross_attention_dim is not None: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + else: + self.norm2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + attention_mask=None, + cross_attention_kwargs=None, + class_labels=None, + ): + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + # 1. Self-Attention + cross_attention_kwargs = ( + cross_attention_kwargs if cross_attention_kwargs is not None else {} + ) + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states + if self.only_cross_attention + else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm2(hidden_states) + ) + + # 2. Cross-Attention + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) + + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states + + +class ConfigMixin: + r""" + Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all + methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with + - [`~ConfigMixin.from_config`] + - [`~ConfigMixin.save_config`] + + Class attributes: + - **config_name** (`str`) -- A filename under which the config should stored when calling + [`~ConfigMixin.save_config`] (should be overridden by parent class). + - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be + overridden by subclass). + - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass). + - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function + should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by + subclass). + """ + config_name = None + ignore_for_config = [] + has_compatibles = False + + _deprecated_kwargs = [] + + def register_to_config(self, **kwargs): + if self.config_name is None: + raise NotImplementedError( + f"Make sure that {self.__class__} has defined a class name `config_name`" + ) + # Special case for `kwargs` used in deprecation warning added to schedulers + # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument, + # or solve in a more general way. + kwargs.pop("kwargs", None) + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error(f"Can't set {key} with value {value} for {self}") + raise err + + if not hasattr(self, "_internal_dict"): + internal_dict = kwargs + else: + previous_dict = dict(self._internal_dict) + internal_dict = {**self._internal_dict, **kwargs} + logger.debug(f"Updating config from {previous_dict} to {internal_dict}") + + self._internal_dict = FrozenDict(internal_dict) + + def save_config( + self, + save_directory: Union[str, os.PathLike], + push_to_hub: bool = False, + **kwargs, + ): + """ + Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the + [`~ConfigMixin.from_config`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file will be saved (will be created if it does not exist). + """ + if os.path.isfile(save_directory): + raise AssertionError( + f"Provided path ({save_directory}) should be a directory, not a file" + ) + + os.makedirs(save_directory, exist_ok=True) + + # If we save using the predefined names, we can load using `from_config` + output_config_file = os.path.join(save_directory, self.config_name) + + self.to_json_file(output_config_file) + logger.info(f"Configuration saved in {output_config_file}") + + @classmethod + def from_config( + cls, + config: Union[FrozenDict, Dict[str, Any]] = None, + return_unused_kwargs=False, + **kwargs, + ): + r""" + Instantiate a Python class from a config dictionary + + Parameters: + config (`Dict[str, Any]`): + A config dictionary from which the Python class will be instantiated. Make sure to only load + configuration files of compatible classes. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + Whether kwargs that are not consumed by the Python class should be returned or not. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the Python class. + `**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually + overwrite same named arguments of `config`. + + Examples: + + ```python + >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler + + >>> # Download scheduler from huggingface.co and cache. + >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32") + + >>> # Instantiate DDIM scheduler class with same config as DDPM + >>> scheduler = DDIMScheduler.from_config(scheduler.config) + + >>> # Instantiate PNDM scheduler class with same config as DDPM + >>> scheduler = PNDMScheduler.from_config(scheduler.config) + ``` + """ + # <===== TO BE REMOVED WITH DEPRECATION + # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated + if "pretrained_model_name_or_path" in kwargs: + config = kwargs.pop("pretrained_model_name_or_path") + + if config is None: + raise ValueError( + "Please make sure to provide a config as the first positional argument." + ) + # ======> + + if not isinstance(config, dict): + deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`." + if "Scheduler" in cls.__name__: + deprecation_message += ( + f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead." + " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will" + " be removed in v1.0.0." + ) + elif "Model" in cls.__name__: + deprecation_message += ( + f"If you were trying to load a model, please use {cls}.load_config(...) followed by" + f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary" + " instead. This functionality will be removed in v1.0.0." + ) + deprecate( + "config-passed-as-path", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + config, kwargs = cls.load_config( + pretrained_model_name_or_path=config, + return_unused_kwargs=True, + **kwargs, + ) + + init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs) + + # Allow dtype to be specified on initialization + if "dtype" in unused_kwargs: + init_dict["dtype"] = unused_kwargs.pop("dtype") + + # add possible deprecated kwargs + for deprecated_kwarg in cls._deprecated_kwargs: + if deprecated_kwarg in unused_kwargs: + init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg) + + # Return model and optionally state and/or unused_kwargs + model = cls(**init_dict) + + # make sure to also save config parameters that might be used for compatible classes + model.register_to_config(**hidden_dict) + + # add hidden kwargs of compatible classes to unused_kwargs + unused_kwargs = {**unused_kwargs, **hidden_dict} + + if return_unused_kwargs: + return (model, unused_kwargs) + else: + return model + + @classmethod + def get_config_dict(cls, *args, **kwargs): + deprecation_message = ( + f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be" + " removed in version v1.0.0" + ) + deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False) + return cls.load_config(*args, **kwargs) + + @classmethod + def load_config( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + return_unused_kwargs=False, + **kwargs, + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + r""" + Instantiate a Python class from a config dictionary + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an + organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g., + `./my_model_directory/`. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + use_auth_token = kwargs.pop("use_auth_token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + _ = kwargs.pop("mirror", None) + subfolder = kwargs.pop("subfolder", None) + + user_agent = {"file_type": "config"} + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + + if cls.config_name is None: + raise ValueError( + "`self.config_name` is not defined. Note that one should not load a config from " + "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`" + ) + + if os.path.isfile(pretrained_model_name_or_path): + config_file = pretrained_model_name_or_path + elif os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile( + os.path.join(pretrained_model_name_or_path, cls.config_name) + ): + # Load from a PyTorch checkpoint + config_file = os.path.join( + pretrained_model_name_or_path, cls.config_name + ) + elif subfolder is not None and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) + ): + config_file = os.path.join( + pretrained_model_name_or_path, subfolder, cls.config_name + ) + else: + raise EnvironmentError( + f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." + ) + else: + try: + # Load from URL or cache if already cached + config_file = hf_hub_download( + pretrained_model_name_or_path, + filename=cls.config_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision, + ) + + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier" + " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a" + " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli" + " login`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for" + " this model name. Check the model page at" + f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}." + ) + except HTTPError as err: + raise EnvironmentError( + "There was a specific connection error when trying to load" + f" {pretrained_model_name_or_path}:\n{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to" + " run the library in offline mode at" + " 'https://huggingface.co/docs/diffusers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a {cls.config_name} file" + ) + + try: + # Load config dict + config_dict = cls._dict_from_json_file(config_file) + except (json.JSONDecodeError, UnicodeDecodeError): + raise EnvironmentError( + f"It looks like the config file at '{config_file}' is not a valid JSON file." + ) + + if return_unused_kwargs: + return config_dict, kwargs + + return config_dict + + @staticmethod + def _get_init_keys(cls): + return set(dict(inspect.signature(cls.__init__).parameters).keys()) + + @classmethod + def extract_init_dict(cls, config_dict, **kwargs): + # 0. Copy origin config dict + original_dict = {k: v for k, v in config_dict.items()} + + # 1. Retrieve expected config attributes from __init__ signature + expected_keys = cls._get_init_keys(cls) + expected_keys.remove("self") + # remove general kwargs if present in dict + if "kwargs" in expected_keys: + expected_keys.remove("kwargs") + # remove flax internal keys + if hasattr(cls, "_flax_internal_args"): + for arg in cls._flax_internal_args: + expected_keys.remove(arg) + + # 2. Remove attributes that cannot be expected from expected config attributes + # remove keys to be ignored + if len(cls.ignore_for_config) > 0: + expected_keys = expected_keys - set(cls.ignore_for_config) + + # load diffusers library to import compatible and original scheduler + diffusers_library = importlib.import_module(__name__.split(".")[0]) + + if cls.has_compatibles: + compatible_classes = [ + c for c in cls._get_compatibles() if not isinstance(c, DummyObject) + ] + else: + compatible_classes = [] + + expected_keys_comp_cls = set() + for c in compatible_classes: + expected_keys_c = cls._get_init_keys(c) + expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c) + expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls) + config_dict = { + k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls + } + + # remove attributes from orig class that cannot be expected + orig_cls_name = config_dict.pop("_class_name", cls.__name__) + if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name): + orig_cls = getattr(diffusers_library, orig_cls_name) + unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys + config_dict = { + k: v + for k, v in config_dict.items() + if k not in unexpected_keys_from_orig + } + + # remove private attributes + config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")} + + # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments + init_dict = {} + for key in expected_keys: + # if config param is passed to kwarg and is present in config dict + # it should overwrite existing config dict key + if key in kwargs and key in config_dict: + config_dict[key] = kwargs.pop(key) + + if key in kwargs: + # overwrite key + init_dict[key] = kwargs.pop(key) + elif key in config_dict: + # use value from config dict + init_dict[key] = config_dict.pop(key) + + # 4. Give nice warning if unexpected values have been passed + if len(config_dict) > 0: + logger.warning( + f"The config attributes {config_dict} were passed to {cls.__name__}, " + "but are not expected and will be ignored. Please verify your " + f"{cls.config_name} configuration file." + ) + + # 5. Give nice info if config attributes are initiliazed to default because they have not been passed + passed_keys = set(init_dict.keys()) + if len(expected_keys - passed_keys) > 0: + logger.info( + f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values." + ) + + # 6. Define unused keyword arguments + unused_kwargs = {**config_dict, **kwargs} + + # 7. Define "hidden" config parameters that were saved for compatible classes + hidden_config_dict = { + k: v for k, v in original_dict.items() if k not in init_dict + } + + return init_dict, unused_kwargs, hidden_config_dict + + @classmethod + def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + @property + def config(self) -> Dict[str, Any]: + """ + Returns the config of the class as a frozen dictionary + + Returns: + `Dict[str, Any]`: Config of the class. + """ + return self._internal_dict + + def to_json_string(self) -> str: + """ + Serializes this instance to a JSON string. + + Returns: + `str`: String containing all the attributes that make up this configuration instance in JSON format. + """ + config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {} + config_dict["_class_name"] = self.__class__.__name__ + config_dict["_diffusers_version"] = __version__ + + def to_json_saveable(value): + if isinstance(value, np.ndarray): + value = value.tolist() + elif isinstance(value, PosixPath): + value = str(value) + return value + + config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()} + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string()) + + +class CrossAttention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + processor: Optional["AttnProcessor"] = None, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = ( + cross_attention_dim if cross_attention_dim is not None else query_dim + ) + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.cross_attention_norm = cross_attention_norm + + self.scale = dim_head ** -0.5 + + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm( + num_channels=inner_dim, + num_groups=norm_num_groups, + eps=1e-5, + affine=True, + ) + else: + self.group_norm = None + + if cross_attention_norm: + self.norm_cross = nn.LayerNorm(cross_attention_dim) + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) + + # set attention processor + # We use the AttnProcessor2_0 by default when torch2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + if processor is None: + processor = ( + AttnProcessor2_0() + if hasattr(F, "scaled_dot_product_attention") + else CrossAttnProcessor() + ) + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( + self, + use_memory_efficient_attention_xformers: bool, + attention_op: Optional[Callable] = None, + ): + is_lora = hasattr(self, "processor") and isinstance( + self.processor, (LoRACrossAttnProcessor, LoRAXFormersCrossAttnProcessor) + ) + + if use_memory_efficient_attention_xformers: + if self.added_kv_proj_dim is not None: + # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP + # which uses this type of cross attention ONLY because the attention mask of format + # [0, ..., -10.000, ..., 0, ...,] is not supported + raise NotImplementedError( + "Memory efficient attention with `xformers` is currently not supported when" + " `self.added_kv_proj_dim` is defined." + ) + elif not is_xformers_available(): + raise ModuleNotFoundError( + ( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers" + ), + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" + " only available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + + if is_lora: + processor = LoRAXFormersCrossAttnProcessor( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + processor.to(self.processor.to_q_lora.up.weight.device) + else: + processor = XFormersCrossAttnProcessor(attention_op=attention_op) + else: + if is_lora: + processor = LoRACrossAttnProcessor( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + ) + processor.load_state_dict(self.processor.state_dict()) + processor.to(self.processor.to_q_lora.up.weight.device) + else: + processor = CrossAttnProcessor() + + self.set_processor(processor) + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError( + f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}." + ) + + if slice_size is not None and self.added_kv_proj_dim is not None: + processor = SlicedAttnAddedKVProcessor(slice_size) + elif slice_size is not None: + processor = SlicedAttnProcessor(slice_size) + elif self.added_kv_proj_dim is not None: + processor = CrossAttnAddedKVProcessor() + else: + processor = CrossAttnProcessor() + + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor"): + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info( + f"You are removing possibly trained weights of {self.processor} with {processor}" + ) + self._modules.pop("processor") + + self.processor = processor + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + **cross_attention_kwargs, + ): + # The `CrossAttention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor): + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape( + batch_size // head_size, seq_len, dim * head_size + ) + return tensor + + def head_to_batch_dim(self, tensor): + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape( + batch_size * head_size, seq_len, dim // head_size + ) + return tensor + + def get_attention_scores(self, query, key, attention_mask=None): + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], + query.shape[1], + key.shape[1], + dtype=query.dtype, + device=query.device, + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask(self, attention_mask, target_length, batch_size=None): + if batch_size is None: + deprecate( + "batch_size=None", + "0.0.15", + message=( + "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect" + " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to" + " `prepare_attention_mask` when preparing the attention_mask." + ), + ) + batch_size = 1 + + head_size = self.heads + if attention_mask is None: + return attention_mask + + if attention_mask.shape[-1] != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = ( + attention_mask.shape[0], + attention_mask.shape[1], + target_length, + ) + padding = torch.zeros( + padding_shape, + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + return attention_mask + + +class CrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + cross_attention_kwargs=None, + ): + # TODO(Patrick, William) - attention mask is not used + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + cross_attention_kwargs=None, + upsample_size=None, + attention_mask=None, + ): + # TODO(Patrick, William) - attention mask is not used + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class DownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class Downsample2D(nn.Module): + """ + A downsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + out_channels: + padding: + """ + + def __init__( + self, channels, use_conv=False, out_channels=None, padding=1, name="conv" + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + conv = nn.Conv2d( + self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + conv = nn.AvgPool2d(kernel_size=stride, stride=stride) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states): + # assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: + pad = (0, 1, 0, 1) + hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) + + # assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class DummyObject(type): + """ + Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by + `requires_backend` each time a user tries to access any method of that class. + """ + + def __getattr__(cls, key): + if key.startswith("_"): + return super().__getattr__(cls, key) + requires_backends(cls, cls._backends) + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class FrozenDict(OrderedDict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + for key, value in self.items(): + setattr(self, key, value) + + self.__frozen = True + + def __delitem__(self, *args, **kwargs): + raise Exception( + f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance." + ) + + def setdefault(self, *args, **kwargs): + raise Exception( + f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance." + ) + + def pop(self, *args, **kwargs): + raise Exception( + f"You cannot use ``pop`` on a {self.__class__.__name__} instance." + ) + + def update(self, *args, **kwargs): + raise Exception( + f"You cannot use ``update`` on a {self.__class__.__name__} instance." + ) + + def __setattr__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception( + f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance." + ) + super().__setattr__(name, value) + + def __setitem__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception( + f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance." + ) + super().__setitem__(name, value) + + +class GEGLU(nn.Module): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def gelu(self, gate): + # if gate.device.type != "mps": + # return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class ModelMixin(torch.nn.Module): + r""" + Base class for all models. + + [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading + and saving models. + + - **config_name** ([`str`]) -- A filename under which the model should be stored when calling + [`~models.ModelMixin.save_pretrained`]. + """ + config_name = CONFIG_NAME + _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] + _supports_gradient_checkpointing = False + + def __init__(self): + super().__init__() + + @property + def is_gradient_checkpointing(self) -> bool: + """ + Whether gradient checkpointing is activated for this model or not. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + return any( + hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing + for m in self.modules() + ) + + def enable_gradient_checkpointing(self): + """ + Activates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + if not self._supports_gradient_checkpointing: + raise ValueError( + f"{self.__class__.__name__} does not support gradient checkpointing." + ) + self.apply(partial(self._set_gradient_checkpointing, value=True)) + + def disable_gradient_checkpointing(self): + """ + Deactivates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + if self._supports_gradient_checkpointing: + self.apply(partial(self._set_gradient_checkpointing, value=False)) + + def set_use_memory_efficient_attention_xformers( + self, valid: bool, attention_op: Optional[Callable] = None + ) -> None: + # Recursively walk through all the children. + # Any children which exposes the set_use_memory_efficient_attention_xformers method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid, attention_op) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + for module in self.children(): + if isinstance(module, torch.nn.Module): + fn_recursive_set_mem_eff(module) + + def enable_xformers_memory_efficient_attention( + self, attention_op: Optional[Callable] = None + ): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + + Parameters: + attention_op (`Callable`, *optional*): + Override the default `None` operator for use as `op` argument to the + [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention) + function of xFormers. + + Examples: + + ```py + >>> import torch + >>> from diffusers import UNet2DConditionModel + >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp + + >>> model = UNet2DConditionModel.from_pretrained( + ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16 + ... ) + >>> model = model.to("cuda") + >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) + ``` + """ + self.set_use_memory_efficient_attention_xformers(True, attention_op) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.set_use_memory_efficient_attention_xformers(False) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + save_function: Callable = None, + safe_serialization: bool = False, + variant: Optional[str] = None, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + `[`~models.ModelMixin.from_pretrained`]` class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `False`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + variant (`str`, *optional*): + If specified, weights are saved in the format pytorch_model..bin. + """ + if safe_serialization and not is_safetensors_available(): + raise ImportError( + "`safe_serialization` requires the `safetensors library: `pip install safetensors`." + ) + + if os.path.isfile(save_directory): + logger.error( + f"Provided path ({save_directory}) should be a directory, not a file" + ) + return + + os.makedirs(save_directory, exist_ok=True) + + model_to_save = self + + # Attach architecture to the config + # Save the config + if is_main_process: + model_to_save.save_config(save_directory) + + # Save the model + state_dict = model_to_save.state_dict() + + weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + weights_name = _add_variant(weights_name, variant) + + # Save the model + if safe_serialization: + safetensors.torch.save_file( + state_dict, + os.path.join(save_directory, weights_name), + metadata={"format": "pt"}, + ) + else: + torch.save(state_dict, os.path.join(save_directory, weights_name)) + + logger.info( + f"Model weights saved in {os.path.join(save_directory, weights_name)}" + ) + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs + ): + r""" + Instantiate a pretrained pytorch model from a pre-trained model configuration. + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you should first set it back in training mode with `model.train()`. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., + `./my_model_directory/`. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `diffusers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + from_flax (`bool`, *optional*, defaults to `False`): + Load the model weights from a Flax checkpoint save file. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the + same device. + + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading by not initializing the weights and only loading the pre-trained weights. This + also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the + model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, + setting this argument to `True` will raise an error. + variant (`str`, *optional*): + If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is + ignored when using `from_flax`. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use + this method in a firewalled environment. + + + + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + from_flax = kwargs.pop("from_flax", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", None) + device_map = kwargs.pop("device_map", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + variant = kwargs.pop("variant", None) + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if device_map is not None and not is_accelerate_available(): + raise NotImplementedError( + "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" + " `device_map=None`. You can install accelerate with `pip install accelerate`." + ) + + # Check if we can handle device_map and dispatching the weights + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `device_map=None`." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + if low_cpu_mem_usage is False and device_map is not None: + raise ValueError( + f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and" + " dispatching. Please make sure to set `low_cpu_mem_usage=True`." + ) + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the + # Load model + + model_file = None + if from_flax: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=FLAX_WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + config, unused_kwargs = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + device_map=device_map, + **kwargs, + ) + model = cls.from_config(config, **unused_kwargs) + + # Convert the weights + from .modeling_pytorch_flax_utils import ( + load_flax_checkpoint_in_pytorch_model, + ) + + model = load_flax_checkpoint_in_pytorch_model(model, model_file) + else: + if is_safetensors_available(): + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + except: # noqa: E722 + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + + if low_cpu_mem_usage: + # Instantiate model with empty weights + with accelerate.init_empty_weights(): + config, unused_kwargs = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + device_map=device_map, + **kwargs, + ) + model = cls.from_config(config, **unused_kwargs) + + # if device_map is None, load the state dict and move the params from meta device to the cpu + if device_map is None: + param_device = "cpu" + state_dict = load_state_dict(model_file, variant=variant) + # move the params from meta device to cpu + missing_keys = set(model.state_dict().keys()) - set( + state_dict.keys() + ) + if len(missing_keys) > 0: + raise ValueError( + f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" + f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" + " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize" + " those weights or else make sure your checkpoint file is correct." + ) + + for param_name, param in state_dict.items(): + accepts_dtype = "dtype" in set( + inspect.signature( + set_module_tensor_to_device + ).parameters.keys() + ) + if accepts_dtype: + set_module_tensor_to_device( + model, + param_name, + param_device, + value=param, + dtype=torch_dtype, + ) + else: + set_module_tensor_to_device( + model, param_name, param_device, value=param + ) + else: # else let accelerate handle loading and dispatching. + # Load weights and dispatch according to the device_map + # by deafult the device_map is None and the weights are loaded on the CPU + accelerate.load_checkpoint_and_dispatch( + model, model_file, device_map, dtype=torch_dtype + ) + + loading_info = { + "missing_keys": [], + "unexpected_keys": [], + "mismatched_keys": [], + "error_msgs": [], + } + else: + config, unused_kwargs = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + device_map=device_map, + **kwargs, + ) + model = cls.from_config(config, **unused_kwargs) + + state_dict = load_state_dict(model_file, variant=variant) + + ( + model, + missing_keys, + unexpected_keys, + mismatched_keys, + error_msgs, + ) = cls._load_pretrained_model( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + ) + + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + elif torch_dtype is not None: + model = model.to(torch_dtype) + + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + if output_loading_info: + return model, loading_info + + return model + + @classmethod + def _load_pretrained_model( + cls, + model, + state_dict, + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=False, + ): + # Retrieve missing & unexpected_keys + model_state_dict = model.state_dict() + loaded_keys = [k for k in state_dict.keys()] + + expected_keys = list(model_state_dict.keys()) + + original_loaded_keys = loaded_keys + + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + # Make sure we are able to load base models as well as derived models (with heads) + model_to_load = model + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape + != model_state_dict[model_key].shape + ): + mismatched_keys.append( + ( + checkpoint_key, + state_dict[checkpoint_key].shape, + model_state_dict[model_key].shape, + ) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = _load_state_dict_into_model(model_to_load, state_dict) + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + raise RuntimeError( + f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}" + ) + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" + " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" + " identical (initializing a BertForSequenceClassification model from a" + " BertForSequenceClassification model)." + ) + else: + logger.info( + f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n" + ) + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" + f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" + " without further training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" + " able to use it for predictions and inference." + ) + + return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + + @property + def device(self) -> device: + """ + `torch.device`: The device on which the module is (assuming that all the module parameters are on the same + device). + """ + return get_parameter_device(self) + + @property + def dtype(self) -> torch.dtype: + """ + `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + return get_parameter_dtype(self) + + def num_parameters( + self, only_trainable: bool = False, exclude_embeddings: bool = False + ) -> int: + """ + Get number of (optionally, trainable or non-embeddings) parameters in the module. + + Args: + only_trainable (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of trainable parameters + + exclude_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of non-embeddings parameters + + Returns: + `int`: The number of parameters. + """ + + if exclude_embeddings: + embedding_param_names = [ + f"{name}.weight" + for name, module_type in self.named_modules() + if isinstance(module_type, torch.nn.Embedding) + ] + non_embedding_parameters = [ + parameter + for name, parameter in self.named_parameters() + if name not in embedding_param_names + ] + return sum( + p.numel() + for p in non_embedding_parameters + if p.requires_grad or not only_trainable + ) + else: + return sum( + p.numel() + for p in self.parameters() + if p.requires_grad or not only_trainable + ) + + +class OptionalDependencyNotAvailable(BaseException): + """An error indicating that an optional dependency of Diffusers was not found in the environment.""" + + +class ResnetBlock2D(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv2d layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + groups_out (`int`, *optional*, default to None): + The number of groups to use for the second normalization layer. if set to None, same as `groups`. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. + time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. + By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or + "ada_group" for a stronger conditioning with scale and shift. + kernal (`torch.FloatTensor`, optional, default to None): FIR filter, see + [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. + output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. + use_in_shortcut (`bool`, *optional*, default to `True`): + If `True`, add a 1x1 nn.conv2d layer for skip-connection. + up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. + down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. + conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the + `conv_shortcut` output. + conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. + If None, same as `out_channels`. + """ + + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", # default, scale_shift, ada_group + kernel=None, + output_scale_factor=1.0, + use_in_shortcut=None, + up=False, + down=False, + conv_shortcut_bias: bool = True, + conv_2d_out_channels: Optional[int] = None, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.up = up + self.down = down + self.output_scale_factor = output_scale_factor + self.time_embedding_norm = time_embedding_norm + + if groups_out is None: + groups_out = groups + + if self.time_embedding_norm == "ada_group": + self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) + else: + self.norm1 = torch.nn.GroupNorm( + num_groups=groups, num_channels=in_channels, eps=eps, affine=True + ) + + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) + elif self.time_embedding_norm == "scale_shift": + self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) + elif self.time_embedding_norm == "ada_group": + self.time_emb_proj = None + else: + raise ValueError( + f"unknown time_embedding_norm : {self.time_embedding_norm} " + ) + else: + self.time_emb_proj = None + + if self.time_embedding_norm == "ada_group": + self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) + else: + self.norm2 = torch.nn.GroupNorm( + num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True + ) + + self.dropout = torch.nn.Dropout(dropout) + conv_2d_out_channels = conv_2d_out_channels or out_channels + self.conv2 = torch.nn.Conv2d( + out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1 + ) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = nn.Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + elif non_linearity == "gelu": + self.nonlinearity = nn.GELU() + + self.upsample = self.downsample = None + if self.up: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") + else: + self.upsample = Upsample2D(in_channels, use_conv=False) + elif self.down: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) + else: + self.downsample = Downsample2D( + in_channels, use_conv=False, padding=1, name="op" + ) + + self.use_in_shortcut = ( + self.in_channels != conv_2d_out_channels + if use_in_shortcut is None + else use_in_shortcut + ) + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, + conv_2d_out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=conv_shortcut_bias, + ) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + if self.time_embedding_norm == "ada_group": + hidden_states = self.norm1(hidden_states, temb) + else: + hidden_states = self.norm1(hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() + input_tensor = self.upsample(input_tensor) + hidden_states = self.upsample(hidden_states) + elif self.downsample is not None: + input_tensor = self.downsample(input_tensor) + hidden_states = self.downsample(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if self.time_emb_proj is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + if self.time_embedding_norm == "ada_group": + hidden_states = self.norm2(hidden_states, temb) + else: + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + if act_fn == "silu": + self.act = nn.SiLU() + elif act_fn == "mish": + self.act = nn.Mish() + elif act_fn == "gelu": + self.act = nn.GELU() + else: + raise ValueError( + f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'" + ) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) + + if post_act_fn is None: + self.post_act = None + elif post_act_fn == "silu": + self.post_act = nn.SiLU() + elif post_act_fn == "mish": + self.post_act = nn.Mish() + elif post_act_fn == "gelu": + self.post_act = nn.GELU() + else: + raise ValueError( + f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'" + ) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Timesteps(nn.Module): + def __init__( + self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float + ): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + ) + return t_emb + + +class Transformer2DModel(ModelMixin, ConfigMixin): + """ + Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual + embeddings) inputs. + + When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard + transformer action. Finally, reshape to image. + + When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional + embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict + classes of unnoised image. + + Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised + image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate( + "norm_type!=num_embeds_ada_norm", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif ( + not self.is_input_continuous + and not self.is_input_vectorized + and not self.is_input_patches + ): + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm( + num_groups=norm_num_groups, + num_channels=in_channels, + eps=1e-6, + affine=True, + ) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + elif self.is_input_vectorized: + assert ( + sample_size is not None + ), "Transformer2DModel over discrete input must provide sample_size" + assert ( + num_vector_embeds is not None + ), "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, + embed_dim=inner_dim, + height=self.height, + width=self.width, + ) + elif self.is_input_patches: + assert ( + sample_size is not None + ), "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continous projections + if use_linear_projection: + self.proj_out = nn.Linear(inner_dim, in_channels) + else: + self.proj_out = nn.Conv2d( + inner_dim, in_channels, kernel_size=1, stride=1, padding=0 + ) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches: + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear( + inner_dim, patch_size * patch_size * self.out_channels + ) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + return_dict: bool = True, + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels + conditioning. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: + [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # 1. Input + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch, height * width, inner_dim + ) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch, height * width, inner_dim + ) + hidden_states = self.proj_in(hidden_states) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + hidden_states = self.pos_embed(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, width, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, width, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + elif self.is_input_patches: + # TODO: cleanup! + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = ( + self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + ) + hidden_states = self.proj_out_2(hidden_states) + + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=( + -1, + height, + width, + self.patch_size, + self.patch_size, + self.out_channels, + ) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=( + -1, + self.out_channels, + height * self.patch_size, + width * self.patch_size, + ) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + +class UNet2DConditionLoadersMixin: + def load_attn_procs( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be + defined in + [cross_attention.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py) + and be a `torch.nn.Module` class. + + + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., + `./my_model_directory/`. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `diffusers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use + this method in a firewalled environment. + + + """ + + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", LORA_WEIGHT_NAME) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path_or_dict + + # fill attn processors + attn_processors = {} + + is_lora = all("lora" in k for k in state_dict.keys()) + + if is_lora: + lora_grouped_dict = defaultdict(dict) + for key, value in state_dict.items(): + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join( + key.split(".")[-3:] + ) + lora_grouped_dict[attn_processor_key][sub_key] = value + + for key, value_dict in lora_grouped_dict.items(): + rank = value_dict["to_k_lora.down.weight"].shape[0] + cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] + hidden_size = value_dict["to_k_lora.up.weight"].shape[0] + + attn_processors[key] = LoRACrossAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + rank=rank, + ) + attn_processors[key].load_state_dict(value_dict) + + else: + raise ValueError( + f"{model_file} does not seem to be in the correct format expected by LoRA training." + ) + + # set correct dtype & device + attn_processors = { + k: v.to(device=self.device, dtype=self.dtype) + for k, v in attn_processors.items() + } + + # set layers + self.set_attn_processor(attn_processors) + + def save_attn_procs( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + weights_name: str = LORA_WEIGHT_NAME, + save_function: Callable = None, + ): + r""" + Save an attention processor to a directory, so that it can be re-loaded using the + `[`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`]` method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + """ + if os.path.isfile(save_directory): + logger.error( + f"Provided path ({save_directory}) should be a directory, not a file" + ) + return + + if save_function is None: + save_function = torch.save + + os.makedirs(save_directory, exist_ok=True) + + model_to_save = AttnProcsLayers(self.attn_processors) + + # Save the model + state_dict = model_to_save.state_dict() + + # Clean the folder from a previous save + for filename in os.listdir(save_directory): + full_filename = os.path.join(save_directory, filename) + # If we have a shard file that is not going to be replaced, we delete it, but only from the main process + # in distributed settings to avoid race conditions. + weights_no_suffix = weights_name.replace(".bin", "") + if ( + filename.startswith(weights_no_suffix) + and os.path.isfile(full_filename) + and is_main_process + ): + os.remove(full_filename) + + # Save the model + save_function(state_dict, os.path.join(save_directory, weights_name)) + + logger.info( + f"Model weights saved in {os.path.join(save_directory, weights_name)}" + ) + + +class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep + and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the + mid block layer if `None`. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, it will skip the normalization and activation layers in post-processing + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately + summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, or `"projection"`. + num_class_embeds (`int`, *optional*, defaults to None): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, default to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + timestep_post_act (`str, *optional*, default to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, default to `None`): + The dimension of `cond_proj` layer in timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + time_embedding_type: str = "positional", + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + ): + super().__init__() + + self.sample_size = sample_size + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len( + only_cross_attention + ) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len( + down_block_types + ): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=conv_in_kernel, + padding=conv_in_padding, + ) + + # time + if time_embedding_type == "fourier": + time_embed_dim = block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError( + f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}." + ) + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, + set_W_to_weight=False, + log=False, + flip_sin_to_cos=flip_sin_to_cos, + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = block_out_channels[0] * 4 + + self.time_proj = Timesteps( + block_out_channels[0], flip_sin_to_cos, freq_shift + ) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding( + projection_class_embeddings_input_dim, time_embed_dim + ) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + self.mid_block = UNetMidBlock2DSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[ + min(i + 1, len(block_out_channels) - 1) + ] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], + num_groups=norm_num_groups, + eps=norm_eps, + ) + self.conv_act = nn.SiLU() + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], + out_channels, + kernel_size=conv_out_kernel, + padding=conv_out_padding, + ) + + @property + def attn_processors(self) -> Dict[str, AttnProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors( + name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor] + ): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor( + self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]] + ): + r""" + Parameters: + `processor (`dict` of `AttnProcessor` or `AttnProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + of **all** `CrossAttention` layers. + In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = ( + num_slicable_layers * [slice_size] + if not isinstance(slice_size, list) + else slice_size + ) + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice( + module: torch.nn.Module, slice_size: List[int] + ): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance( + module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D) + ): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2 ** self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError( + "class_labels should be provided when num_class_embeds > 0" + ) + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if ( + hasattr(downsample_block, "has_cross_attention") + and downsample_block.has_cross_attention + ): + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample += down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if mid_block_additional_residual is not None: + sample += mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[ + : -len(upsample_block.resnets) + ] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if ( + hasattr(upsample_block, "has_cross_attention") + and upsample_block.has_cross_attention + ): + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + # if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) + + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + cross_attention_kwargs=None, + ): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class UpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None + ): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class Upsample2D(nn.Module): + """ + An upsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + use_conv_transpose: + out_channels: + """ + + def __init__( + self, + channels, + use_conv=False, + use_conv_transpose=False, + out_channels=None, + name="conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + conv = None + if use_conv_transpose: + conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, hidden_states, output_size=None): + # assert hidden_states.shape[1] == self.channels + + if self.use_conv_transpose: + return self.conv(hidden_states) + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch + # https://github.com/pytorch/pytorch/issues/86679 + dtype = hidden_states.dtype + # if dtype == torch.bfloat16: + # hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + # if hidden_states.shape[0] >= 64: + # hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate( + hidden_states, scale_factor=2.0, mode="nearest" + ) + else: + hidden_states = F.interpolate( + hidden_states, size=output_size, mode="nearest" + ) + + # If the input is bfloat16, we cast back to bfloat16 + # if dtype == torch.bfloat16: + # hidden_states = hidden_states.to(dtype) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states + + +class _tqdm_cls: + def __call__(self, *args, **kwargs): + if _tqdm_active: + return tqdm_lib.tqdm(*args, **kwargs) + else: + return EmptyTqdm(*args, **kwargs) + + def set_lock(self, *args, **kwargs): + self._lock = None + if _tqdm_active: + return tqdm_lib.tqdm.set_lock(*args, **kwargs) + + def get_lock(self): + if _tqdm_active: + return tqdm_lib.tqdm.get_lock() + + +def _configure_library_root_logger() -> None: + global _default_handler + + with _lock: + if _default_handler: + # This library has already configured the library root logger. + return + _default_handler = logging.StreamHandler() # Set sys.stderr as stream. + _default_handler.flush = sys.stderr.flush + + # Apply our default configuration to the library root logger. + library_root_logger = _get_library_root_logger() + library_root_logger.addHandler(_default_handler) + library_root_logger.setLevel(_get_default_logging_level()) + library_root_logger.propagate = False + + +def _get_default_logging_level(): + """ + If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is + not - fall back to `_default_log_level` + """ + env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None) + if env_level_str: + if env_level_str in log_levels: + return log_levels[env_level_str] + else: + logging.getLogger().warning( + f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, " + f"has to be one of: { ', '.join(log_levels.keys()) }" + ) + return _default_log_level + + +def _get_library_name() -> str: + return __name__.split(".")[0] + + +def _get_library_root_logger() -> logging.Logger: + return logging.getLogger(_get_library_name()) + + +def compare_versions( + library_or_version: Union[str, Version], operation: str, requirement_version: str +): + """ + Args: + Compares a library version to some requirement using a given operation. + library_or_version (`str` or `packaging.version.Version`): + A library name or a version to check. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="`. + requirement_version (`str`): + The version to compare the library version against + """ + if operation not in STR_OPERATION_TO_FUNC.keys(): + raise ValueError( + f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}" + ) + operation = STR_OPERATION_TO_FUNC[operation] + if isinstance(library_or_version, str): + library_or_version = parse(importlib_metadata.version(library_or_version)) + return operation(library_or_version, parse(requirement_version)) + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + down_block_type = ( + down_block_type[7:] + if down_block_type.startswith("UNetRes") + else down_block_type + ) + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "ResnetDownsampleBlock2D": + return ResnetDownsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnDownBlock2D": + return AttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for CrossAttnDownBlock2D" + ) + return CrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "SimpleCrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D" + ) + return SimpleCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "SkipDownBlock2D": + return SkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnSkipDownBlock2D": + return AttnSkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "DownEncoderBlock2D": + return DownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnDownEncoderBlock2D": + return AttnDownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "KDownBlock2D": + return KDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif down_block_type == "KCrossAttnDownBlock2D": + return KCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + add_self_attention=True if not add_downsample else False, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_logger(name: Optional[str] = None) -> logging.Logger: + """ + Return a logger with the specified name. + + This function is not supposed to be directly accessed unless you are writing a custom diffusers module. + """ + + if name is None: + name = _get_library_name() + + _configure_library_root_logger() + return logging.getLogger(name) + + +def get_parameter_dtype(parameter: torch.nn.Module): + try: + return next(parameter.parameters()).dtype + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + up_block_type = ( + up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + ) + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "ResnetUpsampleBlock2D": + return ResnetUpsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for CrossAttnUpBlock2D" + ) + return CrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "SimpleCrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D" + ) + return SimpleCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "AttnUpBlock2D": + return AttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "SkipUpBlock2D": + return SkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "AttnSkipUpBlock2D": + return AttnSkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "UpDecoderBlock2D": + return UpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "AttnUpDecoderBlock2D": + return AttnUpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "KUpBlock2D": + return KUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == "KCrossAttnUpBlock2D": + return KCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + ) + + raise ValueError(f"{up_block_type} does not exist.") + + +def interpolate( + input: Tensor, + size: Optional[int] = None, + scale_factor: Optional[List[float]] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None, + antialias: bool = False, +) -> Tensor: # noqa: F811 + r"""Down/up samples the input to either the given :attr:`size` or the given + :attr:`scale_factor` + + The algorithm used for interpolation is determined by :attr:`mode`. + + Currently temporal, spatial and volumetric sampling are supported, i.e. + expected inputs are 3-D, 4-D or 5-D in shape. + + The input dimensions are interpreted in the form: + `mini-batch x channels x [optional depth] x [optional height] x width`. + + The modes available for resizing are: `nearest`, `linear` (3D-only), + `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only), `area`, `nearest-exact` + + Args: + input (Tensor): the input tensor + size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]): + output spatial size. + scale_factor (float or Tuple[float]): multiplier for spatial size. If `scale_factor` is a tuple, + its length has to match the number of spatial dimensions; `input.dim() - 2`. + mode (str): algorithm used for upsampling: + ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | + ``'trilinear'`` | ``'area'`` | ``'nearest-exact'``. Default: ``'nearest'`` + align_corners (bool, optional): Geometrically, we consider the pixels of the + input and output as squares rather than points. + If set to ``True``, the input and output tensors are aligned by the + center points of their corner pixels, preserving the values at the corner pixels. + If set to ``False``, the input and output tensors are aligned by the corner + points of their corner pixels, and the interpolation uses edge value padding + for out-of-boundary values, making this operation *independent* of input size + when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode` + is ``'linear'``, ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``. + Default: ``False`` + recompute_scale_factor (bool, optional): recompute the scale_factor for use in the + interpolation calculation. If `recompute_scale_factor` is ``True``, then + `scale_factor` must be passed in and `scale_factor` is used to compute the + output `size`. The computed output `size` will be used to infer new scales for + the interpolation. Note that when `scale_factor` is floating-point, it may differ + from the recomputed `scale_factor` due to rounding and precision issues. + If `recompute_scale_factor` is ``False``, then `size` or `scale_factor` will + be used directly for interpolation. Default: ``None``. + antialias (bool, optional): flag to apply anti-aliasing. Default: ``False``. Using anti-alias + option together with ``align_corners=False``, interpolation result would match Pillow + result for downsampling operation. Supported modes: ``'bilinear'``, ``'bicubic'``. + + .. note:: + With ``mode='bicubic'``, it's possible to cause overshoot, in other words it can produce + negative values or values greater than 255 for images. + Explicitly call ``result.clamp(min=0, max=255)`` if you want to reduce the overshoot + when displaying the image. + + .. note:: + Mode ``mode='nearest-exact'`` matches Scikit-Image and PIL nearest neighbours interpolation + algorithms and fixes known issues with ``mode='nearest'``. This mode is introduced to keep + backward compatibility. + Mode ``mode='nearest'`` matches buggy OpenCV's ``INTER_NEAREST`` interpolation algorithm. + + Note: + {backward_reproducibility_note} + """ + if has_torch_function_unary(input): + return handle_torch_function( + interpolate, + (input,), + input, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor, + antialias=antialias, + ) + + if mode in ("nearest", "area", "nearest-exact"): + if align_corners is not None: + raise ValueError( + "align_corners option can only be set with the " + "interpolating modes: linear | bilinear | bicubic | trilinear" + ) + else: + if align_corners is None: + align_corners = False + + dim = input.dim() - 2 # Number of spatial dimensions. + + # Process size and scale_factor. Validate that exactly one is set. + # Validate its length if it is a list, or expand it if it is a scalar. + # After this block, exactly one of output_size and scale_factors will + # be non-None, and it will be a list (or tuple). + if size is not None and scale_factor is not None: + raise ValueError("only one of size or scale_factor should be defined") + elif size is not None: + assert scale_factor is None + scale_factors = None + if isinstance(size, (list, tuple)): + if len(size) != dim: + raise ValueError( + "Input and output must have the same number of spatial dimensions, but got " + f"input with spatial dimensions of {list(input.shape[2:])} and output size of {size}. " + "Please provide input tensor in (N, C, d1, d2, ...,dK) format and " + "output size in (o1, o2, ...,oK) format." + ) + output_size = size + else: + output_size = [size for _ in range(dim)] + elif scale_factor is not None: + assert size is None + output_size = None + if isinstance(scale_factor, (list, tuple)): + if len(scale_factor) != dim: + raise ValueError( + "Input and scale_factor must have the same number of spatial dimensions, but " + f"got input with spatial dimensions of {list(input.shape[2:])} and " + f"scale_factor of shape {scale_factor}. " + "Please provide input tensor in (N, C, d1, d2, ...,dK) format and " + "scale_factor in (s1, s2, ...,sK) format." + ) + scale_factors = scale_factor + else: + scale_factors = [scale_factor for _ in range(dim)] + else: + raise ValueError("either size or scale_factor should be defined") + + if ( + recompute_scale_factor is not None + and recompute_scale_factor + and size is not None + ): + raise ValueError( + "recompute_scale_factor is not meaningful with an explicit size." + ) + + # "area" mode always requires an explicit size rather than scale factor. + # Re-use the recompute_scale_factor code path. + if mode == "area" and output_size is None: + recompute_scale_factor = True + + if recompute_scale_factor is not None and recompute_scale_factor: + # We compute output_size here, then un-set scale_factors. + # The C++ code will recompute it based on the (integer) output size. + assert scale_factors is not None + if not torch.jit.is_scripting() and torch._C._get_tracing_state(): + # make scale_factor a tensor in tracing so constant doesn't get baked in + output_size = [ + ( + torch.floor( + ( + input.size(i + 2).float() + * torch.tensor(scale_factors[i], dtype=torch.float32) + ).float() + ) + ) + for i in range(dim) + ] + elif torch.jit.is_scripting(): + output_size = [ + int(math.floor(float(input.size(i + 2)) * scale_factors[i])) + for i in range(dim) + ] + else: + output_size = [ + _sym_int(input.size(i + 2) * scale_factors[i]) for i in range(dim) + ] + scale_factors = None + + if antialias and not (mode in ("bilinear", "bicubic") and input.ndim == 4): + raise ValueError( + "Anti-alias option is only supported for bilinear and bicubic modes" + ) + + if input.dim() == 3 and mode == "nearest": + return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors) + if input.dim() == 4 and mode == "nearest": + return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors) + if input.dim() == 5 and mode == "nearest": + return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors) + + if input.dim() == 3 and mode == "nearest-exact": + return torch._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors) + if input.dim() == 4 and mode == "nearest-exact": + return torch._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors) + if input.dim() == 5 and mode == "nearest-exact": + return torch._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors) + + if input.dim() == 3 and mode == "area": + assert output_size is not None + return adaptive_avg_pool1d(input, output_size) + if input.dim() == 4 and mode == "area": + assert output_size is not None + return adaptive_avg_pool2d(input, output_size) + if input.dim() == 5 and mode == "area": + assert output_size is not None + return adaptive_avg_pool3d(input, output_size) + + if input.dim() == 3 and mode == "linear": + assert align_corners is not None + return torch._C._nn.upsample_linear1d( + input, output_size, align_corners, scale_factors + ) + if input.dim() == 4 and mode == "bilinear": + assert align_corners is not None + if antialias: + return torch._C._nn._upsample_bilinear2d_aa( + input, output_size, align_corners, scale_factors + ) + return torch._C._nn.upsample_bilinear2d( + input, output_size, align_corners, scale_factors + ) + if input.dim() == 5 and mode == "trilinear": + assert align_corners is not None + return torch._C._nn.upsample_trilinear3d( + input, output_size, align_corners, scale_factors + ) + if input.dim() == 4 and mode == "bicubic": + assert align_corners is not None + if antialias: + return torch._C._nn._upsample_bicubic2d_aa( + input, output_size, align_corners, scale_factors + ) + return torch._C._nn.upsample_bicubic2d( + input, output_size, align_corners, scale_factors + ) + + if input.dim() == 3 and mode == "bilinear": + raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input") + if input.dim() == 3 and mode == "trilinear": + raise NotImplementedError("Got 3D input, but trilinear mode needs 5D input") + if input.dim() == 4 and mode == "linear": + raise NotImplementedError("Got 4D input, but linear mode needs 3D input") + if input.dim() == 4 and mode == "trilinear": + raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input") + if input.dim() == 5 and mode == "linear": + raise NotImplementedError("Got 5D input, but linear mode needs 3D input") + if input.dim() == 5 and mode == "bilinear": + raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input") + + raise NotImplementedError( + "Input Error: Only 3D, 4D and 5D input Tensors supported" + " (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact" + " (got {})".format(input.dim(), mode) + ) + + +def is_accelerate_available(): + return _accelerate_available + + +def is_flax_available(): + return _flax_available + + +def is_k_diffusion_available(): + return _k_diffusion_available + + +def is_librosa_available(): + return _librosa_available + + +def is_onnx_available(): + return _onnx_available + + +def is_safetensors_available(): + return _safetensors_available + + +def is_scipy_available(): + return _scipy_available + + +def is_tensor(obj): + r"""Returns True if `obj` is a PyTorch tensor. + + Note that this function is simply doing ``isinstance(obj, Tensor)``. + Using that ``isinstance`` check is better for typechecking with mypy, + and more explicit - so it's recommended to use that instead of + ``is_tensor``. + + Args: + obj (Object): Object to test + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> torch.is_tensor(x) + True + + """ + return isinstance(obj, torch.Tensor) + + +def is_torch_available(): + return _torch_available + + +def is_torch_version(operation: str, version: str): + """ + Args: + Compares the current PyTorch version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A string version of PyTorch + """ + return compare_versions(parse(_torch_version), operation, version) + + +def is_transformers_available(): + return _transformers_available + + +def is_xformers_available(): + return _xformers_available diff --git a/pi/models/unet/prologue.py b/pi/models/unet/prologue.py new file mode 100644 index 0000000..ba291fc --- /dev/null +++ b/pi/models/unet/prologue.py @@ -0,0 +1,10 @@ +import os + +LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" +CONFIG_NAME = "config.json" +WEIGHTS_NAME = "diffusion_pytorch_model.bin" +FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" +SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" +DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" +ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} +HF_HUB_OFFLINE = os.getenv("HF_HUB_OFFLINE", "").upper() in ENV_VARS_TRUE_VALUES diff --git a/requirements.txt b/requirements.txt index 8ec3a01..a9b60f5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ typeguard==3.0.0b2 multiprocess pybind11 pytest +pyccolo \ No newline at end of file