diff --git a/README.md b/README.md
index 7a101c1..3b55d81 100644
--- a/README.md
+++ b/README.md
@@ -15,7 +15,7 @@ Early days of a lightweight MLIR Python frontend with support for PyTorch (throu
Just
```shell
-pip install - requirements.txt
+pip install -r requirements.txt
pip install . --no-build-isolation
```
diff --git a/examples/unet.py b/examples/unet.py
index ca51b1d..5d45508 100644
--- a/examples/unet.py
+++ b/examples/unet.py
@@ -1,27 +1,92 @@
-import pi
-from pi import nn
-from pi.mlir.utils import pipile
-from pi.utils.annotations import annotate_args
-from pi.models.unet import UNet2DConditionModel
-
-
-class MyUNet(nn.Module):
- def __init__(self):
- super().__init__()
- self.unet = UNet2DConditionModel()
-
- @annotate_args(
- [
- None,
- ([-1, -1, -1, -1], pi.float32, True),
- ]
+import inspect
+import re
+
+import numpy as np
+import torch
+
+from pi.lazy_importer.run_lazy_imports import do_package_imports, do_hand_imports
+from pi.lazy_importer import lazy_imports
+
+
+#
+
+
+def floats_tensor(shape, scale=1.0, rng=None, name=None):
+ total_dims = 1
+ for dim in shape:
+ total_dims *= dim
+ values = []
+ for _ in range(total_dims):
+ values.append(np.random.random() * scale)
+ return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()
+
+
+def run(
+ CTor,
+ down_block_types=("CrossAttnDownBlock2D", "ResnetDownsampleBlock2D"),
+ up_block_types=("UpBlock2D", "ResnetUpsampleBlock2D"),
+):
+ unet = CTor(
+ **{
+ "block_out_channels": (32, 64),
+ "down_block_types": down_block_types,
+ "up_block_types": up_block_types,
+ "cross_attention_dim": 32,
+ "attention_head_dim": 8,
+ "out_channels": 4,
+ "in_channels": 4,
+ "layers_per_block": 2,
+ "sample_size": 32,
+ }
)
- def forward(self, x):
- y = self.resnet(x)
- return y
+ unet.eval()
+ batch_size = 4
+ num_channels = 4
+ sizes = (32, 32)
+
+ noise = floats_tensor((batch_size, num_channels) + sizes)
+ time_step = torch.tensor([10])
+ encoder_hidden_states = floats_tensor((batch_size, 4, 32))
+ output = unet(noise, time_step, encoder_hidden_states)
+
+
+def make_linearized():
+ def filter(ret):
+ try:
+ MODULE_TARGET = lambda x: re.match(
+ r"(huggingface|torch|diffusers)", inspect.getmodule(x).__package__
+ )
+ return MODULE_TARGET(ret)
+ except:
+ return None
+
+ lazy_imports.MODULE_TARGET = filter
+
+ def _inner():
+
+ from diffusers import UNet2DConditionModel
+
+ run(
+ UNet2DConditionModel,
+ down_block_types=("CrossAttnDownBlock2D", "ResnetDownsampleBlock2D"),
+ up_block_types=("UpBlock2D", "ResnetUpsampleBlock2D"),
+ )
+ run(
+ UNet2DConditionModel,
+ down_block_types=("DownBlock2D", "AttnDownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "ResnetUpsampleBlock2D"),
+ )
+
+ prefix = "from pi.models.unet.prologue import CONFIG_NAME, LORA_WEIGHT_NAME"
+ name = "unet_linearized"
+ do_package_imports(_inner, prefix, name)
+
+
+def run_linearized():
+ from pi.models.unet import linearized
+
+ run(linearized.UNet2DConditionModel)
-test_module = MyUNet()
-x = pi.randn((1, 3, 64, 64))
-mlir_module = pipile(test_module, example_args=(x,))
-print(mlir_module)
+if __name__ == "__main__":
+ make_linearized()
diff --git a/pi/lazy_importer/__init__.py b/pi/lazy_importer/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/pi/lazy_importer/lazy_imports.py b/pi/lazy_importer/lazy_imports.py
new file mode 100644
index 0000000..216843d
--- /dev/null
+++ b/pi/lazy_importer/lazy_imports.py
@@ -0,0 +1,377 @@
+# -*- coding: utf-8 -*-
+import torch
+import ast
+import copy
+import importlib
+import inspect
+import logging
+import re
+import sys
+from textwrap import dedent
+from types import FrameType
+from typing import Any, List, Optional, Set, Union
+
+import pyccolo as pyc
+from pyccolo._fast.misc_ast_utils import subscript_to_slice
+
+logger = logging.getLogger(__name__)
+
+
+_unresolved = object()
+
+imports_file = set()
+import_uses_file = set()
+nn_modules_file = set()
+functions_file = set()
+classes_file = set()
+all_calls = set()
+
+
+class _LazySymbol:
+ non_modules: Set[str] = set()
+ blocklist_packages: Set[str] = set()
+
+ def __init__(self, spec: Union[ast.Import, ast.ImportFrom]):
+ self.spec = spec
+ imports_file.add(ast.unparse(spec))
+ self.value = _unresolved
+
+ @property
+ def qualified_module(self) -> str:
+ node = self.spec
+ name = node.names[0].name
+ if isinstance(node, ast.Import):
+ return name
+ else:
+ return f"{node.module}.{name}"
+
+ @staticmethod
+ def top_level_package(module: str) -> str:
+ return module.split(".", 1)[0]
+
+ @classmethod
+ def _unwrap_module(cls, module: str) -> Any:
+ if module in sys.modules:
+ return sys.modules[module]
+ exc = None
+ if module not in cls.non_modules:
+ try:
+ with pyc.allow_reentrant_event_handling():
+ return importlib.import_module(module)
+ except ImportError as e:
+ cls.non_modules.add(module)
+ exc = e
+ except Exception:
+ logger.error("fatal error trying to import module %s", module)
+ raise
+ module_symbol = module.rsplit(".", 1)
+ if len(module_symbol) != 2:
+ raise exc
+ else:
+ module, symbol = module_symbol
+ ret = getattr(cls._unwrap_module(module), symbol)
+ if isinstance(ret, _LazySymbol):
+ ret = ret.unwrap()
+ handle_unwrapped_ret(ret)
+ return ret
+
+ def _unwrap_helper(self) -> Any:
+ return self._unwrap_module(self.qualified_module)
+
+ def unwrap(self) -> Any:
+ if self.value is not _unresolved:
+ return self.value
+ ret = self._unwrap_helper()
+ handle_unwrapped_ret(ret)
+ self.value = ret
+ return ret
+
+ def __call__(self, *args, **kwargs):
+ return self.unwrap()(*args, **kwargs)
+
+ def __getattr__(self, item):
+ hasattr(self.unwrap(), item)
+
+
+class _GetLazyNames(ast.NodeVisitor):
+ def __init__(self) -> None:
+ self.lazy_names: Optional[Set[str]] = set()
+
+ def visit_Import(self, node: ast.Import) -> None:
+ if self.lazy_names is None:
+ return
+ for alias in node.names:
+ if alias.asname is None:
+ return
+ for alias in node.names:
+ self.lazy_names.add(alias.asname)
+
+ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
+ if self.lazy_names is None:
+ return
+ for alias in node.names:
+ if alias.name == "*":
+ self.lazy_names = None
+ return
+ for alias in node.names:
+ self.lazy_names.add(alias.asname or alias.name)
+
+ @classmethod
+ def compute(cls, node: ast.Module) -> Set[str]:
+ inst = cls()
+ inst.visit(node)
+ return inst.lazy_names
+
+
+def _make_attr_guard_helper(node: ast.Attribute) -> Optional[str]:
+ if isinstance(node.value, ast.Name):
+ return f"{node.value.id}_O_{node.attr}"
+ elif isinstance(node.value, ast.Attribute):
+ prefix = _make_attr_guard_helper(node.value)
+ if prefix is None:
+ return None
+ else:
+ return f"{prefix}_O_{node.attr}"
+ else:
+ return None
+
+
+def _make_attr_guard(node: ast.Attribute) -> Optional[str]:
+ suffix = _make_attr_guard_helper(node)
+ if suffix is None:
+ return None
+ else:
+ return f"_Xix_{suffix}"
+
+
+def _make_subscript_guard_helper(node: ast.Subscript) -> Optional[str]:
+ slice_val = subscript_to_slice(node)
+ if isinstance(slice_val, (ast.Constant, ast.Str, ast.Num, ast.Name)):
+ if isinstance(slice_val, ast.Name):
+ subscript = slice_val.id
+ elif hasattr(slice_val, "s"):
+ subscript = f"_{slice_val.s}_" # type: ignore
+ elif hasattr(slice_val, "n"):
+ subscript = f"_{slice_val.n}_" # type: ignore
+ else:
+ return None
+ else:
+ return None
+ if isinstance(node.value, ast.Name):
+ return f"{node.value.id}_S_{subscript}"
+ elif isinstance(node.value, ast.Subscript):
+ prefix = _make_subscript_guard_helper(node.value)
+ if prefix is None:
+ return None
+ else:
+ return f"{prefix}_S_{subscript}"
+ else:
+ return None
+
+
+def _make_subscript_guard(node: ast.Subscript) -> Optional[str]:
+ suffix = _make_subscript_guard_helper(node)
+ if suffix is None:
+ return None
+ else:
+ return f"_Xix_{suffix}"
+
+
+MODULE_TARGET = None
+
+
+def handle_unwrapped_ret(ret):
+ if match := MODULE_TARGET(ret):
+ if (
+ inspect.isclass(ret)
+ and issubclass(ret, torch.nn.Module)
+ and match.group() != "torch"
+ ):
+ src = dedent(inspect.getsource(ret))
+ nn_modules_file.add(src)
+ elif inspect.isclass(ret):
+ src = dedent(inspect.getsource(ret))
+ classes_file.add(src)
+
+
+def handle_call(ret):
+ if match := MODULE_TARGET(ret):
+ if isinstance(ret, torch.nn.Module) and match.group() != "torch":
+ src = dedent(inspect.getsource(ret.__class__))
+ nn_modules_file.add(src)
+ elif inspect.isclass(ret) and match.group() != "torch":
+ src = dedent(inspect.getsource(ret))
+ classes_file.add(src)
+ elif inspect.isfunction(ret):
+ if hasattr(ret, "__name__") and ret.__name__ == "__init__":
+ # 'super().__init__()'
+ return
+ src = dedent(inspect.getsource(ret))
+ functions_file.add(src)
+
+
+class LazyImportTracer(pyc.BaseTracer):
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self.cur_module_lazy_names: Set[str] = set()
+ self.saved_attributes: List[Any] = []
+ self.saved_subscripts: List[Any] = []
+ self.saved_slices: List[Any] = []
+
+ def _is_name_lazy_load(self, node: Union[ast.Attribute, ast.Name]) -> bool:
+ if self.cur_module_lazy_names is None:
+ return True
+ elif isinstance(node, ast.Name):
+ return node.id in self.cur_module_lazy_names
+ elif isinstance(node, (ast.Attribute, ast.Subscript)):
+ return self._is_name_lazy_load(node.value) # type: ignore
+ elif isinstance(node, ast.Call):
+ return self._is_name_lazy_load(node.func)
+ else:
+ return False
+
+ def static_init_module(self, node: ast.Module) -> None:
+ self.cur_module_lazy_names = _GetLazyNames.compute(node)
+
+ @staticmethod
+ def _convert_relative_to_absolute(
+ package: str, module: Optional[str], level: int
+ ) -> str:
+ prefix = package.rsplit(".", level - 1)[0]
+ if not module:
+ return prefix
+ else:
+ return f"{prefix}.{module}"
+
+ @pyc.init_module
+ def init_module(
+ self, _ret: None, node: ast.Module, frame: FrameType, *_, **__
+ ) -> None:
+ assert node is not None
+ for guard in self.local_guards_by_module_id.get(id(node), []):
+ frame.f_globals[guard] = False
+
+ @pyc.before_call()
+ def before_call(
+ self,
+ ret: None,
+ node: Union[ast.Import, ast.ImportFrom],
+ frame: FrameType,
+ *_,
+ **__,
+ ) -> Any:
+ handle_call(ret)
+ return ret
+
+ @pyc.before_stmt(
+ when=pyc.Predicate(
+ lambda node: isinstance(node, (ast.Import, ast.ImportFrom))
+ and pyc.is_outer_stmt(node),
+ static=True,
+ )
+ )
+ def before_stmt(
+ self,
+ _ret: None,
+ node: Union[ast.Import, ast.ImportFrom],
+ frame: FrameType,
+ *_,
+ **__,
+ ) -> Any:
+ is_import = isinstance(node, ast.Import)
+ for alias in node.names:
+ if alias.name == "*":
+ return None
+ elif is_import and alias.asname is None:
+ return None
+ package = frame.f_globals["__package__"]
+ level = getattr(node, "level", 0)
+ if is_import:
+ module = None
+ else:
+ module = node.module # type: ignore
+ if level > 0:
+ module = self._convert_relative_to_absolute(package, module, level)
+ for alias in node.names:
+ node_cpy = copy.deepcopy(node)
+ node_cpy.names = [alias]
+ if module is not None:
+ node_cpy.module = module # type: ignore
+ node_cpy.level = 0 # type: ignore
+ lz = _LazySymbol(spec=node_cpy)
+ frame.f_globals[alias.asname or alias.name] = lz
+ import_uses_file.add(alias.asname or alias.name)
+ return pyc.Pass
+
+ @pyc.before_attribute_load(
+ when=pyc.Predicate(_is_name_lazy_load, static=True), guard=_make_attr_guard
+ )
+ def before_attr_load(self, ret: Any, *_, **__) -> Any:
+ self.saved_attributes.append(ret)
+ return ret
+
+ @pyc.after_attribute_load(
+ when=pyc.Predicate(_is_name_lazy_load, static=True), guard=_make_attr_guard
+ )
+ def after_attr_load(
+ self, ret: Any, node: ast.Attribute, frame: FrameType, _evt, guard, *_, **__
+ ) -> Any:
+ if guard is not None:
+ frame.f_globals[guard] = True
+ saved_attr_obj = self.saved_attributes.pop()
+ if isinstance(ret, _LazySymbol):
+ ret = ret.unwrap()
+ setattr(saved_attr_obj, node.attr, ret)
+ return pyc.Null if ret is None else ret
+
+ @pyc.before_subscript_load(
+ when=pyc.Predicate(_is_name_lazy_load, static=True), guard=_make_subscript_guard
+ )
+ def before_subscript_load(self, ret: Any, *_, attr_or_subscript: Any, **__) -> Any:
+ self.saved_subscripts.append(ret)
+ self.saved_slices.append(attr_or_subscript)
+ return ret
+
+ @pyc.after_subscript_load(
+ when=pyc.Predicate(_is_name_lazy_load, static=True), guard=_make_subscript_guard
+ )
+ def after_subscript_load(
+ self, ret: Any, _node, frame: FrameType, _evt, guard, *_, **__
+ ) -> Any:
+ if guard is not None:
+ frame.f_globals[guard] = True
+ saved_subscript_obj = self.saved_subscripts.pop()
+ saved_slice_obj = self.saved_slices.pop()
+ if isinstance(ret, _LazySymbol):
+ ret = ret.unwrap()
+ saved_subscript_obj[saved_slice_obj] = ret
+ return pyc.Null if ret is None else ret
+
+ @pyc.load_name(
+ when=pyc.Predicate(_is_name_lazy_load, static=True),
+ guard=lambda node: f"_Xix_{node.id}_guard",
+ )
+ def load_name(
+ self, ret: Any, node: ast.Name, frame: FrameType, _evt, guard, *_, **__
+ ) -> Any:
+ if guard is not None:
+ frame.f_globals[guard] = True
+ if isinstance(ret, _LazySymbol):
+ ret = ret.unwrap()
+ frame.f_globals[node.id] = ret
+ else:
+ print("notlazysymbol", ret)
+ return pyc.Null if ret is None else ret
+
+
+def capture_all_calls(frame: FrameType, event, arg):
+ if event != "call":
+ return
+
+ try:
+ if MODULE_TARGET(frame.f_code):
+ code = inspect.getsource(frame.f_code)
+ if not re.match(r"\s", code):
+ all_calls.add(code)
+ except:
+ pass
diff --git a/pi/lazy_importer/run_lazy_imports.py b/pi/lazy_importer/run_lazy_imports.py
new file mode 100644
index 0000000..a85966b
--- /dev/null
+++ b/pi/lazy_importer/run_lazy_imports.py
@@ -0,0 +1,76 @@
+import sys
+from textwrap import dedent
+
+from pi.lazy_importer.lazy_imports import (
+ LazyImportTracer,
+ imports_file,
+ import_uses_file,
+ nn_modules_file,
+ classes_file,
+ capture_all_calls,
+ functions_file,
+)
+
+
+class LazyImports(LazyImportTracer):
+ def should_instrument_file(self, filename: str) -> bool:
+ return not filename.endswith("lazy_imports.py")
+
+
+# /home/mlevental/mambaforge/envs/PI/lib/python3.11/site-packages/pyccolo/ast_rewriter.py:161
+PREFIX = f"""\
+from __future__ import annotations
+import functools
+import importlib
+import inspect
+import json
+import logging
+import math
+import os
+from collections import defaultdict
+from dataclasses import dataclass, fields
+from functools import partial
+from pathlib import PosixPath
+from typing import Optional, Callable, Tuple, Union, Dict, List, Any, OrderedDict
+import sys
+import warnings
+
+import numpy as np
+import torch
+from torch import nn, Tensor
+from torch.nn import functional as F
+logger = logging.getLogger(__name__)
+
+"""
+
+
+def do_hand_imports(closure, closure2):
+ with LazyImports.instance():
+ closure()
+ with open("imports.py", "w") as f:
+ for n in sorted(imports_file):
+ print(n, file=f, flush=True)
+ for n in sorted(import_uses_file):
+ print(
+ dedent(
+ f"""\
+ try: {n}()
+ except: pass
+ """
+ ),
+ file=f,
+ flush=True,
+ )
+ with LazyImports.instance():
+ closure2()
+
+
+def do_package_imports(closure, prefix, name):
+ with LazyImports.instance():
+ closure()
+
+ with open(f"{name}.py", "w") as f:
+ print(PREFIX, file=f, flush=True)
+ print(prefix, file=f, flush=True)
+ for n in sorted(functions_file | classes_file | nn_modules_file):
+ print(n, file=f, flush=True)
diff --git a/pi/models/unet/__init__.py b/pi/models/unet/__init__.py
deleted file mode 100644
index 1d34370..0000000
--- a/pi/models/unet/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .unet_2d_condition import UNet2DConditionModel
\ No newline at end of file
diff --git a/pi/models/unet/attention.py b/pi/models/unet/attention.py
deleted file mode 100644
index 1bd0215..0000000
--- a/pi/models/unet/attention.py
+++ /dev/null
@@ -1,519 +0,0 @@
-# Copyright 2022 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import math
-from typing import Callable, Optional
-
-from ... import nn
-from ... import pi
-from ...nn import functional as F
-
-from .cross_attention import CrossAttention
-from .embeddings import CombinedTimestepLabelEmbeddings
-
-
-xformers = None
-
-
-class AttentionBlock(nn.Module):
- """
- An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
- to the N-d case.
- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
- Uses three q, k, v linear layers to compute attention.
-
- Parameters:
- channels (`int`): The number of channels in the input and output.
- num_head_channels (`int`, *optional*):
- The number of channels in each head. If None, then `num_heads` = 1.
- norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
- rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
- eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
- """
-
- # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore
-
- def __init__(
- self,
- channels: int,
- num_head_channels: Optional[int] = None,
- norm_num_groups: int = 32,
- rescale_output_factor: float = 1.0,
- eps: float = 1e-5,
- ):
- super().__init__()
- self.channels = channels
-
- self.num_heads = (
- channels // num_head_channels if num_head_channels is not None else 1
- )
- self.num_head_size = num_head_channels
- self.group_norm = nn.GroupNorm(
- num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True
- )
-
- # define q,k,v as linear layers
- self.query = nn.Linear(channels, channels)
- self.key = nn.Linear(channels, channels)
- self.value = nn.Linear(channels, channels)
-
- self.rescale_output_factor = rescale_output_factor
- self.proj_attn = nn.Linear(channels, channels, 1)
-
- self._use_memory_efficient_attention_xformers = False
- self._attention_op = None
-
- def reshape_heads_to_batch_dim(self, tensor):
- batch_size, seq_len, dim = tensor.shape
- head_size = self.num_heads
- tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
- tensor = tensor.permute(0, 2, 1, 3).reshape(
- batch_size * head_size, seq_len, dim // head_size
- )
- return tensor
-
- def reshape_batch_dim_to_heads(self, tensor):
- batch_size, seq_len, dim = tensor.shape
- head_size = self.num_heads
- tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
- tensor = tensor.permute(0, 2, 1, 3).reshape(
- batch_size // head_size, seq_len, dim * head_size
- )
- return tensor
-
- def set_use_memory_efficient_attention_xformers(
- self,
- use_memory_efficient_attention_xformers: bool,
- attention_op: Optional[Callable] = None,
- ):
- # if use_memory_efficient_attention_xformers:
- # if not is_xformers_available():
- # raise ModuleNotFoundError(
- # (
- # "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
- # " xformers"
- # ),
- # name="xformers",
- # )
- # elif not pi.cuda.is_available():
- # raise ValueError(
- # "pi.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
- # " only available for GPU "
- # )
- # else:
- # try:
- # # Make sure we can run the memory efficient attention
- # _ = xformers.ops.memory_efficient_attention(
- # pi.randn((1, 2, 40), device="cuda"),
- # pi.randn((1, 2, 40), device="cuda"),
- # pi.randn((1, 2, 40), device="cuda"),
- # )
- # except Exception as e:
- # raise e
- # self._use_memory_efficient_attention_xformers = (
- # use_memory_efficient_attention_xformers
- # )
- # self._attention_op = attention_op
- raise NotImplementedError
-
- def forward(self, hidden_states):
- residual = hidden_states
- batch, channel, height, width = hidden_states.shape
-
- # norm
- hidden_states = self.group_norm(hidden_states)
-
- hidden_states = hidden_states.view(batch, channel, height * width).transpose(
- 1, 2
- )
-
- # proj to q, k, v
- query_proj = self.query(hidden_states)
- key_proj = self.key(hidden_states)
- value_proj = self.value(hidden_states)
-
- scale = 1 / math.sqrt(self.channels / self.num_heads)
-
- query_proj = self.reshape_heads_to_batch_dim(query_proj)
- key_proj = self.reshape_heads_to_batch_dim(key_proj)
- value_proj = self.reshape_heads_to_batch_dim(value_proj)
-
- if self._use_memory_efficient_attention_xformers:
- # Memory efficient attention
- hidden_states = xformers.ops.memory_efficient_attention(
- query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op
- )
- hidden_states = hidden_states.to(query_proj.dtype)
- else:
- attention_scores = pi.baddbmm(
- pi.empty(
- query_proj.shape[0],
- query_proj.shape[1],
- key_proj.shape[1],
- dtype=query_proj.dtype,
- device=query_proj.device,
- ),
- query_proj,
- key_proj.transpose(-1, -2),
- beta=0,
- alpha=scale,
- )
- attention_probs = pi.softmax(attention_scores.float(), dim=-1).type(
- attention_scores.dtype
- )
- hidden_states = pi.bmm(attention_probs, value_proj)
-
- # reshape hidden_states
- hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
-
- # compute next hidden_states
- hidden_states = self.proj_attn(hidden_states)
-
- hidden_states = hidden_states.transpose(-1, -2).reshape(
- batch, channel, height, width
- )
-
- # res connect and rescale
- hidden_states = (hidden_states + residual) / self.rescale_output_factor
- return hidden_states
-
-
-class BasicTransformerBlock(nn.Module):
- r"""
- A basic Transformer block.
-
- Parameters:
- dim (`int`): The number of channels in the input and output.
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
- attention_head_dim (`int`): The number of channels in each head.
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
- num_embeds_ada_norm (:
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
- attention_bias (:
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
- """
-
- def __init__(
- self,
- dim: int,
- num_attention_heads: int,
- attention_head_dim: int,
- dropout=0.0,
- cross_attention_dim: Optional[int] = None,
- activation_fn: str = "geglu",
- num_embeds_ada_norm: Optional[int] = None,
- attention_bias: bool = False,
- only_cross_attention: bool = False,
- upcast_attention: bool = False,
- norm_elementwise_affine: bool = True,
- norm_type: str = "layer_norm",
- final_dropout: bool = False,
- ):
- super().__init__()
- self.only_cross_attention = only_cross_attention
-
- self.use_ada_layer_norm_zero = (
- num_embeds_ada_norm is not None
- ) and norm_type == "ada_norm_zero"
- self.use_ada_layer_norm = (
- num_embeds_ada_norm is not None
- ) and norm_type == "ada_norm"
-
- if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
- raise ValueError(
- f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
- f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
- )
-
- # 1. Self-Attn
- self.attn1 = CrossAttention(
- query_dim=dim,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- dropout=dropout,
- bias=attention_bias,
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
- upcast_attention=upcast_attention,
- )
-
- self.ff = FeedForward(
- dim,
- dropout=dropout,
- activation_fn=activation_fn,
- final_dropout=final_dropout,
- )
-
- # 2. Cross-Attn
- if cross_attention_dim is not None:
- self.attn2 = CrossAttention(
- query_dim=dim,
- cross_attention_dim=cross_attention_dim,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- dropout=dropout,
- bias=attention_bias,
- upcast_attention=upcast_attention,
- ) # is self-attn if encoder_hidden_states is none
- else:
- self.attn2 = None
-
- if self.use_ada_layer_norm:
- self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
- elif self.use_ada_layer_norm_zero:
- self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
- else:
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
-
- if cross_attention_dim is not None:
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
- # the second cross attention block.
- self.norm2 = (
- AdaLayerNorm(dim, num_embeds_ada_norm)
- if self.use_ada_layer_norm
- else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
- )
- else:
- self.norm2 = None
-
- # 3. Feed-forward
- self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
-
- def forward(
- self,
- hidden_states,
- encoder_hidden_states=None,
- timestep=None,
- attention_mask=None,
- cross_attention_kwargs=None,
- class_labels=None,
- ):
- if self.use_ada_layer_norm:
- norm_hidden_states = self.norm1(hidden_states, timestep)
- elif self.use_ada_layer_norm_zero:
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
- )
- else:
- norm_hidden_states = self.norm1(hidden_states)
-
- # 1. Self-Attention
- cross_attention_kwargs = (
- cross_attention_kwargs if cross_attention_kwargs is not None else {}
- )
- attn_output = self.attn1(
- norm_hidden_states,
- encoder_hidden_states=encoder_hidden_states
- if self.only_cross_attention
- else None,
- attention_mask=attention_mask,
- **cross_attention_kwargs,
- )
- if self.use_ada_layer_norm_zero:
- attn_output = gate_msa.unsqueeze(1) * attn_output
- hidden_states = attn_output + hidden_states
-
- if self.attn2 is not None:
- norm_hidden_states = (
- self.norm2(hidden_states, timestep)
- if self.use_ada_layer_norm
- else self.norm2(hidden_states)
- )
-
- # 2. Cross-Attention
- attn_output = self.attn2(
- norm_hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=attention_mask,
- **cross_attention_kwargs,
- )
- hidden_states = attn_output + hidden_states
-
- # 3. Feed-forward
- norm_hidden_states = self.norm3(hidden_states)
-
- if self.use_ada_layer_norm_zero:
- norm_hidden_states = (
- norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
- )
-
- ff_output = self.ff(norm_hidden_states)
-
- if self.use_ada_layer_norm_zero:
- ff_output = gate_mlp.unsqueeze(1) * ff_output
-
- hidden_states = ff_output + hidden_states
-
- return hidden_states
-
-
-class FeedForward(nn.Module):
- r"""
- A feed-forward layer.
-
- Parameters:
- dim (`int`): The number of channels in the input.
- dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
- mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
- final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
- """
-
- def __init__(
- self,
- dim: int,
- dim_out: Optional[int] = None,
- mult: int = 4,
- dropout: float = 0.0,
- activation_fn: str = "geglu",
- final_dropout: bool = False,
- ):
- super().__init__()
- inner_dim = int(dim * mult)
- dim_out = dim_out if dim_out is not None else dim
-
- if activation_fn == "gelu":
- act_fn = GELU(dim, inner_dim)
- if activation_fn == "gelu-approximate":
- act_fn = GELU(dim, inner_dim, approximate="tanh")
- elif activation_fn == "geglu":
- act_fn = GEGLU(dim, inner_dim)
- elif activation_fn == "geglu-approximate":
- act_fn = ApproximateGELU(dim, inner_dim)
-
- self.net = nn.ModuleList([])
- # project in
- self.net.append(act_fn)
- # project dropout
- self.net.append(nn.Dropout(dropout))
- # project out
- self.net.append(nn.Linear(inner_dim, dim_out))
- # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
- if final_dropout:
- self.net.append(nn.Dropout(dropout))
-
- def forward(self, hidden_states):
- for module in self.net:
- hidden_states = module(hidden_states)
- return hidden_states
-
-
-class GELU(nn.Module):
- r"""
- GELU activation function with tanh approximation support with `approximate="tanh"`.
- """
-
- def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
- super().__init__()
- self.proj = nn.Linear(dim_in, dim_out)
- self.approximate = approximate
-
- def gelu(self, gate):
- if gate.device.type != "mps":
- return F.gelu(gate, approximate=self.approximate)
- # mps: gelu is not implemented for float16
- return F.gelu(gate.to(dtype=pi.float32), approximate=self.approximate).to(
- dtype=gate.dtype
- )
-
- def forward(self, hidden_states):
- hidden_states = self.proj(hidden_states)
- hidden_states = self.gelu(hidden_states)
- return hidden_states
-
-
-class GEGLU(nn.Module):
- r"""
- A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
-
- Parameters:
- dim_in (`int`): The number of channels in the input.
- dim_out (`int`): The number of channels in the output.
- """
-
- def __init__(self, dim_in: int, dim_out: int):
- super().__init__()
- self.proj = nn.Linear(dim_in, dim_out * 2)
-
- def gelu(self, gate):
- if gate.device.type != "mps":
- return F.gelu(gate)
- # mps: gelu is not implemented for float16
- return F.gelu(gate.to(dtype=pi.float32)).to(dtype=gate.dtype)
-
- def forward(self, hidden_states):
- hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
- return hidden_states * self.gelu(gate)
-
-
-class ApproximateGELU(nn.Module):
- """
- The approximate form of Gaussian Error Linear Unit (GELU)
-
- For more details, see section 2: https://arxiv.org/abs/1606.08415
- """
-
- def __init__(self, dim_in: int, dim_out: int):
- super().__init__()
- self.proj = nn.Linear(dim_in, dim_out)
-
- def forward(self, x):
- x = self.proj(x)
- return x * pi.sigmoid(1.702 * x)
-
-
-class AdaLayerNorm(nn.Module):
- """
- Norm layer modified to incorporate timestep embeddings.
- """
-
- def __init__(self, embedding_dim, num_embeddings):
- super().__init__()
- self.emb = nn.Embedding(num_embeddings, embedding_dim)
- self.silu = nn.SiLU()
- self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
- self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
-
- def forward(self, x, timestep):
- emb = self.linear(self.silu(self.emb(timestep)))
- scale, shift = pi.chunk(emb, 2)
- x = self.norm(x) * (1 + scale) + shift
- return x
-
-
-class AdaLayerNormZero(nn.Module):
- """
- Norm layer adaptive layer norm zero (adaLN-Zero).
- """
-
- def __init__(self, embedding_dim, num_embeddings):
- super().__init__()
-
- self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
-
- self.silu = nn.SiLU()
- self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
- self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
-
- def forward(self, x, timestep, class_labels, hidden_dtype=None):
- emb = self.linear(
- self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype))
- )
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(
- 6, dim=1
- )
- x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
- return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
diff --git a/pi/models/unet/cross_attention.py b/pi/models/unet/cross_attention.py
deleted file mode 100644
index dc15f86..0000000
--- a/pi/models/unet/cross_attention.py
+++ /dev/null
@@ -1,679 +0,0 @@
-# Copyright 2022 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from typing import Callable, Optional, Union
-
-from ... import nn
-from ... import pi
-from ...nn import functional as F
-
-
-xformers = None
-
-
-class CrossAttention(nn.Module):
- r"""
- A cross attention layer.
-
- Parameters:
- query_dim (`int`): The number of channels in the query.
- cross_attention_dim (`int`, *optional*):
- The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
- heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
- dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
- bias (`bool`, *optional*, defaults to False):
- Set to `True` for the query, key, and value linear layers to contain a bias parameter.
- """
-
- def __init__(
- self,
- query_dim: int,
- cross_attention_dim: Optional[int] = None,
- heads: int = 8,
- dim_head: int = 64,
- dropout: float = 0.0,
- bias=False,
- upcast_attention: bool = False,
- upcast_softmax: bool = False,
- added_kv_proj_dim: Optional[int] = None,
- norm_num_groups: Optional[int] = None,
- processor: Optional["AttnProcessor"] = None,
- ):
- super().__init__()
- inner_dim = dim_head * heads
- cross_attention_dim = (
- cross_attention_dim if cross_attention_dim is not None else query_dim
- )
- self.upcast_attention = upcast_attention
- self.upcast_softmax = upcast_softmax
-
- self.scale = dim_head ** -0.5
-
- self.heads = heads
- # for slice_size > 0 the attention score computation
- # is split across the batch axis to save memory
- # You can set slice_size with `set_attention_slice`
- self.sliceable_head_dim = heads
-
- self.added_kv_proj_dim = added_kv_proj_dim
-
- if norm_num_groups is not None:
- self.group_norm = nn.GroupNorm(
- num_channels=inner_dim,
- num_groups=norm_num_groups,
- eps=1e-5,
- affine=True,
- )
- else:
- self.group_norm = None
-
- self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
- self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
- self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
-
- if self.added_kv_proj_dim is not None:
- self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
- self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
-
- self.to_out = nn.ModuleList([])
- self.to_out.append(nn.Linear(inner_dim, query_dim))
- self.to_out.append(nn.Dropout(dropout))
-
- # set attention processor
- processor = processor if processor is not None else CrossAttnProcessor()
- self.set_processor(processor)
-
- def set_use_memory_efficient_attention_xformers(
- self,
- use_memory_efficient_attention_xformers: bool,
- attention_op: Optional[Callable] = None,
- ):
- if use_memory_efficient_attention_xformers:
- # if self.added_kv_proj_dim is not None:
- # # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
- # # which uses this type of cross attention ONLY because the attention mask of format
- # # [0, ..., -10.000, ..., 0, ...,] is not supported
- # raise NotImplementedError(
- # "Memory efficient attention with `xformers` is currently not supported when"
- # " `self.added_kv_proj_dim` is defined."
- # )
- # elif not is_xformers_available():
- # raise ModuleNotFoundError(
- # (
- # "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
- # " xformers"
- # ),
- # name="xformers",
- # )
- # elif not pi.cuda.is_available():
- # raise ValueError(
- # "pi.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
- # " only available for GPU "
- # )
- # else:
- # try:
- # # Make sure we can run the memory efficient attention
- # _ = xformers.ops.memory_efficient_attention(
- # pi.randn((1, 2, 40), device="cuda"),
- # pi.randn((1, 2, 40), device="cuda"),
- # pi.randn((1, 2, 40), device="cuda"),
- # )
- # except Exception as e:
- # raise e
- #
- # processor = XFormersCrossAttnProcessor(attention_op=attention_op)
- raise NotImplementedError
- else:
- processor = CrossAttnProcessor()
-
- self.set_processor(processor)
-
- def set_attention_slice(self, slice_size):
- if slice_size is not None and slice_size > self.sliceable_head_dim:
- raise ValueError(
- f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}."
- )
-
- if slice_size is not None and self.added_kv_proj_dim is not None:
- processor = SlicedAttnAddedKVProcessor(slice_size)
- elif slice_size is not None:
- processor = SlicedAttnProcessor(slice_size)
- elif self.added_kv_proj_dim is not None:
- processor = CrossAttnAddedKVProcessor()
- else:
- processor = CrossAttnProcessor()
-
- self.set_processor(processor)
-
- def set_processor(self, processor: "AttnProcessor"):
- self.processor = processor
-
- def forward(
- self,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- **cross_attention_kwargs,
- ):
- # The `CrossAttention` class can call different attention processors / attention functions
- # here we simply pass along all tensors to the selected processor class
- # For standard processors that are defined here, `**cross_attention_kwargs` is empty
- return self.processor(
- self,
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=attention_mask,
- **cross_attention_kwargs,
- )
-
- def batch_to_head_dim(self, tensor):
- head_size = self.heads
- batch_size, seq_len, dim = tensor.shape
- tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
- tensor = tensor.permute(0, 2, 1, 3).reshape(
- batch_size // head_size, seq_len, dim * head_size
- )
- return tensor
-
- def head_to_batch_dim(self, tensor):
- head_size = self.heads
- batch_size, seq_len, dim = tensor.shape
- tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
- tensor = tensor.permute(0, 2, 1, 3).reshape(
- batch_size * head_size, seq_len, dim // head_size
- )
- return tensor
-
- def get_attention_scores(self, query, key, attention_mask=None):
- dtype = query.dtype
- if self.upcast_attention:
- query = query.float()
- key = key.float()
-
- attention_scores = pi.baddbmm(
- pi.empty(
- query.shape[0],
- query.shape[1],
- key.shape[1],
- dtype=query.dtype,
- device=query.device,
- ),
- query,
- key.transpose(-1, -2),
- beta=0,
- alpha=self.scale,
- )
-
- if attention_mask is not None:
- attention_scores = attention_scores + attention_mask
-
- if self.upcast_softmax:
- attention_scores = attention_scores.float()
-
- attention_probs = attention_scores.softmax(dim=-1)
- attention_probs = attention_probs.to(dtype)
-
- return attention_probs
-
- def prepare_attention_mask(self, attention_mask, target_length):
- head_size = self.heads
- if attention_mask is None:
- return attention_mask
-
- if attention_mask.shape[-1] != target_length:
- if attention_mask.device.type == "mps":
- # HACK: MPS: Does not support padding by greater than dimension of input tensor.
- # Instead, we can manually construct the padding tensor.
- padding_shape = (
- attention_mask.shape[0],
- attention_mask.shape[1],
- target_length,
- )
- padding = pi.zeros(padding_shape, device=attention_mask.device)
- attention_mask = pi.concat([attention_mask, padding], dim=2)
- else:
- attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
- attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
- return attention_mask
-
-
-class CrossAttnProcessor:
- def __call__(
- self,
- attn: CrossAttention,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- ):
- batch_size, sequence_length, _ = hidden_states.shape
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
-
- query = attn.to_q(hidden_states)
- query = attn.head_to_batch_dim(query)
-
- encoder_hidden_states = (
- encoder_hidden_states
- if encoder_hidden_states is not None
- else hidden_states
- )
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
- key = attn.head_to_batch_dim(key)
- value = attn.head_to_batch_dim(value)
-
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
- hidden_states = pi.bmm(attention_probs, value)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- return hidden_states
-
-
-class LoRALinearLayer(nn.Module):
- def __init__(self, in_features, out_features, rank=4):
- super().__init__()
-
- if rank > min(in_features, out_features):
- raise ValueError(
- f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}"
- )
-
- self.down = nn.Linear(in_features, rank, bias=False)
- self.up = nn.Linear(rank, out_features, bias=False)
- self.scale = 1.0
-
- nn.init.normal_(self.down.weight, std=1 / rank)
- nn.init.zeros_(self.up.weight)
-
- def forward(self, hidden_states):
- orig_dtype = hidden_states.dtype
- dtype = self.down.weight.dtype
-
- down_hidden_states = self.down(hidden_states.to(dtype))
- up_hidden_states = self.up(down_hidden_states)
-
- return up_hidden_states.to(orig_dtype)
-
-
-class LoRACrossAttnProcessor(nn.Module):
- def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
- super().__init__()
-
- self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
- self.to_k_lora = LoRALinearLayer(
- cross_attention_dim or hidden_size, hidden_size
- )
- self.to_v_lora = LoRALinearLayer(
- cross_attention_dim or hidden_size, hidden_size
- )
- self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
-
- def __call__(
- self,
- attn: CrossAttention,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- scale=1.0,
- ):
- batch_size, sequence_length, _ = hidden_states.shape
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
-
- query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
- query = attn.head_to_batch_dim(query)
-
- encoder_hidden_states = (
- encoder_hidden_states
- if encoder_hidden_states is not None
- else hidden_states
- )
-
- key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(
- encoder_hidden_states
- )
- value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(
- encoder_hidden_states
- )
-
- key = attn.head_to_batch_dim(key)
- value = attn.head_to_batch_dim(value)
-
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
- hidden_states = pi.bmm(attention_probs, value)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(
- hidden_states
- )
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- return hidden_states
-
-
-class CrossAttnAddedKVProcessor:
- def __call__(
- self,
- attn: CrossAttention,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- ):
- residual = hidden_states
- hidden_states = hidden_states.view(
- hidden_states.shape[0], hidden_states.shape[1], -1
- ).transpose(1, 2)
- batch_size, sequence_length, _ = hidden_states.shape
- encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
-
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
-
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
- query = attn.head_to_batch_dim(query)
-
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
- key = attn.head_to_batch_dim(key)
- value = attn.head_to_batch_dim(value)
-
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
- encoder_hidden_states_key_proj = attn.head_to_batch_dim(
- encoder_hidden_states_key_proj
- )
- encoder_hidden_states_value_proj = attn.head_to_batch_dim(
- encoder_hidden_states_value_proj
- )
-
- key = pi.concat([encoder_hidden_states_key_proj, key], dim=1)
- value = pi.concat([encoder_hidden_states_value_proj, value], dim=1)
-
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
- hidden_states = pi.bmm(attention_probs, value)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
- hidden_states = hidden_states + residual
-
- return hidden_states
-
-
-class XFormersCrossAttnProcessor:
- def __init__(self, attention_op: Optional[Callable] = None):
- self.attention_op = attention_op
-
- def __call__(
- self,
- attn: CrossAttention,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- ):
- batch_size, sequence_length, _ = hidden_states.shape
-
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
-
- query = attn.to_q(hidden_states)
-
- encoder_hidden_states = (
- encoder_hidden_states
- if encoder_hidden_states is not None
- else hidden_states
- )
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- query = attn.head_to_batch_dim(query).contiguous()
- key = attn.head_to_batch_dim(key).contiguous()
- value = attn.head_to_batch_dim(value).contiguous()
-
- hidden_states = xformers.ops.memory_efficient_attention(
- query, key, value, attn_bias=attention_mask, op=self.attention_op
- )
- hidden_states = hidden_states.to(query.dtype)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
- return hidden_states
-
-
-class LoRAXFormersCrossAttnProcessor(nn.Module):
- def __init__(self, hidden_size, cross_attention_dim, rank=4):
- super().__init__()
-
- self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
- self.to_k_lora = LoRALinearLayer(
- cross_attention_dim or hidden_size, hidden_size
- )
- self.to_v_lora = LoRALinearLayer(
- cross_attention_dim or hidden_size, hidden_size
- )
- self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
-
- def __call__(
- self,
- attn: CrossAttention,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- scale=1.0,
- ):
- batch_size, sequence_length, _ = hidden_states.shape
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
-
- query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
- query = attn.head_to_batch_dim(query).contiguous()
-
- encoder_hidden_states = (
- encoder_hidden_states
- if encoder_hidden_states is not None
- else hidden_states
- )
-
- key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(
- encoder_hidden_states
- )
- value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(
- encoder_hidden_states
- )
-
- key = attn.head_to_batch_dim(key).contiguous()
- value = attn.head_to_batch_dim(value).contiguous()
-
- hidden_states = xformers.ops.memory_efficient_attention(
- query, key, value, attn_bias=attention_mask
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(
- hidden_states
- )
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- return hidden_states
-
-
-class SlicedAttnProcessor:
- def __init__(self, slice_size):
- self.slice_size = slice_size
-
- def __call__(
- self,
- attn: CrossAttention,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- ):
- batch_size, sequence_length, _ = hidden_states.shape
-
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
-
- query = attn.to_q(hidden_states)
- dim = query.shape[-1]
- query = attn.head_to_batch_dim(query)
-
- encoder_hidden_states = (
- encoder_hidden_states
- if encoder_hidden_states is not None
- else hidden_states
- )
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
- key = attn.head_to_batch_dim(key)
- value = attn.head_to_batch_dim(value)
-
- batch_size_attention = query.shape[0]
- hidden_states = pi.zeros(
- (batch_size_attention, sequence_length, dim // attn.heads),
- device=query.device,
- dtype=query.dtype,
- )
-
- for i in range(hidden_states.shape[0] // self.slice_size):
- start_idx = i * self.slice_size
- end_idx = (i + 1) * self.slice_size
-
- query_slice = query[start_idx:end_idx]
- key_slice = key[start_idx:end_idx]
- attn_mask_slice = (
- attention_mask[start_idx:end_idx]
- if attention_mask is not None
- else None
- )
-
- attn_slice = attn.get_attention_scores(
- query_slice, key_slice, attn_mask_slice
- )
-
- attn_slice = pi.bmm(attn_slice, value[start_idx:end_idx])
-
- hidden_states[start_idx:end_idx] = attn_slice
-
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- return hidden_states
-
-
-class SlicedAttnAddedKVProcessor:
- def __init__(self, slice_size):
- self.slice_size = slice_size
-
- def __call__(
- self,
- attn: "CrossAttention",
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- ):
- residual = hidden_states
- hidden_states = hidden_states.view(
- hidden_states.shape[0], hidden_states.shape[1], -1
- ).transpose(1, 2)
- encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
-
- batch_size, sequence_length, _ = hidden_states.shape
-
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
-
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
- dim = query.shape[-1]
- query = attn.head_to_batch_dim(query)
-
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
-
- key = attn.head_to_batch_dim(key)
- value = attn.head_to_batch_dim(value)
- encoder_hidden_states_key_proj = attn.head_to_batch_dim(
- encoder_hidden_states_key_proj
- )
- encoder_hidden_states_value_proj = attn.head_to_batch_dim(
- encoder_hidden_states_value_proj
- )
-
- key = pi.concat([encoder_hidden_states_key_proj, key], dim=1)
- value = pi.concat([encoder_hidden_states_value_proj, value], dim=1)
-
- batch_size_attention = query.shape[0]
- hidden_states = pi.zeros(
- (batch_size_attention, sequence_length, dim // attn.heads),
- device=query.device,
- dtype=query.dtype,
- )
-
- for i in range(hidden_states.shape[0] // self.slice_size):
- start_idx = i * self.slice_size
- end_idx = (i + 1) * self.slice_size
-
- query_slice = query[start_idx:end_idx]
- key_slice = key[start_idx:end_idx]
- attn_mask_slice = (
- attention_mask[start_idx:end_idx]
- if attention_mask is not None
- else None
- )
-
- attn_slice = attn.get_attention_scores(
- query_slice, key_slice, attn_mask_slice
- )
-
- attn_slice = pi.bmm(attn_slice, value[start_idx:end_idx])
-
- hidden_states[start_idx:end_idx] = attn_slice
-
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
- hidden_states = hidden_states + residual
-
- return hidden_states
-
-
-AttnProcessor = Union[
- CrossAttnProcessor,
- XFormersCrossAttnProcessor,
- SlicedAttnProcessor,
- CrossAttnAddedKVProcessor,
- SlicedAttnAddedKVProcessor,
-]
diff --git a/pi/models/unet/embeddings.py b/pi/models/unet/embeddings.py
deleted file mode 100644
index de0cbe8..0000000
--- a/pi/models/unet/embeddings.py
+++ /dev/null
@@ -1,385 +0,0 @@
-# Copyright 2022 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import math
-
-import numpy as np
-from ... import nn
-from ... import pi
-from ...nn import functional as F
-
-
-def get_timestep_embedding(
- timesteps: pi.Tensor,
- embedding_dim: int,
- flip_sin_to_cos: bool = False,
- downscale_freq_shift: float = 1,
- scale: float = 1,
- max_period: int = 10000,
-):
- """
- This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
-
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
- These may be fractional.
- :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
- embeddings. :return: an [N x dim] Tensor of positional embeddings.
- """
- assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
-
- half_dim = embedding_dim // 2
- exponent = -math.log(max_period) * pi.arange(
- start=0, end=half_dim, dtype=pi.float32, device=timesteps.device
- )
- exponent = exponent / (half_dim - downscale_freq_shift)
-
- emb = pi.exp(exponent)
- emb = timesteps[:, None].float() * emb[None, :]
-
- # scale embeddings
- emb = scale * emb
-
- # concat sine and cosine embeddings
- emb = pi.cat([pi.sin(emb), pi.cos(emb)], dim=-1)
-
- # flip sine and cosine embeddings
- if flip_sin_to_cos:
- emb = pi.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
-
- # zero pad
- if embedding_dim % 2 == 1:
- emb = pi.nn.functional.pad(emb, (0, 1, 0, 0))
- return emb
-
-
-def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
- """
- grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
- [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
- """
- grid_h = np.arange(grid_size, dtype=np.float32)
- grid_w = np.arange(grid_size, dtype=np.float32)
- grid = np.meshgrid(grid_w, grid_h) # here w goes first
- grid = np.stack(grid, axis=0)
-
- grid = grid.reshape([2, 1, grid_size, grid_size])
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
- if cls_token and extra_tokens > 0:
- pos_embed = np.concatenate(
- [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
- )
- return pos_embed
-
-
-def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
- if embed_dim % 2 != 0:
- raise ValueError("embed_dim must be divisible by 2")
-
- # use half of dimensions to encode grid_h
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
-
- emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
- return emb
-
-
-def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
- """
- embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
- """
- if embed_dim % 2 != 0:
- raise ValueError("embed_dim must be divisible by 2")
-
- omega = np.arange(embed_dim // 2, dtype=np.float64)
- omega /= embed_dim / 2.0
- omega = 1.0 / 10000 ** omega # (D/2,)
-
- pos = pos.reshape(-1) # (M,)
- out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
-
- emb_sin = np.sin(out) # (M, D/2)
- emb_cos = np.cos(out) # (M, D/2)
-
- emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
- return emb
-
-
-class PatchEmbed(nn.Module):
- """2D Image to Patch Embedding"""
-
- def __init__(
- self,
- height=224,
- width=224,
- patch_size=16,
- in_channels=3,
- embed_dim=768,
- layer_norm=False,
- flatten=True,
- bias=True,
- ):
- super().__init__()
-
- num_patches = (height // patch_size) * (width // patch_size)
- self.flatten = flatten
- self.layer_norm = layer_norm
-
- self.proj = nn.Conv2d(
- in_channels,
- embed_dim,
- kernel_size=(patch_size, patch_size),
- stride=patch_size,
- bias=bias,
- )
- if layer_norm:
- self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
- else:
- self.norm = None
-
- pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches ** 0.5))
- self.register_buffer(
- "pos_embed", pi.from_numpy(pos_embed).float().unsqueeze(0), persistent=False
- )
-
- def forward(self, latent):
- latent = self.proj(latent)
- if self.flatten:
- latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
- if self.layer_norm:
- latent = self.norm(latent)
- return latent + self.pos_embed
-
-
-class TimestepEmbedding(nn.Module):
- def __init__(
- self,
- in_channels: int,
- time_embed_dim: int,
- act_fn: str = "silu",
- out_dim: int = None,
- ):
- super().__init__()
-
- self.linear_1 = nn.Linear(in_channels, time_embed_dim)
- self.act = None
- if act_fn == "silu":
- self.act = nn.SiLU()
- elif act_fn == "mish":
- self.act = nn.Mish()
-
- if out_dim is not None:
- time_embed_dim_out = out_dim
- else:
- time_embed_dim_out = time_embed_dim
- self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
-
- def forward(self, sample):
- sample = self.linear_1(sample)
-
- if self.act is not None:
- sample = self.act(sample)
-
- sample = self.linear_2(sample)
- return sample
-
-
-class Timesteps(nn.Module):
- def __init__(
- self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float
- ):
- super().__init__()
- self.num_channels = num_channels
- self.flip_sin_to_cos = flip_sin_to_cos
- self.downscale_freq_shift = downscale_freq_shift
-
- def forward(self, timesteps):
- t_emb = get_timestep_embedding(
- timesteps,
- self.num_channels,
- flip_sin_to_cos=self.flip_sin_to_cos,
- downscale_freq_shift=self.downscale_freq_shift,
- )
- return t_emb
-
-
-class GaussianFourierProjection(nn.Module):
- """Gaussian Fourier embeddings for noise levels."""
-
- def __init__(
- self,
- embedding_size: int = 256,
- scale: float = 1.0,
- set_W_to_weight=True,
- log=True,
- flip_sin_to_cos=False,
- ):
- super().__init__()
- self.weight = nn.Parameter(
- pi.randn(embedding_size) * scale, requires_grad=False
- )
- self.log = log
- self.flip_sin_to_cos = flip_sin_to_cos
-
- if set_W_to_weight:
- # to delete later
- self.W = nn.Parameter(pi.randn(embedding_size) * scale, requires_grad=False)
-
- self.weight = self.W
-
- def forward(self, x):
- if self.log:
- x = pi.log(x)
-
- x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
-
- if self.flip_sin_to_cos:
- out = pi.cat([pi.cos(x_proj), pi.sin(x_proj)], dim=-1)
- else:
- out = pi.cat([pi.sin(x_proj), pi.cos(x_proj)], dim=-1)
- return out
-
-
-class ImagePositionalEmbeddings(nn.Module):
- """
- Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
- height and width of the latent space.
-
- For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
-
- For VQ-diffusion:
-
- Output vector embeddings are used as input for the transformer.
-
- Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
-
- Args:
- num_embed (`int`):
- Number of embeddings for the latent pixels embeddings.
- height (`int`):
- Height of the latent image i.e. the number of height embeddings.
- width (`int`):
- Width of the latent image i.e. the number of width embeddings.
- embed_dim (`int`):
- Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
- """
-
- def __init__(
- self,
- num_embed: int,
- height: int,
- width: int,
- embed_dim: int,
- ):
- super().__init__()
-
- self.height = height
- self.width = width
- self.num_embed = num_embed
- self.embed_dim = embed_dim
-
- self.emb = nn.Embedding(self.num_embed, embed_dim)
- self.height_emb = nn.Embedding(self.height, embed_dim)
- self.width_emb = nn.Embedding(self.width, embed_dim)
-
- def forward(self, index):
- emb = self.emb(index)
-
- height_emb = self.height_emb(
- pi.arange(self.height, device=index.device).view(1, self.height)
- )
-
- # 1 x H x D -> 1 x H x 1 x D
- height_emb = height_emb.unsqueeze(2)
-
- width_emb = self.width_emb(
- pi.arange(self.width, device=index.device).view(1, self.width)
- )
-
- # 1 x W x D -> 1 x 1 x W x D
- width_emb = width_emb.unsqueeze(1)
-
- pos_emb = height_emb + width_emb
-
- # 1 x H x W x D -> 1 x L xD
- pos_emb = pos_emb.view(1, self.height * self.width, -1)
-
- emb = emb + pos_emb[:, : emb.shape[1], :]
-
- return emb
-
-
-class LabelEmbedding(nn.Module):
- """
- Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
-
- Args:
- num_classes (`int`): The number of classes.
- hidden_size (`int`): The size of the vector embeddings.
- dropout_prob (`float`): The probability of dropping a label.
- """
-
- def __init__(self, num_classes, hidden_size, dropout_prob):
- super().__init__()
- use_cfg_embedding = dropout_prob > 0
- self.embedding_table = nn.Embedding(
- num_classes + use_cfg_embedding, hidden_size
- )
- self.num_classes = num_classes
- self.dropout_prob = dropout_prob
-
- def token_drop(self, labels, force_drop_ids=None):
- """
- Drops labels to enable classifier-free guidance.
- """
- if force_drop_ids is None:
- drop_ids = (
- pi.rand(labels.shape[0], device=labels.device) < self.dropout_prob
- )
- else:
- drop_ids = pi.tensor(force_drop_ids == 1)
- labels = pi.where(drop_ids, self.num_classes, labels)
- return labels
-
- def forward(self, labels, force_drop_ids=None):
- use_dropout = self.dropout_prob > 0
- if (self.training and use_dropout) or (force_drop_ids is not None):
- labels = self.token_drop(labels, force_drop_ids)
- embeddings = self.embedding_table(labels)
- return embeddings
-
-
-class CombinedTimestepLabelEmbeddings(nn.Module):
- def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
- super().__init__()
-
- self.time_proj = Timesteps(
- num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1
- )
- self.timestep_embedder = TimestepEmbedding(
- in_channels=256, time_embed_dim=embedding_dim
- )
- self.class_embedder = LabelEmbedding(
- num_classes, embedding_dim, class_dropout_prob
- )
-
- def forward(self, timestep, class_labels, hidden_dtype=None):
- timesteps_proj = self.time_proj(timestep)
- timesteps_emb = self.timestep_embedder(
- timesteps_proj.to(dtype=hidden_dtype)
- ) # (N, D)
-
- class_labels = self.class_embedder(class_labels) # (N, D)
-
- conditioning = timesteps_emb + class_labels # (N, D)
-
- return conditioning
diff --git a/pi/models/unet/hand_imports.py b/pi/models/unet/hand_imports.py
new file mode 100644
index 0000000..652d090
--- /dev/null
+++ b/pi/models/unet/hand_imports.py
@@ -0,0 +1,21 @@
+from diffusers.models.unet_2d_blocks import ResnetDownsampleBlock2D, AttnDownBlock2D, SimpleCrossAttnDownBlock2D, \
+ SkipDownBlock2D, AttnSkipDownBlock2D, DownEncoderBlock2D, AttnDownEncoderBlock2D, KDownBlock2D, \
+ KCrossAttnDownBlock2D, ResnetUpsampleBlock2D, SimpleCrossAttnUpBlock2D, AttnUpBlock2D, SkipUpBlock2D, \
+ AttnSkipUpBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D, KUpBlock2D, KCrossAttnUpBlock2D
+from diffusers.models import DualTransformer2DModel
+from diffusers.models.attention import AdaGroupNorm
+from diffusers.models.cross_attention import CrossAttnAddedKVProcessor, AttnProcessor
+from diffusers.models.embeddings import GaussianFourierProjection, PatchEmbed, ImagePositionalEmbeddings
+from diffusers.models.resnet import FirUpsample2D, FirDownsample2D, downsample_2d, upsample_2d
+from diffusers.models.unet_2d_blocks import UNetMidBlock2DSimpleCrossAttn
+from diffusers.models.unet_2d_condition import UNet2DConditionOutput
+from diffusers.models.resnet import upfirdn2d_native
+from diffusers.loaders import AttnProcsLayers
+from diffusers.models.attention import AttentionBlock, AdaLayerNorm, GELU, ApproximateGELU
+from diffusers.models.cross_attention import LoRACrossAttnProcessor, CrossAttnProcessor, LoRAXFormersCrossAttnProcessor, \
+ XFormersCrossAttnProcessor, SlicedAttnAddedKVProcessor, SlicedAttnProcessor, AttnProcessor
+from diffusers.models.resnet import KDownsample2D, KUpsample2D
+from diffusers.models.unet_2d_blocks import KAttentionBlock
+from diffusers.models.cross_attention import LoRALinearLayer, AttnProcessor
+from diffusers.models.attention import AdaLayerNormZero
+from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings
diff --git a/pi/models/unet/linearized.py b/pi/models/unet/linearized.py
new file mode 100644
index 0000000..d421b35
--- /dev/null
+++ b/pi/models/unet/linearized.py
@@ -0,0 +1,5014 @@
+from __future__ import annotations
+import functools
+import importlib
+import inspect
+import json
+import logging
+import math
+import os
+from collections import defaultdict
+from dataclasses import dataclass, fields
+from functools import partial
+from pathlib import PosixPath
+from typing import Optional, Callable, Tuple, Union, Dict, List, Any, OrderedDict
+import sys
+
+import numpy as np
+import torch
+from torch import nn, Tensor
+from torch.nn import functional as F
+
+from pi.models.unet.prologue import CONFIG_NAME, LORA_WEIGHT_NAME
+
+logger = logging.getLogger(__name__)
+
+
+def register_to_config(init):
+ r"""
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
+
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
+ """
+
+ @functools.wraps(init)
+ def inner_init(self, *args, **kwargs):
+ # Ignore private kwargs in the init.
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
+ config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
+ if not isinstance(self, ConfigMixin):
+ raise RuntimeError(
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
+ "not inherit from `ConfigMixin`."
+ )
+
+ ignore = getattr(self, "ignore_for_config", [])
+ # Get positional arguments aligned with kwargs
+ new_kwargs = {}
+ signature = inspect.signature(init)
+ parameters = {
+ name: p.default
+ for i, (name, p) in enumerate(signature.parameters.items())
+ if i > 0 and name not in ignore
+ }
+ for arg, name in zip(args, parameters.keys()):
+ new_kwargs[name] = arg
+
+ # Then add all kwargs
+ new_kwargs.update(
+ {
+ k: init_kwargs.get(k, default)
+ for k, default in parameters.items()
+ if k not in ignore and k not in new_kwargs
+ }
+ )
+ new_kwargs = {**config_init_kwargs, **new_kwargs}
+ getattr(self, "register_to_config")(**new_kwargs)
+ init(self, *args, **init_kwargs)
+
+ return inner_init
+
+
+class BaseOutput(OrderedDict):
+ """
+ Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
+ tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
+ python dictionary.
+
+
+
+ You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple
+ before.
+
+
+ """
+
+ def __post_init__(self):
+ class_fields = fields(self)
+
+ # Safety and consistency checks
+ if not len(class_fields):
+ raise ValueError(f"{self.__class__.__name__} has no fields.")
+
+ first_field = getattr(self, class_fields[0].name)
+ other_fields_are_none = all(
+ getattr(self, field.name) is None for field in class_fields[1:]
+ )
+
+ if other_fields_are_none and isinstance(first_field, dict):
+ for key, value in first_field.items():
+ self[key] = value
+ else:
+ for field in class_fields:
+ v = getattr(self, field.name)
+ if v is not None:
+ self[field.name] = v
+
+ def __delitem__(self, *args, **kwargs):
+ raise Exception(
+ f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance."
+ )
+
+ def setdefault(self, *args, **kwargs):
+ raise Exception(
+ f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance."
+ )
+
+ def pop(self, *args, **kwargs):
+ raise Exception(
+ f"You cannot use ``pop`` on a {self.__class__.__name__} instance."
+ )
+
+ def update(self, *args, **kwargs):
+ raise Exception(
+ f"You cannot use ``update`` on a {self.__class__.__name__} instance."
+ )
+
+ def __getitem__(self, k):
+ if isinstance(k, str):
+ inner_dict = {k: v for (k, v) in self.items()}
+ return inner_dict[k]
+ else:
+ return self.to_tuple()[k]
+
+ def __setattr__(self, name, value):
+ if name in self.keys() and value is not None:
+ # Don't call self.__setitem__ to avoid recursion errors
+ super().__setitem__(name, value)
+ super().__setattr__(name, value)
+
+ def __setitem__(self, key, value):
+ # Will raise a KeyException if needed
+ super().__setitem__(key, value)
+ # Don't call self.__setattr__ to avoid recursion errors
+ super().__setattr__(key, value)
+
+ def to_tuple(self) -> Tuple[Any]:
+ """
+ Convert self to a tuple containing all the attributes/keys that are not `None`.
+ """
+ return tuple(self[k] for k in self.keys())
+
+
+@dataclass
+class Transformer2DModelOutput(BaseOutput):
+ """
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
+ Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
+ for the unnoised latent pixels.
+ """
+
+ sample: torch.FloatTensor
+
+
+class AttnProcessor2_0:
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: CrossAttention,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ ):
+ batch_size, sequence_length, inner_dim = hidden_states.shape
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(
+ attention_mask, sequence_length, batch_size
+ )
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
+ )
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.cross_attention_norm:
+ encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ head_dim = inner_dim // attn.heads
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(
+ batch_size, -1, attn.heads * head_dim
+ )
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+
+class BasicTransformerBlock(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_type: str = "layer_norm",
+ final_dropout: bool = False,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+
+ self.use_ada_layer_norm_zero = (
+ num_embeds_ada_norm is not None
+ ) and norm_type == "ada_norm_zero"
+ self.use_ada_layer_norm = (
+ num_embeds_ada_norm is not None
+ ) and norm_type == "ada_norm"
+
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
+ raise ValueError(
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
+ )
+
+ # 1. Self-Attn
+ self.attn1 = CrossAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ )
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ )
+
+ # 2. Cross-Attn
+ if cross_attention_dim is not None:
+ self.attn2 = CrossAttention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ ) # is self-attn if encoder_hidden_states is none
+ else:
+ self.attn2 = None
+
+ if self.use_ada_layer_norm:
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif self.use_ada_layer_norm_zero:
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
+ else:
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+
+ if cross_attention_dim is not None:
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # the second cross attention block.
+ self.norm2 = (
+ AdaLayerNorm(dim, num_embeds_ada_norm)
+ if self.use_ada_layer_norm
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+ )
+ else:
+ self.norm2 = None
+
+ # 3. Feed-forward
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+
+ def forward(
+ self,
+ hidden_states,
+ encoder_hidden_states=None,
+ timestep=None,
+ attention_mask=None,
+ cross_attention_kwargs=None,
+ class_labels=None,
+ ):
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ # 1. Self-Attention
+ cross_attention_kwargs = (
+ cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ )
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states
+ if self.only_cross_attention
+ else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = attn_output + hidden_states
+
+ if self.attn2 is not None:
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep)
+ if self.use_ada_layer_norm
+ else self.norm2(hidden_states)
+ )
+
+ # 2. Cross-Attention
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 3. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = (
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ )
+
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = ff_output + hidden_states
+
+ return hidden_states
+
+
+class ConfigMixin:
+ r"""
+ Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
+ methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
+ - [`~ConfigMixin.from_config`]
+ - [`~ConfigMixin.save_config`]
+
+ Class attributes:
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
+ [`~ConfigMixin.save_config`] (should be overridden by parent class).
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
+ overridden by subclass).
+ - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
+ - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
+ should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
+ subclass).
+ """
+ config_name = None
+ ignore_for_config = []
+ has_compatibles = False
+
+ _deprecated_kwargs = []
+
+ def register_to_config(self, **kwargs):
+ if self.config_name is None:
+ raise NotImplementedError(
+ f"Make sure that {self.__class__} has defined a class name `config_name`"
+ )
+ # Special case for `kwargs` used in deprecation warning added to schedulers
+ # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
+ # or solve in a more general way.
+ kwargs.pop("kwargs", None)
+ for key, value in kwargs.items():
+ try:
+ setattr(self, key, value)
+ except AttributeError as err:
+ logger.error(f"Can't set {key} with value {value} for {self}")
+ raise err
+
+ if not hasattr(self, "_internal_dict"):
+ internal_dict = kwargs
+ else:
+ previous_dict = dict(self._internal_dict)
+ internal_dict = {**self._internal_dict, **kwargs}
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
+
+ self._internal_dict = FrozenDict(internal_dict)
+
+ def save_config(
+ self,
+ save_directory: Union[str, os.PathLike],
+ push_to_hub: bool = False,
+ **kwargs,
+ ):
+ """
+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
+ [`~ConfigMixin.from_config`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
+ """
+ if os.path.isfile(save_directory):
+ raise AssertionError(
+ f"Provided path ({save_directory}) should be a directory, not a file"
+ )
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ # If we save using the predefined names, we can load using `from_config`
+ output_config_file = os.path.join(save_directory, self.config_name)
+
+ self.to_json_file(output_config_file)
+ logger.info(f"Configuration saved in {output_config_file}")
+
+ @classmethod
+ def from_config(
+ cls,
+ config: Union[FrozenDict, Dict[str, Any]] = None,
+ return_unused_kwargs=False,
+ **kwargs,
+ ):
+ r"""
+ Instantiate a Python class from a config dictionary
+
+ Parameters:
+ config (`Dict[str, Any]`):
+ A config dictionary from which the Python class will be instantiated. Make sure to only load
+ configuration files of compatible classes.
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ Whether kwargs that are not consumed by the Python class should be returned or not.
+
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to update the configuration object (after it being loaded) and initiate the Python class.
+ `**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually
+ overwrite same named arguments of `config`.
+
+ Examples:
+
+ ```python
+ >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
+
+ >>> # Download scheduler from huggingface.co and cache.
+ >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
+
+ >>> # Instantiate DDIM scheduler class with same config as DDPM
+ >>> scheduler = DDIMScheduler.from_config(scheduler.config)
+
+ >>> # Instantiate PNDM scheduler class with same config as DDPM
+ >>> scheduler = PNDMScheduler.from_config(scheduler.config)
+ ```
+ """
+ # <===== TO BE REMOVED WITH DEPRECATION
+ # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
+ if "pretrained_model_name_or_path" in kwargs:
+ config = kwargs.pop("pretrained_model_name_or_path")
+
+ if config is None:
+ raise ValueError(
+ "Please make sure to provide a config as the first positional argument."
+ )
+ # ======>
+
+ if not isinstance(config, dict):
+ deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
+ if "Scheduler" in cls.__name__:
+ deprecation_message += (
+ f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
+ " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
+ " be removed in v1.0.0."
+ )
+ elif "Model" in cls.__name__:
+ deprecation_message += (
+ f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
+ f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
+ " instead. This functionality will be removed in v1.0.0."
+ )
+ deprecate(
+ "config-passed-as-path",
+ "1.0.0",
+ deprecation_message,
+ standard_warn=False,
+ )
+ config, kwargs = cls.load_config(
+ pretrained_model_name_or_path=config,
+ return_unused_kwargs=True,
+ **kwargs,
+ )
+
+ init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
+
+ # Allow dtype to be specified on initialization
+ if "dtype" in unused_kwargs:
+ init_dict["dtype"] = unused_kwargs.pop("dtype")
+
+ # add possible deprecated kwargs
+ for deprecated_kwarg in cls._deprecated_kwargs:
+ if deprecated_kwarg in unused_kwargs:
+ init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
+
+ # Return model and optionally state and/or unused_kwargs
+ model = cls(**init_dict)
+
+ # make sure to also save config parameters that might be used for compatible classes
+ model.register_to_config(**hidden_dict)
+
+ # add hidden kwargs of compatible classes to unused_kwargs
+ unused_kwargs = {**unused_kwargs, **hidden_dict}
+
+ if return_unused_kwargs:
+ return (model, unused_kwargs)
+ else:
+ return model
+
+ @classmethod
+ def get_config_dict(cls, *args, **kwargs):
+ deprecation_message = (
+ f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
+ " removed in version v1.0.0"
+ )
+ deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
+ return cls.load_config(*args, **kwargs)
+
+ @classmethod
+ def load_config(
+ cls,
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ return_unused_kwargs=False,
+ **kwargs,
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ r"""
+ Instantiate a Python class from a config dictionary
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
+ organization name, like `google/ddpm-celebahq-256`.
+ - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g.,
+ `./my_model_directory/`.
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
+ huggingface.co or downloaded locally), you can specify the folder name here.
+
+
+
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
+ use this method in a firewalled environment.
+
+
+ """
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ revision = kwargs.pop("revision", None)
+ _ = kwargs.pop("mirror", None)
+ subfolder = kwargs.pop("subfolder", None)
+
+ user_agent = {"file_type": "config"}
+
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+
+ if cls.config_name is None:
+ raise ValueError(
+ "`self.config_name` is not defined. Note that one should not load a config from "
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
+ )
+
+ if os.path.isfile(pretrained_model_name_or_path):
+ config_file = pretrained_model_name_or_path
+ elif os.path.isdir(pretrained_model_name_or_path):
+ if os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, cls.config_name)
+ ):
+ # Load from a PyTorch checkpoint
+ config_file = os.path.join(
+ pretrained_model_name_or_path, cls.config_name
+ )
+ elif subfolder is not None and os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
+ ):
+ config_file = os.path.join(
+ pretrained_model_name_or_path, subfolder, cls.config_name
+ )
+ else:
+ raise EnvironmentError(
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
+ )
+ else:
+ try:
+ # Load from URL or cache if already cached
+ config_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ filename=cls.config_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
+ subfolder=subfolder,
+ revision=revision,
+ )
+
+ except RepositoryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
+ " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
+ " login`."
+ )
+ except RevisionNotFoundError:
+ raise EnvironmentError(
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
+ " this model name. Check the model page at"
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
+ )
+ except EntryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
+ )
+ except HTTPError as err:
+ raise EnvironmentError(
+ "There was a specific connection error when trying to load"
+ f" {pretrained_model_name_or_path}:\n{err}"
+ )
+ except ValueError:
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
+ " run the library in offline mode at"
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
+ )
+ except EnvironmentError:
+ raise EnvironmentError(
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+ f"containing a {cls.config_name} file"
+ )
+
+ try:
+ # Load config dict
+ config_dict = cls._dict_from_json_file(config_file)
+ except (json.JSONDecodeError, UnicodeDecodeError):
+ raise EnvironmentError(
+ f"It looks like the config file at '{config_file}' is not a valid JSON file."
+ )
+
+ if return_unused_kwargs:
+ return config_dict, kwargs
+
+ return config_dict
+
+ @staticmethod
+ def _get_init_keys(cls):
+ return set(dict(inspect.signature(cls.__init__).parameters).keys())
+
+ @classmethod
+ def extract_init_dict(cls, config_dict, **kwargs):
+ # 0. Copy origin config dict
+ original_dict = {k: v for k, v in config_dict.items()}
+
+ # 1. Retrieve expected config attributes from __init__ signature
+ expected_keys = cls._get_init_keys(cls)
+ expected_keys.remove("self")
+ # remove general kwargs if present in dict
+ if "kwargs" in expected_keys:
+ expected_keys.remove("kwargs")
+ # remove flax internal keys
+ if hasattr(cls, "_flax_internal_args"):
+ for arg in cls._flax_internal_args:
+ expected_keys.remove(arg)
+
+ # 2. Remove attributes that cannot be expected from expected config attributes
+ # remove keys to be ignored
+ if len(cls.ignore_for_config) > 0:
+ expected_keys = expected_keys - set(cls.ignore_for_config)
+
+ # load diffusers library to import compatible and original scheduler
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
+
+ if cls.has_compatibles:
+ compatible_classes = [
+ c for c in cls._get_compatibles() if not isinstance(c, DummyObject)
+ ]
+ else:
+ compatible_classes = []
+
+ expected_keys_comp_cls = set()
+ for c in compatible_classes:
+ expected_keys_c = cls._get_init_keys(c)
+ expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
+ expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
+ config_dict = {
+ k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls
+ }
+
+ # remove attributes from orig class that cannot be expected
+ orig_cls_name = config_dict.pop("_class_name", cls.__name__)
+ if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
+ orig_cls = getattr(diffusers_library, orig_cls_name)
+ unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
+ config_dict = {
+ k: v
+ for k, v in config_dict.items()
+ if k not in unexpected_keys_from_orig
+ }
+
+ # remove private attributes
+ config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
+
+ # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
+ init_dict = {}
+ for key in expected_keys:
+ # if config param is passed to kwarg and is present in config dict
+ # it should overwrite existing config dict key
+ if key in kwargs and key in config_dict:
+ config_dict[key] = kwargs.pop(key)
+
+ if key in kwargs:
+ # overwrite key
+ init_dict[key] = kwargs.pop(key)
+ elif key in config_dict:
+ # use value from config dict
+ init_dict[key] = config_dict.pop(key)
+
+ # 4. Give nice warning if unexpected values have been passed
+ if len(config_dict) > 0:
+ logger.warning(
+ f"The config attributes {config_dict} were passed to {cls.__name__}, "
+ "but are not expected and will be ignored. Please verify your "
+ f"{cls.config_name} configuration file."
+ )
+
+ # 5. Give nice info if config attributes are initiliazed to default because they have not been passed
+ passed_keys = set(init_dict.keys())
+ if len(expected_keys - passed_keys) > 0:
+ logger.info(
+ f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
+ )
+
+ # 6. Define unused keyword arguments
+ unused_kwargs = {**config_dict, **kwargs}
+
+ # 7. Define "hidden" config parameters that were saved for compatible classes
+ hidden_config_dict = {
+ k: v for k, v in original_dict.items() if k not in init_dict
+ }
+
+ return init_dict, unused_kwargs, hidden_config_dict
+
+ @classmethod
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
+ with open(json_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ return json.loads(text)
+
+ def __repr__(self):
+ return f"{self.__class__.__name__} {self.to_json_string()}"
+
+ @property
+ def config(self) -> Dict[str, Any]:
+ """
+ Returns the config of the class as a frozen dictionary
+
+ Returns:
+ `Dict[str, Any]`: Config of the class.
+ """
+ return self._internal_dict
+
+ def to_json_string(self) -> str:
+ """
+ Serializes this instance to a JSON string.
+
+ Returns:
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
+ """
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
+ config_dict["_class_name"] = self.__class__.__name__
+ config_dict["_diffusers_version"] = __version__
+
+ def to_json_saveable(value):
+ if isinstance(value, np.ndarray):
+ value = value.tolist()
+ elif isinstance(value, PosixPath):
+ value = str(value)
+ return value
+
+ config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
+
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+ """
+ Save this instance to a JSON file.
+
+ Args:
+ json_file_path (`str` or `os.PathLike`):
+ Path to the JSON file in which this configuration instance's parameters will be saved.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ writer.write(self.to_json_string())
+
+
+class CrossAttention(nn.Module):
+ r"""
+ A cross attention layer.
+
+ Parameters:
+ query_dim (`int`): The number of channels in the query.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ bias (`bool`, *optional*, defaults to False):
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias=False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ cross_attention_norm: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ norm_num_groups: Optional[int] = None,
+ processor: Optional["AttnProcessor"] = None,
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ cross_attention_dim = (
+ cross_attention_dim if cross_attention_dim is not None else query_dim
+ )
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+ self.cross_attention_norm = cross_attention_norm
+
+ self.scale = dim_head ** -0.5
+
+ self.heads = heads
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self.sliceable_head_dim = heads
+
+ self.added_kv_proj_dim = added_kv_proj_dim
+
+ if norm_num_groups is not None:
+ self.group_norm = nn.GroupNorm(
+ num_channels=inner_dim,
+ num_groups=norm_num_groups,
+ eps=1e-5,
+ affine=True,
+ )
+ else:
+ self.group_norm = None
+
+ if cross_attention_norm:
+ self.norm_cross = nn.LayerNorm(cross_attention_dim)
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
+
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
+ self.to_out.append(nn.Dropout(dropout))
+
+ # set attention processor
+ # We use the AttnProcessor2_0 by default when torch2.x is used which uses
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+ if processor is None:
+ processor = (
+ AttnProcessor2_0()
+ if hasattr(F, "scaled_dot_product_attention")
+ else CrossAttnProcessor()
+ )
+ self.set_processor(processor)
+
+ def set_use_memory_efficient_attention_xformers(
+ self,
+ use_memory_efficient_attention_xformers: bool,
+ attention_op: Optional[Callable] = None,
+ ):
+ is_lora = hasattr(self, "processor") and isinstance(
+ self.processor, (LoRACrossAttnProcessor, LoRAXFormersCrossAttnProcessor)
+ )
+
+ if use_memory_efficient_attention_xformers:
+ if self.added_kv_proj_dim is not None:
+ # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
+ # which uses this type of cross attention ONLY because the attention mask of format
+ # [0, ..., -10.000, ..., 0, ...,] is not supported
+ raise NotImplementedError(
+ "Memory efficient attention with `xformers` is currently not supported when"
+ " `self.added_kv_proj_dim` is defined."
+ )
+ elif not is_xformers_available():
+ raise ModuleNotFoundError(
+ (
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
+ " xformers"
+ ),
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
+ " only available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ _ = xformers.ops.memory_efficient_attention(
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ )
+ except Exception as e:
+ raise e
+
+ if is_lora:
+ processor = LoRAXFormersCrossAttnProcessor(
+ hidden_size=self.processor.hidden_size,
+ cross_attention_dim=self.processor.cross_attention_dim,
+ rank=self.processor.rank,
+ attention_op=attention_op,
+ )
+ processor.load_state_dict(self.processor.state_dict())
+ processor.to(self.processor.to_q_lora.up.weight.device)
+ else:
+ processor = XFormersCrossAttnProcessor(attention_op=attention_op)
+ else:
+ if is_lora:
+ processor = LoRACrossAttnProcessor(
+ hidden_size=self.processor.hidden_size,
+ cross_attention_dim=self.processor.cross_attention_dim,
+ rank=self.processor.rank,
+ )
+ processor.load_state_dict(self.processor.state_dict())
+ processor.to(self.processor.to_q_lora.up.weight.device)
+ else:
+ processor = CrossAttnProcessor()
+
+ self.set_processor(processor)
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
+ raise ValueError(
+ f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}."
+ )
+
+ if slice_size is not None and self.added_kv_proj_dim is not None:
+ processor = SlicedAttnAddedKVProcessor(slice_size)
+ elif slice_size is not None:
+ processor = SlicedAttnProcessor(slice_size)
+ elif self.added_kv_proj_dim is not None:
+ processor = CrossAttnAddedKVProcessor()
+ else:
+ processor = CrossAttnProcessor()
+
+ self.set_processor(processor)
+
+ def set_processor(self, processor: "AttnProcessor"):
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
+ # pop `processor` from `self._modules`
+ if (
+ hasattr(self, "processor")
+ and isinstance(self.processor, torch.nn.Module)
+ and not isinstance(processor, torch.nn.Module)
+ ):
+ logger.info(
+ f"You are removing possibly trained weights of {self.processor} with {processor}"
+ )
+ self._modules.pop("processor")
+
+ self.processor = processor
+
+ def forward(
+ self,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ **cross_attention_kwargs,
+ ):
+ # The `CrossAttention` class can call different attention processors / attention functions
+ # here we simply pass along all tensors to the selected processor class
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ def batch_to_head_dim(self, tensor):
+ head_size = self.heads
+ batch_size, seq_len, dim = tensor.shape
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(
+ batch_size // head_size, seq_len, dim * head_size
+ )
+ return tensor
+
+ def head_to_batch_dim(self, tensor):
+ head_size = self.heads
+ batch_size, seq_len, dim = tensor.shape
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(
+ batch_size * head_size, seq_len, dim // head_size
+ )
+ return tensor
+
+ def get_attention_scores(self, query, key, attention_mask=None):
+ dtype = query.dtype
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ if attention_mask is None:
+ baddbmm_input = torch.empty(
+ query.shape[0],
+ query.shape[1],
+ key.shape[1],
+ dtype=query.dtype,
+ device=query.device,
+ )
+ beta = 0
+ else:
+ baddbmm_input = attention_mask
+ beta = 1
+
+ attention_scores = torch.baddbmm(
+ baddbmm_input,
+ query,
+ key.transpose(-1, -2),
+ beta=beta,
+ alpha=self.scale,
+ )
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+ attention_probs = attention_probs.to(dtype)
+
+ return attention_probs
+
+ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None):
+ if batch_size is None:
+ deprecate(
+ "batch_size=None",
+ "0.0.15",
+ message=(
+ "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
+ " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
+ " `prepare_attention_mask` when preparing the attention_mask."
+ ),
+ )
+ batch_size = 1
+
+ head_size = self.heads
+ if attention_mask is None:
+ return attention_mask
+
+ if attention_mask.shape[-1] != target_length:
+ if attention_mask.device.type == "mps":
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
+ # Instead, we can manually construct the padding tensor.
+ padding_shape = (
+ attention_mask.shape[0],
+ attention_mask.shape[1],
+ target_length,
+ )
+ padding = torch.zeros(
+ padding_shape,
+ dtype=attention_mask.dtype,
+ device=attention_mask.device,
+ )
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
+ else:
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+
+ if attention_mask.shape[0] < batch_size * head_size:
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
+ return attention_mask
+
+
+class CrossAttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ padding=downsample_padding,
+ name="op",
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ temb=None,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ cross_attention_kwargs=None,
+ ):
+ # TODO(Patrick, William) - attention mask is not used
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ cross_attention_kwargs,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ ).sample
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList(
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
+ )
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ res_hidden_states_tuple,
+ temb=None,
+ encoder_hidden_states=None,
+ cross_attention_kwargs=None,
+ upsample_size=None,
+ attention_mask=None,
+ ):
+ # TODO(Patrick, William) - attention mask is not used
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ cross_attention_kwargs,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ ).sample
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+class DownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ padding=downsample_padding,
+ name="op",
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, temb=None):
+ output_states = ()
+
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class Downsample2D(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+
+ Parameters:
+ channels: channels in the inputs and outputs.
+ use_conv: a bool determining if a convolution is applied.
+ out_channels:
+ padding:
+ """
+
+ def __init__(
+ self, channels, use_conv=False, out_channels=None, padding=1, name="conv"
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.padding = padding
+ stride = 2
+ self.name = name
+
+ if use_conv:
+ conv = nn.Conv2d(
+ self.channels, self.out_channels, 3, stride=stride, padding=padding
+ )
+ else:
+ assert self.channels == self.out_channels
+ conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
+
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
+ if name == "conv":
+ self.Conv2d_0 = conv
+ self.conv = conv
+ elif name == "Conv2d_0":
+ self.conv = conv
+ else:
+ self.conv = conv
+
+ def forward(self, hidden_states):
+ # assert hidden_states.shape[1] == self.channels
+ if self.use_conv and self.padding == 0:
+ pad = (0, 1, 0, 1)
+ hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
+
+ # assert hidden_states.shape[1] == self.channels
+ hidden_states = self.conv(hidden_states)
+
+ return hidden_states
+
+
+class DummyObject(type):
+ """
+ Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
+ `requires_backend` each time a user tries to access any method of that class.
+ """
+
+ def __getattr__(cls, key):
+ if key.startswith("_"):
+ return super().__getattr__(cls, key)
+ requires_backends(cls, cls._backends)
+
+
+class FeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ dim (`int`): The number of channels in the input.
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: Optional[int] = None,
+ mult: int = 4,
+ dropout: float = 0.0,
+ activation_fn: str = "geglu",
+ final_dropout: bool = False,
+ ):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+
+ if activation_fn == "gelu":
+ act_fn = GELU(dim, inner_dim)
+ if activation_fn == "gelu-approximate":
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
+ elif activation_fn == "geglu":
+ act_fn = GEGLU(dim, inner_dim)
+ elif activation_fn == "geglu-approximate":
+ act_fn = ApproximateGELU(dim, inner_dim)
+
+ self.net = nn.ModuleList([])
+ # project in
+ self.net.append(act_fn)
+ # project dropout
+ self.net.append(nn.Dropout(dropout))
+ # project out
+ self.net.append(nn.Linear(inner_dim, dim_out))
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
+ if final_dropout:
+ self.net.append(nn.Dropout(dropout))
+
+ def forward(self, hidden_states):
+ for module in self.net:
+ hidden_states = module(hidden_states)
+ return hidden_states
+
+
+class FrozenDict(OrderedDict):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ for key, value in self.items():
+ setattr(self, key, value)
+
+ self.__frozen = True
+
+ def __delitem__(self, *args, **kwargs):
+ raise Exception(
+ f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance."
+ )
+
+ def setdefault(self, *args, **kwargs):
+ raise Exception(
+ f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance."
+ )
+
+ def pop(self, *args, **kwargs):
+ raise Exception(
+ f"You cannot use ``pop`` on a {self.__class__.__name__} instance."
+ )
+
+ def update(self, *args, **kwargs):
+ raise Exception(
+ f"You cannot use ``update`` on a {self.__class__.__name__} instance."
+ )
+
+ def __setattr__(self, name, value):
+ if hasattr(self, "__frozen") and self.__frozen:
+ raise Exception(
+ f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance."
+ )
+ super().__setattr__(name, value)
+
+ def __setitem__(self, name, value):
+ if hasattr(self, "__frozen") and self.__frozen:
+ raise Exception(
+ f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance."
+ )
+ super().__setitem__(name, value)
+
+
+class GEGLU(nn.Module):
+ r"""
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def gelu(self, gate):
+ # if gate.device.type != "mps":
+ # return F.gelu(gate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
+
+ def forward(self, hidden_states):
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
+ return hidden_states * self.gelu(gate)
+
+
+class ModelMixin(torch.nn.Module):
+ r"""
+ Base class for all models.
+
+ [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
+ and saving models.
+
+ - **config_name** ([`str`]) -- A filename under which the model should be stored when calling
+ [`~models.ModelMixin.save_pretrained`].
+ """
+ config_name = CONFIG_NAME
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
+ _supports_gradient_checkpointing = False
+
+ def __init__(self):
+ super().__init__()
+
+ @property
+ def is_gradient_checkpointing(self) -> bool:
+ """
+ Whether gradient checkpointing is activated for this model or not.
+
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
+ activations".
+ """
+ return any(
+ hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing
+ for m in self.modules()
+ )
+
+ def enable_gradient_checkpointing(self):
+ """
+ Activates gradient checkpointing for the current model.
+
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
+ activations".
+ """
+ if not self._supports_gradient_checkpointing:
+ raise ValueError(
+ f"{self.__class__.__name__} does not support gradient checkpointing."
+ )
+ self.apply(partial(self._set_gradient_checkpointing, value=True))
+
+ def disable_gradient_checkpointing(self):
+ """
+ Deactivates gradient checkpointing for the current model.
+
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
+ activations".
+ """
+ if self._supports_gradient_checkpointing:
+ self.apply(partial(self._set_gradient_checkpointing, value=False))
+
+ def set_use_memory_efficient_attention_xformers(
+ self, valid: bool, attention_op: Optional[Callable] = None
+ ) -> None:
+ # Recursively walk through all the children.
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
+ # gets the message
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
+ module.set_use_memory_efficient_attention_xformers(valid, attention_op)
+
+ for child in module.children():
+ fn_recursive_set_mem_eff(child)
+
+ for module in self.children():
+ if isinstance(module, torch.nn.Module):
+ fn_recursive_set_mem_eff(module)
+
+ def enable_xformers_memory_efficient_attention(
+ self, attention_op: Optional[Callable] = None
+ ):
+ r"""
+ Enable memory efficient attention as implemented in xformers.
+
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
+ time. Speed up at training time is not guaranteed.
+
+ Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
+ is used.
+
+ Parameters:
+ attention_op (`Callable`, *optional*):
+ Override the default `None` operator for use as `op` argument to the
+ [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
+ function of xFormers.
+
+ Examples:
+
+ ```py
+ >>> import torch
+ >>> from diffusers import UNet2DConditionModel
+ >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
+
+ >>> model = UNet2DConditionModel.from_pretrained(
+ ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
+ ... )
+ >>> model = model.to("cuda")
+ >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
+ ```
+ """
+ self.set_use_memory_efficient_attention_xformers(True, attention_op)
+
+ def disable_xformers_memory_efficient_attention(self):
+ r"""
+ Disable memory efficient attention as implemented in xformers.
+ """
+ self.set_use_memory_efficient_attention_xformers(False)
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ is_main_process: bool = True,
+ save_function: Callable = None,
+ safe_serialization: bool = False,
+ variant: Optional[str] = None,
+ ):
+ """
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
+ `[`~models.ModelMixin.from_pretrained`]` class method.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful when in distributed training like
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
+ the main process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
+ need to replace `torch.save` by another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `False`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
+ variant (`str`, *optional*):
+ If specified, weights are saved in the format pytorch_model..bin.
+ """
+ if safe_serialization and not is_safetensors_available():
+ raise ImportError(
+ "`safe_serialization` requires the `safetensors library: `pip install safetensors`."
+ )
+
+ if os.path.isfile(save_directory):
+ logger.error(
+ f"Provided path ({save_directory}) should be a directory, not a file"
+ )
+ return
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ model_to_save = self
+
+ # Attach architecture to the config
+ # Save the config
+ if is_main_process:
+ model_to_save.save_config(save_directory)
+
+ # Save the model
+ state_dict = model_to_save.state_dict()
+
+ weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
+ weights_name = _add_variant(weights_name, variant)
+
+ # Save the model
+ if safe_serialization:
+ safetensors.torch.save_file(
+ state_dict,
+ os.path.join(save_directory, weights_name),
+ metadata={"format": "pt"},
+ )
+ else:
+ torch.save(state_dict, os.path.join(save_directory, weights_name))
+
+ logger.info(
+ f"Model weights saved in {os.path.join(save_directory, weights_name)}"
+ )
+
+ @classmethod
+ def from_pretrained(
+ cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs
+ ):
+ r"""
+ Instantiate a pretrained pytorch model from a pre-trained model configuration.
+
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
+ the model, you should first set it back in training mode with `model.train()`.
+
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+ task.
+
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+ weights are discarded.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
+ `./my_model_directory/`.
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
+ will be automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `diffusers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ from_flax (`bool`, *optional*, defaults to `False`):
+ Load the model weights from a Flax checkpoint save file.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
+ huggingface.co or downloaded locally), you can specify the folder name here.
+
+ mirror (`str`, *optional*):
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+ Please refer to the mirror site for more information.
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
+ same device.
+
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
+ more information about each option see [designing a device
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
+ also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
+ model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
+ setting this argument to `True` will raise an error.
+ variant (`str`, *optional*):
+ If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is
+ ignored when using `from_flax`.
+
+
+
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
+ this method in a firewalled environment.
+
+
+
+ """
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
+ force_download = kwargs.pop("force_download", False)
+ from_flax = kwargs.pop("from_flax", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ output_loading_info = kwargs.pop("output_loading_info", False)
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ torch_dtype = kwargs.pop("torch_dtype", None)
+ subfolder = kwargs.pop("subfolder", None)
+ device_map = kwargs.pop("device_map", None)
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
+ variant = kwargs.pop("variant", None)
+
+ if low_cpu_mem_usage and not is_accelerate_available():
+ low_cpu_mem_usage = False
+ logger.warning(
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+ " install accelerate\n```\n."
+ )
+
+ if device_map is not None and not is_accelerate_available():
+ raise NotImplementedError(
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
+ )
+
+ # Check if we can handle device_map and dispatching the weights
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `device_map=None`."
+ )
+
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `low_cpu_mem_usage=False`."
+ )
+
+ if low_cpu_mem_usage is False and device_map is not None:
+ raise ValueError(
+ f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
+ )
+
+ user_agent = {
+ "diffusers": __version__,
+ "file_type": "model",
+ "framework": "pytorch",
+ }
+
+ # Load config if we don't provide a configuration
+ config_path = pretrained_model_name_or_path
+
+ # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
+ # Load model
+
+ model_file = None
+ if from_flax:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path,
+ weights_name=FLAX_WEIGHTS_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ config, unused_kwargs = cls.load_config(
+ config_path,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ device_map=device_map,
+ **kwargs,
+ )
+ model = cls.from_config(config, **unused_kwargs)
+
+ # Convert the weights
+ from .modeling_pytorch_flax_utils import (
+ load_flax_checkpoint_in_pytorch_model,
+ )
+
+ model = load_flax_checkpoint_in_pytorch_model(model, model_file)
+ else:
+ if is_safetensors_available():
+ try:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path,
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ except: # noqa: E722
+ pass
+ if model_file is None:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path,
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+
+ if low_cpu_mem_usage:
+ # Instantiate model with empty weights
+ with accelerate.init_empty_weights():
+ config, unused_kwargs = cls.load_config(
+ config_path,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ device_map=device_map,
+ **kwargs,
+ )
+ model = cls.from_config(config, **unused_kwargs)
+
+ # if device_map is None, load the state dict and move the params from meta device to the cpu
+ if device_map is None:
+ param_device = "cpu"
+ state_dict = load_state_dict(model_file, variant=variant)
+ # move the params from meta device to cpu
+ missing_keys = set(model.state_dict().keys()) - set(
+ state_dict.keys()
+ )
+ if len(missing_keys) > 0:
+ raise ValueError(
+ f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize"
+ " those weights or else make sure your checkpoint file is correct."
+ )
+
+ for param_name, param in state_dict.items():
+ accepts_dtype = "dtype" in set(
+ inspect.signature(
+ set_module_tensor_to_device
+ ).parameters.keys()
+ )
+ if accepts_dtype:
+ set_module_tensor_to_device(
+ model,
+ param_name,
+ param_device,
+ value=param,
+ dtype=torch_dtype,
+ )
+ else:
+ set_module_tensor_to_device(
+ model, param_name, param_device, value=param
+ )
+ else: # else let accelerate handle loading and dispatching.
+ # Load weights and dispatch according to the device_map
+ # by deafult the device_map is None and the weights are loaded on the CPU
+ accelerate.load_checkpoint_and_dispatch(
+ model, model_file, device_map, dtype=torch_dtype
+ )
+
+ loading_info = {
+ "missing_keys": [],
+ "unexpected_keys": [],
+ "mismatched_keys": [],
+ "error_msgs": [],
+ }
+ else:
+ config, unused_kwargs = cls.load_config(
+ config_path,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ device_map=device_map,
+ **kwargs,
+ )
+ model = cls.from_config(config, **unused_kwargs)
+
+ state_dict = load_state_dict(model_file, variant=variant)
+
+ (
+ model,
+ missing_keys,
+ unexpected_keys,
+ mismatched_keys,
+ error_msgs,
+ ) = cls._load_pretrained_model(
+ model,
+ state_dict,
+ model_file,
+ pretrained_model_name_or_path,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ )
+
+ loading_info = {
+ "missing_keys": missing_keys,
+ "unexpected_keys": unexpected_keys,
+ "mismatched_keys": mismatched_keys,
+ "error_msgs": error_msgs,
+ }
+
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
+ raise ValueError(
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
+ )
+ elif torch_dtype is not None:
+ model = model.to(torch_dtype)
+
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
+
+ # Set model in evaluation mode to deactivate DropOut modules by default
+ model.eval()
+ if output_loading_info:
+ return model, loading_info
+
+ return model
+
+ @classmethod
+ def _load_pretrained_model(
+ cls,
+ model,
+ state_dict,
+ resolved_archive_file,
+ pretrained_model_name_or_path,
+ ignore_mismatched_sizes=False,
+ ):
+ # Retrieve missing & unexpected_keys
+ model_state_dict = model.state_dict()
+ loaded_keys = [k for k in state_dict.keys()]
+
+ expected_keys = list(model_state_dict.keys())
+
+ original_loaded_keys = loaded_keys
+
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
+
+ # Make sure we are able to load base models as well as derived models (with heads)
+ model_to_load = model
+
+ def _find_mismatched_keys(
+ state_dict,
+ model_state_dict,
+ loaded_keys,
+ ignore_mismatched_sizes,
+ ):
+ mismatched_keys = []
+ if ignore_mismatched_sizes:
+ for checkpoint_key in loaded_keys:
+ model_key = checkpoint_key
+
+ if (
+ model_key in model_state_dict
+ and state_dict[checkpoint_key].shape
+ != model_state_dict[model_key].shape
+ ):
+ mismatched_keys.append(
+ (
+ checkpoint_key,
+ state_dict[checkpoint_key].shape,
+ model_state_dict[model_key].shape,
+ )
+ )
+ del state_dict[checkpoint_key]
+ return mismatched_keys
+
+ if state_dict is not None:
+ # Whole checkpoint
+ mismatched_keys = _find_mismatched_keys(
+ state_dict,
+ model_state_dict,
+ original_loaded_keys,
+ ignore_mismatched_sizes,
+ )
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
+
+ if len(error_msgs) > 0:
+ error_msg = "\n\t".join(error_msgs)
+ if "size mismatch" in error_msg:
+ error_msg += "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
+ raise RuntimeError(
+ f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}"
+ )
+
+ if len(unexpected_keys) > 0:
+ logger.warning(
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
+ " identical (initializing a BertForSequenceClassification model from a"
+ " BertForSequenceClassification model)."
+ )
+ else:
+ logger.info(
+ f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
+ )
+ if len(missing_keys) > 0:
+ logger.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
+ )
+ elif len(mismatched_keys) == 0:
+ logger.info(
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
+ " without further training."
+ )
+ if len(mismatched_keys) > 0:
+ mismatched_warning = "\n".join(
+ [
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
+ for key, shape1, shape2 in mismatched_keys
+ ]
+ )
+ logger.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
+ " able to use it for predictions and inference."
+ )
+
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
+
+ @property
+ def device(self) -> device:
+ """
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
+ device).
+ """
+ return get_parameter_device(self)
+
+ @property
+ def dtype(self) -> torch.dtype:
+ """
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
+ """
+ return get_parameter_dtype(self)
+
+ def num_parameters(
+ self, only_trainable: bool = False, exclude_embeddings: bool = False
+ ) -> int:
+ """
+ Get number of (optionally, trainable or non-embeddings) parameters in the module.
+
+ Args:
+ only_trainable (`bool`, *optional*, defaults to `False`):
+ Whether or not to return only the number of trainable parameters
+
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether or not to return only the number of non-embeddings parameters
+
+ Returns:
+ `int`: The number of parameters.
+ """
+
+ if exclude_embeddings:
+ embedding_param_names = [
+ f"{name}.weight"
+ for name, module_type in self.named_modules()
+ if isinstance(module_type, torch.nn.Embedding)
+ ]
+ non_embedding_parameters = [
+ parameter
+ for name, parameter in self.named_parameters()
+ if name not in embedding_param_names
+ ]
+ return sum(
+ p.numel()
+ for p in non_embedding_parameters
+ if p.requires_grad or not only_trainable
+ )
+ else:
+ return sum(
+ p.numel()
+ for p in self.parameters()
+ if p.requires_grad or not only_trainable
+ )
+
+
+class OptionalDependencyNotAvailable(BaseException):
+ """An error indicating that an optional dependency of Diffusers was not found in the environment."""
+
+
+class ResnetBlock2D(nn.Module):
+ r"""
+ A Resnet block.
+
+ Parameters:
+ in_channels (`int`): The number of channels in the input.
+ out_channels (`int`, *optional*, default to be `None`):
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
+ groups_out (`int`, *optional*, default to None):
+ The number of groups to use for the second normalization layer. if set to None, same as `groups`.
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
+ non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
+ time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
+ By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
+ "ada_group" for a stronger conditioning with scale and shift.
+ kernal (`torch.FloatTensor`, optional, default to None): FIR filter, see
+ [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
+ output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
+ use_in_shortcut (`bool`, *optional*, default to `True`):
+ If `True`, add a 1x1 nn.conv2d layer for skip-connection.
+ up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
+ down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
+ conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
+ `conv_shortcut` output.
+ conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
+ If None, same as `out_channels`.
+ """
+
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout=0.0,
+ temb_channels=512,
+ groups=32,
+ groups_out=None,
+ pre_norm=True,
+ eps=1e-6,
+ non_linearity="swish",
+ time_embedding_norm="default", # default, scale_shift, ada_group
+ kernel=None,
+ output_scale_factor=1.0,
+ use_in_shortcut=None,
+ up=False,
+ down=False,
+ conv_shortcut_bias: bool = True,
+ conv_2d_out_channels: Optional[int] = None,
+ ):
+ super().__init__()
+ self.pre_norm = pre_norm
+ self.pre_norm = True
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+ self.up = up
+ self.down = down
+ self.output_scale_factor = output_scale_factor
+ self.time_embedding_norm = time_embedding_norm
+
+ if groups_out is None:
+ groups_out = groups
+
+ if self.time_embedding_norm == "ada_group":
+ self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
+ else:
+ self.norm1 = torch.nn.GroupNorm(
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
+ )
+
+ self.conv1 = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ if temb_channels is not None:
+ if self.time_embedding_norm == "default":
+ self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
+ elif self.time_embedding_norm == "scale_shift":
+ self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
+ elif self.time_embedding_norm == "ada_group":
+ self.time_emb_proj = None
+ else:
+ raise ValueError(
+ f"unknown time_embedding_norm : {self.time_embedding_norm} "
+ )
+ else:
+ self.time_emb_proj = None
+
+ if self.time_embedding_norm == "ada_group":
+ self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
+ else:
+ self.norm2 = torch.nn.GroupNorm(
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
+ )
+
+ self.dropout = torch.nn.Dropout(dropout)
+ conv_2d_out_channels = conv_2d_out_channels or out_channels
+ self.conv2 = torch.nn.Conv2d(
+ out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ if non_linearity == "swish":
+ self.nonlinearity = lambda x: F.silu(x)
+ elif non_linearity == "mish":
+ self.nonlinearity = nn.Mish()
+ elif non_linearity == "silu":
+ self.nonlinearity = nn.SiLU()
+ elif non_linearity == "gelu":
+ self.nonlinearity = nn.GELU()
+
+ self.upsample = self.downsample = None
+ if self.up:
+ if kernel == "fir":
+ fir_kernel = (1, 3, 3, 1)
+ self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
+ elif kernel == "sde_vp":
+ self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
+ else:
+ self.upsample = Upsample2D(in_channels, use_conv=False)
+ elif self.down:
+ if kernel == "fir":
+ fir_kernel = (1, 3, 3, 1)
+ self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
+ elif kernel == "sde_vp":
+ self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
+ else:
+ self.downsample = Downsample2D(
+ in_channels, use_conv=False, padding=1, name="op"
+ )
+
+ self.use_in_shortcut = (
+ self.in_channels != conv_2d_out_channels
+ if use_in_shortcut is None
+ else use_in_shortcut
+ )
+
+ self.conv_shortcut = None
+ if self.use_in_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(
+ in_channels,
+ conv_2d_out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=conv_shortcut_bias,
+ )
+
+ def forward(self, input_tensor, temb):
+ hidden_states = input_tensor
+
+ if self.time_embedding_norm == "ada_group":
+ hidden_states = self.norm1(hidden_states, temb)
+ else:
+ hidden_states = self.norm1(hidden_states)
+
+ hidden_states = self.nonlinearity(hidden_states)
+
+ if self.upsample is not None:
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
+ if hidden_states.shape[0] >= 64:
+ input_tensor = input_tensor.contiguous()
+ hidden_states = hidden_states.contiguous()
+ input_tensor = self.upsample(input_tensor)
+ hidden_states = self.upsample(hidden_states)
+ elif self.downsample is not None:
+ input_tensor = self.downsample(input_tensor)
+ hidden_states = self.downsample(hidden_states)
+
+ hidden_states = self.conv1(hidden_states)
+
+ if self.time_emb_proj is not None:
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
+
+ if temb is not None and self.time_embedding_norm == "default":
+ hidden_states = hidden_states + temb
+
+ if self.time_embedding_norm == "ada_group":
+ hidden_states = self.norm2(hidden_states, temb)
+ else:
+ hidden_states = self.norm2(hidden_states)
+
+ if temb is not None and self.time_embedding_norm == "scale_shift":
+ scale, shift = torch.chunk(temb, 2, dim=1)
+ hidden_states = hidden_states * (1 + scale) + shift
+
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ input_tensor = self.conv_shortcut(input_tensor)
+
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
+
+ return output_tensor
+
+
+class TimestepEmbedding(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ time_embed_dim: int,
+ act_fn: str = "silu",
+ out_dim: int = None,
+ post_act_fn: Optional[str] = None,
+ cond_proj_dim=None,
+ ):
+ super().__init__()
+
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
+
+ if cond_proj_dim is not None:
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
+ else:
+ self.cond_proj = None
+
+ if act_fn == "silu":
+ self.act = nn.SiLU()
+ elif act_fn == "mish":
+ self.act = nn.Mish()
+ elif act_fn == "gelu":
+ self.act = nn.GELU()
+ else:
+ raise ValueError(
+ f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'"
+ )
+
+ if out_dim is not None:
+ time_embed_dim_out = out_dim
+ else:
+ time_embed_dim_out = time_embed_dim
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
+
+ if post_act_fn is None:
+ self.post_act = None
+ elif post_act_fn == "silu":
+ self.post_act = nn.SiLU()
+ elif post_act_fn == "mish":
+ self.post_act = nn.Mish()
+ elif post_act_fn == "gelu":
+ self.post_act = nn.GELU()
+ else:
+ raise ValueError(
+ f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'"
+ )
+
+ def forward(self, sample, condition=None):
+ if condition is not None:
+ sample = sample + self.cond_proj(condition)
+ sample = self.linear_1(sample)
+
+ if self.act is not None:
+ sample = self.act(sample)
+
+ sample = self.linear_2(sample)
+
+ if self.post_act is not None:
+ sample = self.post_act(sample)
+ return sample
+
+
+class Timesteps(nn.Module):
+ def __init__(
+ self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float
+ ):
+ super().__init__()
+ self.num_channels = num_channels
+ self.flip_sin_to_cos = flip_sin_to_cos
+ self.downscale_freq_shift = downscale_freq_shift
+
+ def forward(self, timesteps):
+ t_emb = get_timestep_embedding(
+ timesteps,
+ self.num_channels,
+ flip_sin_to_cos=self.flip_sin_to_cos,
+ downscale_freq_shift=self.downscale_freq_shift,
+ )
+ return t_emb
+
+
+class Transformer2DModel(ModelMixin, ConfigMixin):
+ """
+ Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
+ embeddings) inputs.
+
+ When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
+ transformer action. Finally, reshape to image.
+
+ When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
+ embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
+ classes of unnoised image.
+
+ Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
+ image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ Pass if the input is continuous. The number of channels in the input and output.
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
+ `ImagePositionalEmbeddings`.
+ num_vector_embeds (`int`, *optional*):
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
+ Includes the class for the masked latent pixel.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
+ up to but not more than steps than `num_embeds_ada_norm`.
+ attention_bias (`bool`, *optional*):
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ out_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ sample_size: Optional[int] = None,
+ num_vector_embeds: Optional[int] = None,
+ patch_size: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_type: str = "layer_norm",
+ norm_elementwise_affine: bool = True,
+ ):
+ super().__init__()
+ self.use_linear_projection = use_linear_projection
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ inner_dim = num_attention_heads * attention_head_dim
+
+ # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
+ # Define whether input is continuous or discrete depending on configuration
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
+ self.is_input_vectorized = num_vector_embeds is not None
+ self.is_input_patches = in_channels is not None and patch_size is not None
+
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
+ deprecation_message = (
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
+ )
+ deprecate(
+ "norm_type!=num_embeds_ada_norm",
+ "1.0.0",
+ deprecation_message,
+ standard_warn=False,
+ )
+ norm_type = "ada_norm"
+
+ if self.is_input_continuous and self.is_input_vectorized:
+ raise ValueError(
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
+ " sure that either `in_channels` or `num_vector_embeds` is None."
+ )
+ elif self.is_input_vectorized and self.is_input_patches:
+ raise ValueError(
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
+ " sure that either `num_vector_embeds` or `num_patches` is None."
+ )
+ elif (
+ not self.is_input_continuous
+ and not self.is_input_vectorized
+ and not self.is_input_patches
+ ):
+ raise ValueError(
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
+ )
+
+ # 2. Define input layers
+ if self.is_input_continuous:
+ self.in_channels = in_channels
+
+ self.norm = torch.nn.GroupNorm(
+ num_groups=norm_num_groups,
+ num_channels=in_channels,
+ eps=1e-6,
+ affine=True,
+ )
+ if use_linear_projection:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+ else:
+ self.proj_in = nn.Conv2d(
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
+ )
+ elif self.is_input_vectorized:
+ assert (
+ sample_size is not None
+ ), "Transformer2DModel over discrete input must provide sample_size"
+ assert (
+ num_vector_embeds is not None
+ ), "Transformer2DModel over discrete input must provide num_embed"
+
+ self.height = sample_size
+ self.width = sample_size
+ self.num_vector_embeds = num_vector_embeds
+ self.num_latent_pixels = self.height * self.width
+
+ self.latent_image_embedding = ImagePositionalEmbeddings(
+ num_embed=num_vector_embeds,
+ embed_dim=inner_dim,
+ height=self.height,
+ width=self.width,
+ )
+ elif self.is_input_patches:
+ assert (
+ sample_size is not None
+ ), "Transformer2DModel over patched input must provide sample_size"
+
+ self.height = sample_size
+ self.width = sample_size
+
+ self.patch_size = patch_size
+ self.pos_embed = PatchEmbed(
+ height=sample_size,
+ width=sample_size,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=inner_dim,
+ )
+
+ # 3. Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=norm_elementwise_affine,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ # 4. Define output layers
+ self.out_channels = in_channels if out_channels is None else out_channels
+ if self.is_input_continuous:
+ # TODO: should use out_channels for continous projections
+ if use_linear_projection:
+ self.proj_out = nn.Linear(inner_dim, in_channels)
+ else:
+ self.proj_out = nn.Conv2d(
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ elif self.is_input_vectorized:
+ self.norm_out = nn.LayerNorm(inner_dim)
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
+ elif self.is_input_patches:
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
+ self.proj_out_2 = nn.Linear(
+ inner_dim, patch_size * patch_size * self.out_channels
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ encoder_hidden_states=None,
+ timestep=None,
+ class_labels=None,
+ cross_attention_kwargs=None,
+ return_dict: bool = True,
+ ):
+ """
+ Args:
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
+ When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
+ hidden_states
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ timestep ( `torch.long`, *optional*):
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
+ Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
+ conditioning.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
+ [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ # 1. Input
+ if self.is_input_continuous:
+ batch, _, height, width = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ if not self.use_linear_projection:
+ hidden_states = self.proj_in(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
+ batch, height * width, inner_dim
+ )
+ else:
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
+ batch, height * width, inner_dim
+ )
+ hidden_states = self.proj_in(hidden_states)
+ elif self.is_input_vectorized:
+ hidden_states = self.latent_image_embedding(hidden_states)
+ elif self.is_input_patches:
+ hidden_states = self.pos_embed(hidden_states)
+
+ # 2. Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ timestep=timestep,
+ cross_attention_kwargs=cross_attention_kwargs,
+ class_labels=class_labels,
+ )
+
+ # 3. Output
+ if self.is_input_continuous:
+ if not self.use_linear_projection:
+ hidden_states = (
+ hidden_states.reshape(batch, height, width, inner_dim)
+ .permute(0, 3, 1, 2)
+ .contiguous()
+ )
+ hidden_states = self.proj_out(hidden_states)
+ else:
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = (
+ hidden_states.reshape(batch, height, width, inner_dim)
+ .permute(0, 3, 1, 2)
+ .contiguous()
+ )
+
+ output = hidden_states + residual
+ elif self.is_input_vectorized:
+ hidden_states = self.norm_out(hidden_states)
+ logits = self.out(hidden_states)
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
+ logits = logits.permute(0, 2, 1)
+
+ # log(p(x_0))
+ output = F.log_softmax(logits.double(), dim=1).float()
+ elif self.is_input_patches:
+ # TODO: cleanup!
+ conditioning = self.transformer_blocks[0].norm1.emb(
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
+ hidden_states = (
+ self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
+ )
+ hidden_states = self.proj_out_2(hidden_states)
+
+ # unpatchify
+ height = width = int(hidden_states.shape[1] ** 0.5)
+ hidden_states = hidden_states.reshape(
+ shape=(
+ -1,
+ height,
+ width,
+ self.patch_size,
+ self.patch_size,
+ self.out_channels,
+ )
+ )
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(
+ -1,
+ self.out_channels,
+ height * self.patch_size,
+ width * self.patch_size,
+ )
+ )
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
+
+
+class UNet2DConditionLoadersMixin:
+ def load_attn_procs(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be
+ defined in
+ [cross_attention.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py)
+ and be a `torch.nn.Module` class.
+
+
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
+ `./my_model_directory/`.
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `diffusers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
+ huggingface.co or downloaded locally), you can specify the folder name here.
+
+ mirror (`str`, *optional*):
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+ Please refer to the mirror site for more information.
+
+
+
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
+ this method in a firewalled environment.
+
+
+ """
+
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", LORA_WEIGHT_NAME)
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ state_dict = torch.load(model_file, map_location="cpu")
+ else:
+ state_dict = pretrained_model_name_or_path_or_dict
+
+ # fill attn processors
+ attn_processors = {}
+
+ is_lora = all("lora" in k for k in state_dict.keys())
+
+ if is_lora:
+ lora_grouped_dict = defaultdict(dict)
+ for key, value in state_dict.items():
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(
+ key.split(".")[-3:]
+ )
+ lora_grouped_dict[attn_processor_key][sub_key] = value
+
+ for key, value_dict in lora_grouped_dict.items():
+ rank = value_dict["to_k_lora.down.weight"].shape[0]
+ cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
+ hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
+
+ attn_processors[key] = LoRACrossAttnProcessor(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ rank=rank,
+ )
+ attn_processors[key].load_state_dict(value_dict)
+
+ else:
+ raise ValueError(
+ f"{model_file} does not seem to be in the correct format expected by LoRA training."
+ )
+
+ # set correct dtype & device
+ attn_processors = {
+ k: v.to(device=self.device, dtype=self.dtype)
+ for k, v in attn_processors.items()
+ }
+
+ # set layers
+ self.set_attn_processor(attn_processors)
+
+ def save_attn_procs(
+ self,
+ save_directory: Union[str, os.PathLike],
+ is_main_process: bool = True,
+ weights_name: str = LORA_WEIGHT_NAME,
+ save_function: Callable = None,
+ ):
+ r"""
+ Save an attention processor to a directory, so that it can be re-loaded using the
+ `[`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`]` method.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful when in distributed training like
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
+ the main process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
+ need to replace `torch.save` by another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ """
+ if os.path.isfile(save_directory):
+ logger.error(
+ f"Provided path ({save_directory}) should be a directory, not a file"
+ )
+ return
+
+ if save_function is None:
+ save_function = torch.save
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ model_to_save = AttnProcsLayers(self.attn_processors)
+
+ # Save the model
+ state_dict = model_to_save.state_dict()
+
+ # Clean the folder from a previous save
+ for filename in os.listdir(save_directory):
+ full_filename = os.path.join(save_directory, filename)
+ # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
+ # in distributed settings to avoid race conditions.
+ weights_no_suffix = weights_name.replace(".bin", "")
+ if (
+ filename.startswith(weights_no_suffix)
+ and os.path.isfile(full_filename)
+ and is_main_process
+ ):
+ os.remove(full_filename)
+
+ # Save the model
+ save_function(state_dict, os.path.join(save_directory, weights_name))
+
+ logger.info(
+ f"Model weights saved in {os.path.join(save_directory, weights_name)}"
+ )
+
+
+class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
+ r"""
+ UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
+ and returns sample shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the models (such as downloading or saving, etc.)
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
+ The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the
+ mid block layer if `None`.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
+ The tuple of upsample blocks to use.
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
+ Whether to include self-attention in the basic transformer blocks, see
+ [`~models.attention.BasicTransformerBlock`].
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
+ If `None`, it will skip the normalization and activation layers in post-processing
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
+ for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
+ class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
+ summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, or `"projection"`.
+ num_class_embeds (`int`, *optional*, defaults to None):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ time_embedding_type (`str`, *optional*, default to `positional`):
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
+ timestep_post_act (`str, *optional*, default to `None`):
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
+ time_cond_proj_dim (`int`, *optional*, default to `None`):
+ The dimension of `cond_proj` layer in timestep embedding.
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
+ using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
+ up_block_types: Tuple[str] = (
+ "UpBlock2D",
+ "CrossAttnUpBlock2D",
+ "CrossAttnUpBlock2D",
+ "CrossAttnUpBlock2D",
+ ),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ time_embedding_type: str = "positional",
+ timestep_post_act: Optional[str] = None,
+ time_cond_proj_dim: Optional[int] = None,
+ conv_in_kernel: int = 3,
+ conv_out_kernel: int = 3,
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(
+ only_cross_attention
+ ) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
+ down_block_types
+ ):
+ raise ValueError(
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ # input
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels,
+ block_out_channels[0],
+ kernel_size=conv_in_kernel,
+ padding=conv_in_padding,
+ )
+
+ # time
+ if time_embedding_type == "fourier":
+ time_embed_dim = block_out_channels[0] * 2
+ if time_embed_dim % 2 != 0:
+ raise ValueError(
+ f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
+ )
+ self.time_proj = GaussianFourierProjection(
+ time_embed_dim // 2,
+ set_W_to_weight=False,
+ log=False,
+ flip_sin_to_cos=flip_sin_to_cos,
+ )
+ timestep_input_dim = time_embed_dim
+ elif time_embedding_type == "positional":
+ time_embed_dim = block_out_channels[0] * 4
+
+ self.time_proj = Timesteps(
+ block_out_channels[0], flip_sin_to_cos, freq_shift
+ )
+ timestep_input_dim = block_out_channels[0]
+ else:
+ raise ValueError(
+ f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`."
+ )
+
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ post_act_fn=timestep_post_act,
+ cond_proj_dim=time_cond_proj_dim,
+ )
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(
+ projection_class_embeddings_input_dim, time_embed_dim
+ )
+ else:
+ self.class_embedding = None
+
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim[-1],
+ resnet_groups=norm_num_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif mid_block_type is None:
+ self.mid_block = None
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
+ only_cross_attention = list(reversed(only_cross_attention))
+
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[
+ min(i + 1, len(block_out_channels) - 1)
+ ]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=reversed_attention_head_dim[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ if norm_num_groups is not None:
+ self.conv_norm_out = nn.GroupNorm(
+ num_channels=block_out_channels[0],
+ num_groups=norm_num_groups,
+ eps=norm_eps,
+ )
+ self.conv_act = nn.SiLU()
+ else:
+ self.conv_norm_out = None
+ self.conv_act = None
+
+ conv_out_padding = (conv_out_kernel - 1) // 2
+ self.conv_out = nn.Conv2d(
+ block_out_channels[0],
+ out_channels,
+ kernel_size=conv_out_kernel,
+ padding=conv_out_padding,
+ )
+
+ @property
+ def attn_processors(self) -> Dict[str, AttnProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(
+ name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]
+ ):
+ if hasattr(module, "set_processor"):
+ processors[f"{name}.processor"] = module.processor
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(
+ self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]
+ ):
+ r"""
+ Parameters:
+ `processor (`dict` of `AttnProcessor` or `AttnProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ of **all** `CrossAttention` layers.
+ In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_attention_slice(self, slice_size):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_slicable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_slicable_dims(module)
+
+ num_slicable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_slicable_layers * [1]
+
+ slice_size = (
+ num_slicable_layers * [slice_size]
+ if not isinstance(slice_size, list)
+ else slice_size
+ )
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(
+ module: torch.nn.Module, slice_size: List[int]
+ ):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(
+ module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)
+ ):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[UNet2DConditionOutput, Tuple]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2 ** self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ logger.info("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=self.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError(
+ "class_labels should be provided when num_class_embeds > 0"
+ )
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
+ emb = emb + class_emb
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if (
+ hasattr(downsample_block, "has_cross_attention")
+ and downsample_block.has_cross_attention
+ ):
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ if down_block_additional_residuals is not None:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample += down_block_additional_residual
+ new_down_block_res_samples += (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+
+ if mid_block_additional_residual is not None:
+ sample += mid_block_additional_residual
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[
+ : -len(upsample_block.resnets)
+ ]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if (
+ hasattr(upsample_block, "has_cross_attention")
+ and upsample_block.has_cross_attention
+ ):
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ upsample_size=upsample_size,
+ )
+
+ # 6. post-process
+ if self.conv_norm_out:
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ # if not return_dict:
+ return (sample,)
+
+ return UNet2DConditionOutput(sample=sample)
+
+
+class UNetMidBlock2DCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ upcast_attention=False,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+ resnet_groups = (
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+ )
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ attn_num_head_channels,
+ in_channels // attn_num_head_channels,
+ in_channels=in_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ attn_num_head_channels,
+ in_channels // attn_num_head_channels,
+ in_channels=in_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(
+ self,
+ hidden_states,
+ temb=None,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ cross_attention_kwargs=None,
+ ):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ ).sample
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class UpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList(
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
+ )
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None
+ ):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+class Upsample2D(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+
+ Parameters:
+ channels: channels in the inputs and outputs.
+ use_conv: a bool determining if a convolution is applied.
+ use_conv_transpose:
+ out_channels:
+ """
+
+ def __init__(
+ self,
+ channels,
+ use_conv=False,
+ use_conv_transpose=False,
+ out_channels=None,
+ name="conv",
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+ self.name = name
+
+ conv = None
+ if use_conv_transpose:
+ conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
+ elif use_conv:
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
+
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
+ if name == "conv":
+ self.conv = conv
+ else:
+ self.Conv2d_0 = conv
+
+ def forward(self, hidden_states, output_size=None):
+ # assert hidden_states.shape[1] == self.channels
+
+ if self.use_conv_transpose:
+ return self.conv(hidden_states)
+
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
+ # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
+ # https://github.com/pytorch/pytorch/issues/86679
+ dtype = hidden_states.dtype
+ # if dtype == torch.bfloat16:
+ # hidden_states = hidden_states.to(torch.float32)
+
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
+ # if hidden_states.shape[0] >= 64:
+ # hidden_states = hidden_states.contiguous()
+
+ # if `output_size` is passed we force the interpolation output
+ # size and do not make use of `scale_factor=2`
+ if output_size is None:
+ hidden_states = F.interpolate(
+ hidden_states, scale_factor=2.0, mode="nearest"
+ )
+ else:
+ hidden_states = F.interpolate(
+ hidden_states, size=output_size, mode="nearest"
+ )
+
+ # If the input is bfloat16, we cast back to bfloat16
+ # if dtype == torch.bfloat16:
+ # hidden_states = hidden_states.to(dtype)
+
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
+ if self.use_conv:
+ if self.name == "conv":
+ hidden_states = self.conv(hidden_states)
+ else:
+ hidden_states = self.Conv2d_0(hidden_states)
+
+ return hidden_states
+
+
+class _tqdm_cls:
+ def __call__(self, *args, **kwargs):
+ if _tqdm_active:
+ return tqdm_lib.tqdm(*args, **kwargs)
+ else:
+ return EmptyTqdm(*args, **kwargs)
+
+ def set_lock(self, *args, **kwargs):
+ self._lock = None
+ if _tqdm_active:
+ return tqdm_lib.tqdm.set_lock(*args, **kwargs)
+
+ def get_lock(self):
+ if _tqdm_active:
+ return tqdm_lib.tqdm.get_lock()
+
+
+def _configure_library_root_logger() -> None:
+ global _default_handler
+
+ with _lock:
+ if _default_handler:
+ # This library has already configured the library root logger.
+ return
+ _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
+ _default_handler.flush = sys.stderr.flush
+
+ # Apply our default configuration to the library root logger.
+ library_root_logger = _get_library_root_logger()
+ library_root_logger.addHandler(_default_handler)
+ library_root_logger.setLevel(_get_default_logging_level())
+ library_root_logger.propagate = False
+
+
+def _get_default_logging_level():
+ """
+ If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
+ not - fall back to `_default_log_level`
+ """
+ env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None)
+ if env_level_str:
+ if env_level_str in log_levels:
+ return log_levels[env_level_str]
+ else:
+ logging.getLogger().warning(
+ f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, "
+ f"has to be one of: { ', '.join(log_levels.keys()) }"
+ )
+ return _default_log_level
+
+
+def _get_library_name() -> str:
+ return __name__.split(".")[0]
+
+
+def _get_library_root_logger() -> logging.Logger:
+ return logging.getLogger(_get_library_name())
+
+
+def compare_versions(
+ library_or_version: Union[str, Version], operation: str, requirement_version: str
+):
+ """
+ Args:
+ Compares a library version to some requirement using a given operation.
+ library_or_version (`str` or `packaging.version.Version`):
+ A library name or a version to check.
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`.
+ requirement_version (`str`):
+ The version to compare the library version against
+ """
+ if operation not in STR_OPERATION_TO_FUNC.keys():
+ raise ValueError(
+ f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}"
+ )
+ operation = STR_OPERATION_TO_FUNC[operation]
+ if isinstance(library_or_version, str):
+ library_or_version = parse(importlib_metadata.version(library_or_version))
+ return operation(library_or_version, parse(requirement_version))
+
+
+def get_down_block(
+ down_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ temb_channels,
+ add_downsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ downsample_padding=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+):
+ down_block_type = (
+ down_block_type[7:]
+ if down_block_type.startswith("UNetRes")
+ else down_block_type
+ )
+ if down_block_type == "DownBlock2D":
+ return DownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "ResnetDownsampleBlock2D":
+ return ResnetDownsampleBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "AttnDownBlock2D":
+ return AttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ attn_num_head_channels=attn_num_head_channels,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "CrossAttnDownBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError(
+ "cross_attention_dim must be specified for CrossAttnDownBlock2D"
+ )
+ return CrossAttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "SimpleCrossAttnDownBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError(
+ "cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D"
+ )
+ return SimpleCrossAttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "SkipDownBlock2D":
+ return SkipDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "AttnSkipDownBlock2D":
+ return AttnSkipDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ attn_num_head_channels=attn_num_head_channels,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "DownEncoderBlock2D":
+ return DownEncoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "AttnDownEncoderBlock2D":
+ return AttnDownEncoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ attn_num_head_channels=attn_num_head_channels,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "KDownBlock2D":
+ return KDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ )
+ elif down_block_type == "KCrossAttnDownBlock2D":
+ return KCrossAttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ add_self_attention=True if not add_downsample else False,
+ )
+ raise ValueError(f"{down_block_type} does not exist.")
+
+
+def get_logger(name: Optional[str] = None) -> logging.Logger:
+ """
+ Return a logger with the specified name.
+
+ This function is not supposed to be directly accessed unless you are writing a custom diffusers module.
+ """
+
+ if name is None:
+ name = _get_library_name()
+
+ _configure_library_root_logger()
+ return logging.getLogger(name)
+
+
+def get_parameter_dtype(parameter: torch.nn.Module):
+ try:
+ return next(parameter.parameters()).dtype
+ except StopIteration:
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
+
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+ return tuples
+
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
+ first_tuple = next(gen)
+ return first_tuple[1].dtype
+
+
+def get_timestep_embedding(
+ timesteps: torch.Tensor,
+ embedding_dim: int,
+ flip_sin_to_cos: bool = False,
+ downscale_freq_shift: float = 1,
+ scale: float = 1,
+ max_period: int = 10000,
+):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
+ """
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
+
+ half_dim = embedding_dim // 2
+ exponent = -math.log(max_period) * torch.arange(
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
+ )
+ exponent = exponent / (half_dim - downscale_freq_shift)
+
+ emb = torch.exp(exponent)
+ emb = timesteps[:, None].float() * emb[None, :]
+
+ # scale embeddings
+ emb = scale * emb
+
+ # concat sine and cosine embeddings
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
+
+ # flip sine and cosine embeddings
+ if flip_sin_to_cos:
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
+
+ # zero pad
+ if embedding_dim % 2 == 1:
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+def get_up_block(
+ up_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ prev_output_channel,
+ temb_channels,
+ add_upsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+):
+ up_block_type = (
+ up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
+ )
+ if up_block_type == "UpBlock2D":
+ return UpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "ResnetUpsampleBlock2D":
+ return ResnetUpsampleBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "CrossAttnUpBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError(
+ "cross_attention_dim must be specified for CrossAttnUpBlock2D"
+ )
+ return CrossAttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "SimpleCrossAttnUpBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError(
+ "cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D"
+ )
+ return SimpleCrossAttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "AttnUpBlock2D":
+ return AttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ attn_num_head_channels=attn_num_head_channels,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "SkipUpBlock2D":
+ return SkipUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "AttnSkipUpBlock2D":
+ return AttnSkipUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ attn_num_head_channels=attn_num_head_channels,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "UpDecoderBlock2D":
+ return UpDecoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "AttnUpDecoderBlock2D":
+ return AttnUpDecoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ attn_num_head_channels=attn_num_head_channels,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "KUpBlock2D":
+ return KUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ )
+ elif up_block_type == "KCrossAttnUpBlock2D":
+ return KCrossAttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+def interpolate(
+ input: Tensor,
+ size: Optional[int] = None,
+ scale_factor: Optional[List[float]] = None,
+ mode: str = "nearest",
+ align_corners: Optional[bool] = None,
+ recompute_scale_factor: Optional[bool] = None,
+ antialias: bool = False,
+) -> Tensor: # noqa: F811
+ r"""Down/up samples the input to either the given :attr:`size` or the given
+ :attr:`scale_factor`
+
+ The algorithm used for interpolation is determined by :attr:`mode`.
+
+ Currently temporal, spatial and volumetric sampling are supported, i.e.
+ expected inputs are 3-D, 4-D or 5-D in shape.
+
+ The input dimensions are interpreted in the form:
+ `mini-batch x channels x [optional depth] x [optional height] x width`.
+
+ The modes available for resizing are: `nearest`, `linear` (3D-only),
+ `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only), `area`, `nearest-exact`
+
+ Args:
+ input (Tensor): the input tensor
+ size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):
+ output spatial size.
+ scale_factor (float or Tuple[float]): multiplier for spatial size. If `scale_factor` is a tuple,
+ its length has to match the number of spatial dimensions; `input.dim() - 2`.
+ mode (str): algorithm used for upsampling:
+ ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
+ ``'trilinear'`` | ``'area'`` | ``'nearest-exact'``. Default: ``'nearest'``
+ align_corners (bool, optional): Geometrically, we consider the pixels of the
+ input and output as squares rather than points.
+ If set to ``True``, the input and output tensors are aligned by the
+ center points of their corner pixels, preserving the values at the corner pixels.
+ If set to ``False``, the input and output tensors are aligned by the corner
+ points of their corner pixels, and the interpolation uses edge value padding
+ for out-of-boundary values, making this operation *independent* of input size
+ when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode`
+ is ``'linear'``, ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``.
+ Default: ``False``
+ recompute_scale_factor (bool, optional): recompute the scale_factor for use in the
+ interpolation calculation. If `recompute_scale_factor` is ``True``, then
+ `scale_factor` must be passed in and `scale_factor` is used to compute the
+ output `size`. The computed output `size` will be used to infer new scales for
+ the interpolation. Note that when `scale_factor` is floating-point, it may differ
+ from the recomputed `scale_factor` due to rounding and precision issues.
+ If `recompute_scale_factor` is ``False``, then `size` or `scale_factor` will
+ be used directly for interpolation. Default: ``None``.
+ antialias (bool, optional): flag to apply anti-aliasing. Default: ``False``. Using anti-alias
+ option together with ``align_corners=False``, interpolation result would match Pillow
+ result for downsampling operation. Supported modes: ``'bilinear'``, ``'bicubic'``.
+
+ .. note::
+ With ``mode='bicubic'``, it's possible to cause overshoot, in other words it can produce
+ negative values or values greater than 255 for images.
+ Explicitly call ``result.clamp(min=0, max=255)`` if you want to reduce the overshoot
+ when displaying the image.
+
+ .. note::
+ Mode ``mode='nearest-exact'`` matches Scikit-Image and PIL nearest neighbours interpolation
+ algorithms and fixes known issues with ``mode='nearest'``. This mode is introduced to keep
+ backward compatibility.
+ Mode ``mode='nearest'`` matches buggy OpenCV's ``INTER_NEAREST`` interpolation algorithm.
+
+ Note:
+ {backward_reproducibility_note}
+ """
+ if has_torch_function_unary(input):
+ return handle_torch_function(
+ interpolate,
+ (input,),
+ input,
+ size=size,
+ scale_factor=scale_factor,
+ mode=mode,
+ align_corners=align_corners,
+ recompute_scale_factor=recompute_scale_factor,
+ antialias=antialias,
+ )
+
+ if mode in ("nearest", "area", "nearest-exact"):
+ if align_corners is not None:
+ raise ValueError(
+ "align_corners option can only be set with the "
+ "interpolating modes: linear | bilinear | bicubic | trilinear"
+ )
+ else:
+ if align_corners is None:
+ align_corners = False
+
+ dim = input.dim() - 2 # Number of spatial dimensions.
+
+ # Process size and scale_factor. Validate that exactly one is set.
+ # Validate its length if it is a list, or expand it if it is a scalar.
+ # After this block, exactly one of output_size and scale_factors will
+ # be non-None, and it will be a list (or tuple).
+ if size is not None and scale_factor is not None:
+ raise ValueError("only one of size or scale_factor should be defined")
+ elif size is not None:
+ assert scale_factor is None
+ scale_factors = None
+ if isinstance(size, (list, tuple)):
+ if len(size) != dim:
+ raise ValueError(
+ "Input and output must have the same number of spatial dimensions, but got "
+ f"input with spatial dimensions of {list(input.shape[2:])} and output size of {size}. "
+ "Please provide input tensor in (N, C, d1, d2, ...,dK) format and "
+ "output size in (o1, o2, ...,oK) format."
+ )
+ output_size = size
+ else:
+ output_size = [size for _ in range(dim)]
+ elif scale_factor is not None:
+ assert size is None
+ output_size = None
+ if isinstance(scale_factor, (list, tuple)):
+ if len(scale_factor) != dim:
+ raise ValueError(
+ "Input and scale_factor must have the same number of spatial dimensions, but "
+ f"got input with spatial dimensions of {list(input.shape[2:])} and "
+ f"scale_factor of shape {scale_factor}. "
+ "Please provide input tensor in (N, C, d1, d2, ...,dK) format and "
+ "scale_factor in (s1, s2, ...,sK) format."
+ )
+ scale_factors = scale_factor
+ else:
+ scale_factors = [scale_factor for _ in range(dim)]
+ else:
+ raise ValueError("either size or scale_factor should be defined")
+
+ if (
+ recompute_scale_factor is not None
+ and recompute_scale_factor
+ and size is not None
+ ):
+ raise ValueError(
+ "recompute_scale_factor is not meaningful with an explicit size."
+ )
+
+ # "area" mode always requires an explicit size rather than scale factor.
+ # Re-use the recompute_scale_factor code path.
+ if mode == "area" and output_size is None:
+ recompute_scale_factor = True
+
+ if recompute_scale_factor is not None and recompute_scale_factor:
+ # We compute output_size here, then un-set scale_factors.
+ # The C++ code will recompute it based on the (integer) output size.
+ assert scale_factors is not None
+ if not torch.jit.is_scripting() and torch._C._get_tracing_state():
+ # make scale_factor a tensor in tracing so constant doesn't get baked in
+ output_size = [
+ (
+ torch.floor(
+ (
+ input.size(i + 2).float()
+ * torch.tensor(scale_factors[i], dtype=torch.float32)
+ ).float()
+ )
+ )
+ for i in range(dim)
+ ]
+ elif torch.jit.is_scripting():
+ output_size = [
+ int(math.floor(float(input.size(i + 2)) * scale_factors[i]))
+ for i in range(dim)
+ ]
+ else:
+ output_size = [
+ _sym_int(input.size(i + 2) * scale_factors[i]) for i in range(dim)
+ ]
+ scale_factors = None
+
+ if antialias and not (mode in ("bilinear", "bicubic") and input.ndim == 4):
+ raise ValueError(
+ "Anti-alias option is only supported for bilinear and bicubic modes"
+ )
+
+ if input.dim() == 3 and mode == "nearest":
+ return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors)
+ if input.dim() == 4 and mode == "nearest":
+ return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
+ if input.dim() == 5 and mode == "nearest":
+ return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors)
+
+ if input.dim() == 3 and mode == "nearest-exact":
+ return torch._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors)
+ if input.dim() == 4 and mode == "nearest-exact":
+ return torch._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors)
+ if input.dim() == 5 and mode == "nearest-exact":
+ return torch._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors)
+
+ if input.dim() == 3 and mode == "area":
+ assert output_size is not None
+ return adaptive_avg_pool1d(input, output_size)
+ if input.dim() == 4 and mode == "area":
+ assert output_size is not None
+ return adaptive_avg_pool2d(input, output_size)
+ if input.dim() == 5 and mode == "area":
+ assert output_size is not None
+ return adaptive_avg_pool3d(input, output_size)
+
+ if input.dim() == 3 and mode == "linear":
+ assert align_corners is not None
+ return torch._C._nn.upsample_linear1d(
+ input, output_size, align_corners, scale_factors
+ )
+ if input.dim() == 4 and mode == "bilinear":
+ assert align_corners is not None
+ if antialias:
+ return torch._C._nn._upsample_bilinear2d_aa(
+ input, output_size, align_corners, scale_factors
+ )
+ return torch._C._nn.upsample_bilinear2d(
+ input, output_size, align_corners, scale_factors
+ )
+ if input.dim() == 5 and mode == "trilinear":
+ assert align_corners is not None
+ return torch._C._nn.upsample_trilinear3d(
+ input, output_size, align_corners, scale_factors
+ )
+ if input.dim() == 4 and mode == "bicubic":
+ assert align_corners is not None
+ if antialias:
+ return torch._C._nn._upsample_bicubic2d_aa(
+ input, output_size, align_corners, scale_factors
+ )
+ return torch._C._nn.upsample_bicubic2d(
+ input, output_size, align_corners, scale_factors
+ )
+
+ if input.dim() == 3 and mode == "bilinear":
+ raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input")
+ if input.dim() == 3 and mode == "trilinear":
+ raise NotImplementedError("Got 3D input, but trilinear mode needs 5D input")
+ if input.dim() == 4 and mode == "linear":
+ raise NotImplementedError("Got 4D input, but linear mode needs 3D input")
+ if input.dim() == 4 and mode == "trilinear":
+ raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input")
+ if input.dim() == 5 and mode == "linear":
+ raise NotImplementedError("Got 5D input, but linear mode needs 3D input")
+ if input.dim() == 5 and mode == "bilinear":
+ raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input")
+
+ raise NotImplementedError(
+ "Input Error: Only 3D, 4D and 5D input Tensors supported"
+ " (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact"
+ " (got {})".format(input.dim(), mode)
+ )
+
+
+def is_accelerate_available():
+ return _accelerate_available
+
+
+def is_flax_available():
+ return _flax_available
+
+
+def is_k_diffusion_available():
+ return _k_diffusion_available
+
+
+def is_librosa_available():
+ return _librosa_available
+
+
+def is_onnx_available():
+ return _onnx_available
+
+
+def is_safetensors_available():
+ return _safetensors_available
+
+
+def is_scipy_available():
+ return _scipy_available
+
+
+def is_tensor(obj):
+ r"""Returns True if `obj` is a PyTorch tensor.
+
+ Note that this function is simply doing ``isinstance(obj, Tensor)``.
+ Using that ``isinstance`` check is better for typechecking with mypy,
+ and more explicit - so it's recommended to use that instead of
+ ``is_tensor``.
+
+ Args:
+ obj (Object): Object to test
+ Example::
+
+ >>> x = torch.tensor([1, 2, 3])
+ >>> torch.is_tensor(x)
+ True
+
+ """
+ return isinstance(obj, torch.Tensor)
+
+
+def is_torch_available():
+ return _torch_available
+
+
+def is_torch_version(operation: str, version: str):
+ """
+ Args:
+ Compares the current PyTorch version to a given reference with an operation.
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A string version of PyTorch
+ """
+ return compare_versions(parse(_torch_version), operation, version)
+
+
+def is_transformers_available():
+ return _transformers_available
+
+
+def is_xformers_available():
+ return _xformers_available
diff --git a/pi/models/unet/prologue.py b/pi/models/unet/prologue.py
new file mode 100644
index 0000000..ba291fc
--- /dev/null
+++ b/pi/models/unet/prologue.py
@@ -0,0 +1,10 @@
+import os
+
+LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
+CONFIG_NAME = "config.json"
+WEIGHTS_NAME = "diffusion_pytorch_model.bin"
+FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
+SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
+DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
+ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
+HF_HUB_OFFLINE = os.getenv("HF_HUB_OFFLINE", "").upper() in ENV_VARS_TRUE_VALUES
diff --git a/pi/models/unet/resnet.py b/pi/models/unet/resnet.py
deleted file mode 100644
index ec95523..0000000
--- a/pi/models/unet/resnet.py
+++ /dev/null
@@ -1,752 +0,0 @@
-from functools import partial
-
-from ... import nn
-from ... import pi
-from ...nn import functional as F
-
-
-class Upsample1D(nn.Module):
- """
- An upsampling layer with an optional convolution.
-
- Parameters:
- channels: channels in the inputs and outputs.
- use_conv: a bool determining if a convolution is applied.
- use_conv_transpose:
- out_channels:
- """
-
- def __init__(
- self,
- channels,
- use_conv=False,
- use_conv_transpose=False,
- out_channels=None,
- name="conv",
- ):
- super().__init__()
- self.channels = channels
- self.out_channels = out_channels or channels
- self.use_conv = use_conv
- self.use_conv_transpose = use_conv_transpose
- self.name = name
-
- self.conv = None
- if use_conv_transpose:
- self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
- elif use_conv:
- self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
-
- def forward(self, x):
- assert x.shape[1] == self.channels
- if self.use_conv_transpose:
- return self.conv(x)
-
- x = F.interpolate(x, scale_factor=2.0, mode="nearest")
-
- if self.use_conv:
- x = self.conv(x)
-
- return x
-
-
-class Downsample1D(nn.Module):
- """
- A downsampling layer with an optional convolution.
-
- Parameters:
- channels: channels in the inputs and outputs.
- use_conv: a bool determining if a convolution is applied.
- out_channels:
- padding:
- """
-
- def __init__(
- self, channels, use_conv=False, out_channels=None, padding=1, name="conv"
- ):
- super().__init__()
- self.channels = channels
- self.out_channels = out_channels or channels
- self.use_conv = use_conv
- self.padding = padding
- stride = 2
- self.name = name
-
- if use_conv:
- self.conv = nn.Conv1d(
- self.channels, self.out_channels, 3, stride=stride, padding=padding
- )
- else:
- assert self.channels == self.out_channels
- self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
-
- def forward(self, x):
- assert x.shape[1] == self.channels
- return self.conv(x)
-
-
-class Upsample2D(nn.Module):
- """
- An upsampling layer with an optional convolution.
-
- Parameters:
- channels: channels in the inputs and outputs.
- use_conv: a bool determining if a convolution is applied.
- use_conv_transpose:
- out_channels:
- """
-
- def __init__(
- self,
- channels,
- use_conv=False,
- use_conv_transpose=False,
- out_channels=None,
- name="conv",
- ):
- super().__init__()
- self.channels = channels
- self.out_channels = out_channels or channels
- self.use_conv = use_conv
- self.use_conv_transpose = use_conv_transpose
- self.name = name
-
- conv = None
- if use_conv_transpose:
- conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
- elif use_conv:
- conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
-
- # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
- if name == "conv":
- self.conv = conv
- else:
- self.Conv2d_0 = conv
-
- def forward(self, hidden_states, output_size=None):
- assert hidden_states.shape[1] == self.channels
-
- if self.use_conv_transpose:
- return self.conv(hidden_states)
-
- # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
- # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
- # https://github.com/pytorch/pytorch/issues/86679
- dtype = hidden_states.dtype
- if dtype == pi.bfloat16:
- hidden_states = hidden_states.to(pi.float32)
-
- # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
- if hidden_states.shape[0] >= 64:
- hidden_states = hidden_states.contiguous()
-
- # if `output_size` is passed we force the interpolation output
- # size and do not make use of `scale_factor=2`
- if output_size is None:
- hidden_states = F.interpolate(
- hidden_states, scale_factor=2.0, mode="nearest"
- )
- else:
- hidden_states = F.interpolate(
- hidden_states, size=output_size, mode="nearest"
- )
-
- # If the input is bfloat16, we cast back to bfloat16
- if dtype == pi.bfloat16:
- hidden_states = hidden_states.to(dtype)
-
- # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
- if self.use_conv:
- if self.name == "conv":
- hidden_states = self.conv(hidden_states)
- else:
- hidden_states = self.Conv2d_0(hidden_states)
-
- return hidden_states
-
-
-class Downsample2D(nn.Module):
- """
- A downsampling layer with an optional convolution.
-
- Parameters:
- channels: channels in the inputs and outputs.
- use_conv: a bool determining if a convolution is applied.
- out_channels:
- padding:
- """
-
- def __init__(
- self, channels, use_conv=False, out_channels=None, padding=1, name="conv"
- ):
- super().__init__()
- self.channels = channels
- self.out_channels = out_channels or channels
- self.use_conv = use_conv
- self.padding = padding
- stride = 2
- self.name = name
-
- if use_conv:
- conv = nn.Conv2d(
- self.channels, self.out_channels, 3, stride=stride, padding=padding
- )
- else:
- assert self.channels == self.out_channels
- conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
-
- # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
- if name == "conv":
- self.Conv2d_0 = conv
- self.conv = conv
- elif name == "Conv2d_0":
- self.conv = conv
- else:
- self.conv = conv
-
- def forward(self, hidden_states):
- assert hidden_states.shape[1] == self.channels
- if self.use_conv and self.padding == 0:
- pad = (0, 1, 0, 1)
- hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
-
- assert hidden_states.shape[1] == self.channels
- hidden_states = self.conv(hidden_states)
-
- return hidden_states
-
-
-class FirUpsample2D(nn.Module):
- def __init__(
- self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)
- ):
- super().__init__()
- out_channels = out_channels if out_channels else channels
- if use_conv:
- self.Conv2d_0 = nn.Conv2d(
- channels, out_channels, kernel_size=3, stride=1, padding=1
- )
- self.use_conv = use_conv
- self.fir_kernel = fir_kernel
- self.out_channels = out_channels
-
- def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
- """Fused `upsample_2d()` followed by `Conv2d()`.
-
- Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
- efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
- arbitrary order.
-
- Args:
- hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
- weight: Weight tensor of the shape `[filterH, filterW, inChannels,
- outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
- kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
- (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
- factor: Integer upsampling factor (default: 2).
- gain: Scaling factor for signal magnitude (default: 1.0).
-
- Returns:
- output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
- datatype as `hidden_states`.
- """
-
- assert isinstance(factor, int) and factor >= 1
-
- # Setup filter kernel.
- if kernel is None:
- kernel = [1] * factor
-
- # setup kernel
- kernel = pi.tensor(kernel, dtype=pi.float32)
- if kernel.ndim == 1:
- kernel = pi.outer(kernel, kernel)
- kernel /= pi.sum(kernel)
-
- kernel = kernel * (gain * (factor ** 2))
-
- if self.use_conv:
- convH = weight.shape[2]
- convW = weight.shape[3]
- inC = weight.shape[1]
-
- pad_value = (kernel.shape[0] - factor) - (convW - 1)
-
- stride = (factor, factor)
- # Determine data dimensions.
- output_shape = (
- (hidden_states.shape[2] - 1) * factor + convH,
- (hidden_states.shape[3] - 1) * factor + convW,
- )
- output_padding = (
- output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
- output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
- )
- assert output_padding[0] >= 0 and output_padding[1] >= 0
- num_groups = hidden_states.shape[1] // inC
-
- # Transpose weights.
- weight = pi.reshape(weight, (num_groups, -1, inC, convH, convW))
- weight = pi.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
- weight = pi.reshape(weight, (num_groups * inC, -1, convH, convW))
-
- inverse_conv = F.conv_transpose2d(
- hidden_states,
- weight,
- stride=stride,
- output_padding=output_padding,
- padding=0,
- )
-
- output = upfirdn2d_native(
- inverse_conv,
- pi.tensor(kernel, device=inverse_conv.device),
- pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
- )
- else:
- pad_value = kernel.shape[0] - factor
- output = upfirdn2d_native(
- hidden_states,
- pi.tensor(kernel, device=hidden_states.device),
- up=factor,
- pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
- )
-
- return output
-
- def forward(self, hidden_states):
- if self.use_conv:
- height = self._upsample_2d(
- hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel
- )
- height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
- else:
- height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
-
- return height
-
-
-class FirDownsample2D(nn.Module):
- def __init__(
- self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)
- ):
- super().__init__()
- out_channels = out_channels if out_channels else channels
- if use_conv:
- self.Conv2d_0 = nn.Conv2d(
- channels, out_channels, kernel_size=3, stride=1, padding=1
- )
- self.fir_kernel = fir_kernel
- self.use_conv = use_conv
- self.out_channels = out_channels
-
- def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
- """Fused `Conv2d()` followed by `downsample_2d()`.
- Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
- efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
- arbitrary order.
-
- Args:
- hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
- weight:
- Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
- performed by `inChannels = x.shape[0] // numGroups`.
- kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
- factor`, which corresponds to average pooling.
- factor: Integer downsampling factor (default: 2).
- gain: Scaling factor for signal magnitude (default: 1.0).
-
- Returns:
- output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
- same datatype as `x`.
- """
-
- assert isinstance(factor, int) and factor >= 1
- if kernel is None:
- kernel = [1] * factor
-
- # setup kernel
- kernel = pi.tensor(kernel, dtype=pi.float32)
- if kernel.ndim == 1:
- kernel = pi.outer(kernel, kernel)
- kernel /= pi.sum(kernel)
-
- kernel = kernel * gain
-
- if self.use_conv:
- _, _, convH, convW = weight.shape
- pad_value = (kernel.shape[0] - factor) + (convW - 1)
- stride_value = [factor, factor]
- upfirdn_input = upfirdn2d_native(
- hidden_states,
- pi.tensor(kernel, device=hidden_states.device),
- pad=((pad_value + 1) // 2, pad_value // 2),
- )
- output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
- else:
- pad_value = kernel.shape[0] - factor
- output = upfirdn2d_native(
- hidden_states,
- pi.tensor(kernel, device=hidden_states.device),
- down=factor,
- pad=((pad_value + 1) // 2, pad_value // 2),
- )
-
- return output
-
- def forward(self, hidden_states):
- if self.use_conv:
- downsample_input = self._downsample_2d(
- hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel
- )
- hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
- else:
- hidden_states = self._downsample_2d(
- hidden_states, kernel=self.fir_kernel, factor=2
- )
-
- return hidden_states
-
-
-class ResnetBlock2D(nn.Module):
- def __init__(
- self,
- *,
- in_channels,
- out_channels=None,
- conv_shortcut=False,
- dropout=0.0,
- temb_channels=512,
- groups=32,
- groups_out=None,
- pre_norm=True,
- eps=1e-6,
- non_linearity="swish",
- time_embedding_norm="default",
- kernel=None,
- output_scale_factor=1.0,
- use_in_shortcut=None,
- up=False,
- down=False,
- ):
- super().__init__()
- self.pre_norm = pre_norm
- self.pre_norm = True
- self.in_channels = in_channels
- out_channels = in_channels if out_channels is None else out_channels
- self.out_channels = out_channels
- self.use_conv_shortcut = conv_shortcut
- self.time_embedding_norm = time_embedding_norm
- self.up = up
- self.down = down
- self.output_scale_factor = output_scale_factor
-
- if groups_out is None:
- groups_out = groups
-
- self.norm1 = pi.nn.GroupNorm(
- num_groups=groups, num_channels=in_channels, eps=eps, affine=True
- )
-
- self.conv1 = pi.nn.Conv2d(
- in_channels, out_channels, kernel_size=3, stride=1, padding=1
- )
-
- if temb_channels is not None:
- if self.time_embedding_norm == "default":
- time_emb_proj_out_channels = out_channels
- elif self.time_embedding_norm == "scale_shift":
- time_emb_proj_out_channels = out_channels * 2
- else:
- raise ValueError(
- f"unknown time_embedding_norm : {self.time_embedding_norm} "
- )
-
- self.time_emb_proj = pi.nn.Linear(temb_channels, time_emb_proj_out_channels)
- else:
- self.time_emb_proj = None
-
- self.norm2 = pi.nn.GroupNorm(
- num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
- )
- self.dropout = pi.nn.Dropout(dropout)
- self.conv2 = pi.nn.Conv2d(
- out_channels, out_channels, kernel_size=3, stride=1, padding=1
- )
-
- if non_linearity == "swish":
- self.nonlinearity = lambda x: F.silu(x)
- elif non_linearity == "mish":
- self.nonlinearity = Mish()
- elif non_linearity == "silu":
- self.nonlinearity = nn.SiLU()
-
- self.upsample = self.downsample = None
- if self.up:
- if kernel == "fir":
- fir_kernel = (1, 3, 3, 1)
- self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
- elif kernel == "sde_vp":
- self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
- else:
- self.upsample = Upsample2D(in_channels, use_conv=False)
- elif self.down:
- if kernel == "fir":
- fir_kernel = (1, 3, 3, 1)
- self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
- elif kernel == "sde_vp":
- self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
- else:
- self.downsample = Downsample2D(
- in_channels, use_conv=False, padding=1, name="op"
- )
-
- self.use_in_shortcut = (
- self.in_channels != self.out_channels
- if use_in_shortcut is None
- else use_in_shortcut
- )
-
- self.conv_shortcut = None
- if self.use_in_shortcut:
- self.conv_shortcut = pi.nn.Conv2d(
- in_channels, out_channels, kernel_size=1, stride=1, padding=0
- )
-
- def forward(self, input_tensor, temb):
- hidden_states = input_tensor
-
- hidden_states = self.norm1(hidden_states)
- hidden_states = self.nonlinearity(hidden_states)
-
- if self.upsample is not None:
- # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
- if hidden_states.shape[0] >= 64:
- input_tensor = input_tensor.contiguous()
- hidden_states = hidden_states.contiguous()
- input_tensor = self.upsample(input_tensor)
- hidden_states = self.upsample(hidden_states)
- elif self.downsample is not None:
- input_tensor = self.downsample(input_tensor)
- hidden_states = self.downsample(hidden_states)
-
- hidden_states = self.conv1(hidden_states)
-
- if temb is not None:
- temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
-
- if temb is not None and self.time_embedding_norm == "default":
- hidden_states = hidden_states + temb
-
- hidden_states = self.norm2(hidden_states)
-
- if temb is not None and self.time_embedding_norm == "scale_shift":
- scale, shift = pi.chunk(temb, 2, dim=1)
- hidden_states = hidden_states * (1 + scale) + shift
-
- hidden_states = self.nonlinearity(hidden_states)
-
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.conv2(hidden_states)
-
- if self.conv_shortcut is not None:
- input_tensor = self.conv_shortcut(input_tensor)
-
- output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
-
- return output_tensor
-
-
-class Mish(pi.nn.Module):
- def forward(self, hidden_states):
- return hidden_states * pi.tanh(pi.nn.functional.softplus(hidden_states))
-
-
-# unet_rl.py
-def rearrange_dims(tensor):
- if len(tensor.shape) == 2:
- return tensor[:, :, None]
- if len(tensor.shape) == 3:
- return tensor[:, :, None, :]
- elif len(tensor.shape) == 4:
- return tensor[:, :, 0, :]
- else:
- raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
-
-
-class Conv1dBlock(nn.Module):
- """
- Conv1d --> GroupNorm --> Mish
- """
-
- def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
- super().__init__()
-
- self.conv1d = nn.Conv1d(
- inp_channels, out_channels, kernel_size, padding=kernel_size // 2
- )
- self.group_norm = nn.GroupNorm(n_groups, out_channels)
- self.mish = nn.Mish()
-
- def forward(self, x):
- x = self.conv1d(x)
- x = rearrange_dims(x)
- x = self.group_norm(x)
- x = rearrange_dims(x)
- x = self.mish(x)
- return x
-
-
-# unet_rl.py
-class ResidualTemporalBlock1D(nn.Module):
- def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
- super().__init__()
- self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
- self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
-
- self.time_emb_act = nn.Mish()
- self.time_emb = nn.Linear(embed_dim, out_channels)
-
- self.residual_conv = (
- nn.Conv1d(inp_channels, out_channels, 1)
- if inp_channels != out_channels
- else nn.Identity()
- )
-
- def forward(self, x, t):
- """
- Args:
- x : [ batch_size x inp_channels x horizon ]
- t : [ batch_size x embed_dim ]
-
- returns:
- out : [ batch_size x out_channels x horizon ]
- """
- t = self.time_emb_act(t)
- t = self.time_emb(t)
- out = self.conv_in(x) + rearrange_dims(t)
- out = self.conv_out(out)
- return out + self.residual_conv(x)
-
-
-def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
- r"""Upsample2D a batch of 2D images with the given filter.
- Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
- filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
- `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
- a: multiple of the upsampling factor.
-
- Args:
- hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
- kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
- (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
- factor: Integer upsampling factor (default: 2).
- gain: Scaling factor for signal magnitude (default: 1.0).
-
- Returns:
- output: Tensor of the shape `[N, C, H * factor, W * factor]`
- """
- assert isinstance(factor, int) and factor >= 1
- if kernel is None:
- kernel = [1] * factor
-
- kernel = pi.tensor(kernel, dtype=pi.float32)
- if kernel.ndim == 1:
- kernel = pi.outer(kernel, kernel)
- kernel /= pi.sum(kernel)
-
- kernel = kernel * (gain * (factor ** 2))
- pad_value = kernel.shape[0] - factor
- output = upfirdn2d_native(
- hidden_states,
- kernel.to(device=hidden_states.device),
- up=factor,
- pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
- )
- return output
-
-
-def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
- r"""Downsample2D a batch of 2D images with the given filter.
- Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
- given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
- specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
- shape is a multiple of the downsampling factor.
-
- Args:
- hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
- kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
- (separable). The default is `[1] * factor`, which corresponds to average pooling.
- factor: Integer downsampling factor (default: 2).
- gain: Scaling factor for signal magnitude (default: 1.0).
-
- Returns:
- output: Tensor of the shape `[N, C, H // factor, W // factor]`
- """
-
- assert isinstance(factor, int) and factor >= 1
- if kernel is None:
- kernel = [1] * factor
-
- kernel = pi.tensor(kernel, dtype=pi.float32)
- if kernel.ndim == 1:
- kernel = pi.outer(kernel, kernel)
- kernel /= pi.sum(kernel)
-
- kernel = kernel * gain
- pad_value = kernel.shape[0] - factor
- output = upfirdn2d_native(
- hidden_states,
- kernel.to(device=hidden_states.device),
- down=factor,
- pad=((pad_value + 1) // 2, pad_value // 2),
- )
- return output
-
-
-def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
- up_x = up_y = up
- down_x = down_y = down
- pad_x0 = pad_y0 = pad[0]
- pad_x1 = pad_y1 = pad[1]
-
- _, channel, in_h, in_w = tensor.shape
- tensor = tensor.reshape(-1, in_h, in_w, 1)
-
- _, in_h, in_w, minor = tensor.shape
- kernel_h, kernel_w = kernel.shape
-
- out = tensor.view(-1, in_h, 1, in_w, 1, minor)
- out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
- out = out.view(-1, in_h * up_y, in_w * up_x, minor)
-
- out = F.pad(
- out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
- )
- out = out.to(tensor.device) # Move back to mps if necessary
- out = out[
- :,
- max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
- max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
- :,
- ]
-
- out = out.permute(0, 3, 1, 2)
- out = out.reshape(
- [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
- )
- w = pi.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
- out = F.conv2d(out, w)
- out = out.reshape(
- -1,
- minor,
- in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
- in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
- )
- out = out.permute(0, 2, 3, 1)
- out = out[:, ::down_y, ::down_x, :]
-
- out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
- out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
-
- return out.view(-1, channel, out_h, out_w)
diff --git a/pi/models/unet/transformer_2d.py b/pi/models/unet/transformer_2d.py
deleted file mode 100644
index 5376875..0000000
--- a/pi/models/unet/transformer_2d.py
+++ /dev/null
@@ -1,357 +0,0 @@
-# Copyright 2022 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from dataclasses import dataclass
-from typing import Optional
-
-from ... import nn
-from ... import pi
-from ...nn import functional as F
-
-from .attention import BasicTransformerBlock
-from .embeddings import PatchEmbed, ImagePositionalEmbeddings
-
-
-@dataclass
-class Transformer2DModelOutput:
- """
- Args:
- sample (`pi.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
- Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
- for the unnoised latent pixels.
- """
-
- sample: pi.FloatTensor
-
-
-class Transformer2DModel(nn.Module):
- """
- Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
- embeddings) inputs.
-
- When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
- transformer action. Finally, reshape to image.
-
- When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
- embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
- classes of unnoised image.
-
- Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
- image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
-
- Parameters:
- num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
- attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
- in_channels (`int`, *optional*):
- Pass if the input is continuous. The number of channels in the input and output.
- num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
- cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
- sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
- Note that this is fixed at training time as it is used for learning a number of position embeddings. See
- `ImagePositionalEmbeddings`.
- num_vector_embeds (`int`, *optional*):
- Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
- Includes the class for the masked latent pixel.
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
- num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
- The number of diffusion steps used during training. Note that this is fixed at training time as it is used
- to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
- up to but not more than steps than `num_embeds_ada_norm`.
- attention_bias (`bool`, *optional*):
- Configure if the TransformerBlocks' attention should contain a bias parameter.
- """
-
- def __init__(
- self,
- num_attention_heads: int = 16,
- attention_head_dim: int = 88,
- in_channels: Optional[int] = None,
- out_channels: Optional[int] = None,
- num_layers: int = 1,
- dropout: float = 0.0,
- norm_num_groups: int = 32,
- cross_attention_dim: Optional[int] = None,
- attention_bias: bool = False,
- sample_size: Optional[int] = None,
- num_vector_embeds: Optional[int] = None,
- patch_size: Optional[int] = None,
- activation_fn: str = "geglu",
- num_embeds_ada_norm: Optional[int] = None,
- use_linear_projection: bool = False,
- only_cross_attention: bool = False,
- upcast_attention: bool = False,
- norm_type: str = "layer_norm",
- norm_elementwise_affine: bool = True,
- ):
- super().__init__()
- self.use_linear_projection = use_linear_projection
- self.num_attention_heads = num_attention_heads
- self.attention_head_dim = attention_head_dim
- inner_dim = num_attention_heads * attention_head_dim
-
- # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
- # Define whether input is continuous or discrete depending on configuration
- self.is_input_continuous = (in_channels is not None) and (patch_size is None)
- self.is_input_vectorized = num_vector_embeds is not None
- self.is_input_patches = in_channels is not None and patch_size is not None
-
- if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
- deprecation_message = (
- f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
- " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
- " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
- " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
- " would be very nice if you could open a Pull request for the `transformer/config.json` file"
- )
- norm_type = "ada_norm"
-
- if self.is_input_continuous and self.is_input_vectorized:
- raise ValueError(
- f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
- " sure that either `in_channels` or `num_vector_embeds` is None."
- )
- elif self.is_input_vectorized and self.is_input_patches:
- raise ValueError(
- f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
- " sure that either `num_vector_embeds` or `num_patches` is None."
- )
- elif (
- not self.is_input_continuous
- and not self.is_input_vectorized
- and not self.is_input_patches
- ):
- raise ValueError(
- f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
- f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
- )
-
- # 2. Define input layers
- if self.is_input_continuous:
- self.in_channels = in_channels
-
- self.norm = pi.nn.GroupNorm(
- num_groups=norm_num_groups,
- num_channels=in_channels,
- eps=1e-6,
- affine=True,
- )
- if use_linear_projection:
- self.proj_in = nn.Linear(in_channels, inner_dim)
- else:
- self.proj_in = nn.Conv2d(
- in_channels, inner_dim, kernel_size=1, stride=1, padding=0
- )
- elif self.is_input_vectorized:
- assert (
- sample_size is not None
- ), "Transformer2DModel over discrete input must provide sample_size"
- assert (
- num_vector_embeds is not None
- ), "Transformer2DModel over discrete input must provide num_embed"
-
- self.height = sample_size
- self.width = sample_size
- self.num_vector_embeds = num_vector_embeds
- self.num_latent_pixels = self.height * self.width
-
- self.latent_image_embedding = ImagePositionalEmbeddings(
- num_embed=num_vector_embeds,
- embed_dim=inner_dim,
- height=self.height,
- width=self.width,
- )
- elif self.is_input_patches:
- assert (
- sample_size is not None
- ), "Transformer2DModel over patched input must provide sample_size"
-
- self.height = sample_size
- self.width = sample_size
-
- self.patch_size = patch_size
- self.pos_embed = PatchEmbed(
- height=sample_size,
- width=sample_size,
- patch_size=patch_size,
- in_channels=in_channels,
- embed_dim=inner_dim,
- )
-
- # 3. Define transformers blocks
- self.transformer_blocks = nn.ModuleList(
- [
- BasicTransformerBlock(
- inner_dim,
- num_attention_heads,
- attention_head_dim,
- dropout=dropout,
- cross_attention_dim=cross_attention_dim,
- activation_fn=activation_fn,
- num_embeds_ada_norm=num_embeds_ada_norm,
- attention_bias=attention_bias,
- only_cross_attention=only_cross_attention,
- upcast_attention=upcast_attention,
- norm_type=norm_type,
- norm_elementwise_affine=norm_elementwise_affine,
- )
- for d in range(num_layers)
- ]
- )
-
- # 4. Define output layers
- self.out_channels = in_channels if out_channels is None else out_channels
- if self.is_input_continuous:
- # TODO: should use out_channels for continous projections
- if use_linear_projection:
- self.proj_out = nn.Linear(in_channels, inner_dim)
- else:
- self.proj_out = nn.Conv2d(
- inner_dim, in_channels, kernel_size=1, stride=1, padding=0
- )
- elif self.is_input_vectorized:
- self.norm_out = nn.LayerNorm(inner_dim)
- self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
- elif self.is_input_patches:
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
- self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
- self.proj_out_2 = nn.Linear(
- inner_dim, patch_size * patch_size * self.out_channels
- )
-
- def forward(
- self,
- hidden_states,
- encoder_hidden_states=None,
- timestep=None,
- class_labels=None,
- cross_attention_kwargs=None,
- return_dict: bool = True,
- ):
- """
- Args:
- hidden_states ( When discrete, `pi.LongTensor` of shape `(batch size, num latent pixels)`.
- When continous, `pi.FloatTensor` of shape `(batch size, channel, height, width)`): Input
- hidden_states
- encoder_hidden_states ( `pi.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
- Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
- self-attention.
- timestep ( `pi.long`, *optional*):
- Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
- class_labels ( `pi.LongTensor` of shape `(batch size, num classes)`, *optional*):
- Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
- conditioning.
- return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
-
- Returns:
- [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
- [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
- returning a tuple, the first element is the sample tensor.
- """
- # 1. Input
- if self.is_input_continuous:
- batch, _, height, width = hidden_states.shape
- residual = hidden_states
-
- hidden_states = self.norm(hidden_states)
- if not self.use_linear_projection:
- hidden_states = self.proj_in(hidden_states)
- inner_dim = hidden_states.shape[1]
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
- batch, height * width, inner_dim
- )
- else:
- inner_dim = hidden_states.shape[1]
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
- batch, height * width, inner_dim
- )
- hidden_states = self.proj_in(hidden_states)
- elif self.is_input_vectorized:
- hidden_states = self.latent_image_embedding(hidden_states)
- elif self.is_input_patches:
- hidden_states = self.pos_embed(hidden_states)
-
- # 2. Blocks
- for block in self.transformer_blocks:
- hidden_states = block(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- timestep=timestep,
- cross_attention_kwargs=cross_attention_kwargs,
- class_labels=class_labels,
- )
-
- # 3. Output
- if self.is_input_continuous:
- if not self.use_linear_projection:
- hidden_states = (
- hidden_states.reshape(batch, height, width, inner_dim)
- .permute(0, 3, 1, 2)
- .contiguous()
- )
- hidden_states = self.proj_out(hidden_states)
- else:
- hidden_states = self.proj_out(hidden_states)
- hidden_states = (
- hidden_states.reshape(batch, height, width, inner_dim)
- .permute(0, 3, 1, 2)
- .contiguous()
- )
-
- output = hidden_states + residual
- elif self.is_input_vectorized:
- hidden_states = self.norm_out(hidden_states)
- logits = self.out(hidden_states)
- # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
- logits = logits.permute(0, 2, 1)
-
- # log(p(x_0))
- output = F.log_softmax(logits.double(), dim=1).float()
- elif self.is_input_patches:
- # TODO: cleanup!
- conditioning = self.transformer_blocks[0].norm1.emb(
- timestep, class_labels, hidden_dtype=hidden_states.dtype
- )
- shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
- hidden_states = (
- self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
- )
- hidden_states = self.proj_out_2(hidden_states)
-
- # unpatchify
- height = width = int(hidden_states.shape[1] ** 0.5)
- hidden_states = hidden_states.reshape(
- shape=(
- -1,
- height,
- width,
- self.patch_size,
- self.patch_size,
- self.out_channels,
- )
- )
- hidden_states = pi.einsum("nhwpqc->nchpwq", hidden_states)
- output = hidden_states.reshape(
- shape=(
- -1,
- self.out_channels,
- height * self.patch_size,
- width * self.patch_size,
- )
- )
-
- if not return_dict:
- return (output,)
-
- return Transformer2DModelOutput(sample=output)
diff --git a/pi/models/unet/unet_2d_blocks.py b/pi/models/unet/unet_2d_blocks.py
deleted file mode 100644
index 75b51b5..0000000
--- a/pi/models/unet/unet_2d_blocks.py
+++ /dev/null
@@ -1,2312 +0,0 @@
-# Copyright 2022 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import numpy as np
-from diffusers import Transformer2DModel
-
-from .attention import AttentionBlock
-from .resnet import ResnetBlock2D
-from ... import nn, Tensor
-from ... import pi
-
-
-def get_down_block(
- down_block_type,
- num_layers,
- in_channels,
- out_channels,
- temb_channels,
- add_downsample,
- resnet_eps,
- resnet_act_fn,
- attn_num_head_channels,
- resnet_groups=None,
- cross_attention_dim=None,
- downsample_padding=None,
- dual_cross_attention=False,
- use_linear_projection=False,
- only_cross_attention=False,
- upcast_attention=False,
- resnet_time_scale_shift="default",
-):
- down_block_type = (
- down_block_type[7:]
- if down_block_type.startswith("UNetRes")
- else down_block_type
- )
- if down_block_type == "DownBlock2D":
- return DownBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- add_downsample=add_downsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- downsample_padding=downsample_padding,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- elif down_block_type == "ResnetDownsampleBlock2D":
- return ResnetDownsampleBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- add_downsample=add_downsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- elif down_block_type == "AttnDownBlock2D":
- return AttnDownBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- add_downsample=add_downsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- downsample_padding=downsample_padding,
- attn_num_head_channels=attn_num_head_channels,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- elif down_block_type == "CrossAttnDownBlock2D":
- if cross_attention_dim is None:
- raise ValueError(
- "cross_attention_dim must be specified for CrossAttnDownBlock2D"
- )
- return CrossAttnDownBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- add_downsample=add_downsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- downsample_padding=downsample_padding,
- cross_attention_dim=cross_attention_dim,
- attn_num_head_channels=attn_num_head_channels,
- dual_cross_attention=dual_cross_attention,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention,
- upcast_attention=upcast_attention,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- elif down_block_type == "SimpleCrossAttnDownBlock2D":
- if cross_attention_dim is None:
- raise ValueError(
- "cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D"
- )
- return SimpleCrossAttnDownBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- add_downsample=add_downsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- cross_attention_dim=cross_attention_dim,
- attn_num_head_channels=attn_num_head_channels,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- elif down_block_type == "SkipDownBlock2D":
- return SkipDownBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- add_downsample=add_downsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- downsample_padding=downsample_padding,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- elif down_block_type == "AttnSkipDownBlock2D":
- return AttnSkipDownBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- add_downsample=add_downsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- downsample_padding=downsample_padding,
- attn_num_head_channels=attn_num_head_channels,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- elif down_block_type == "DownEncoderBlock2D":
- return DownEncoderBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- add_downsample=add_downsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- downsample_padding=downsample_padding,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- elif down_block_type == "AttnDownEncoderBlock2D":
- return AttnDownEncoderBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- add_downsample=add_downsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- downsample_padding=downsample_padding,
- attn_num_head_channels=attn_num_head_channels,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- raise ValueError(f"{down_block_type} does not exist.")
-
-
-def get_up_block(
- up_block_type,
- num_layers,
- in_channels,
- out_channels,
- prev_output_channel,
- temb_channels,
- add_upsample,
- resnet_eps,
- resnet_act_fn,
- attn_num_head_channels,
- resnet_groups=None,
- cross_attention_dim=None,
- dual_cross_attention=False,
- use_linear_projection=False,
- only_cross_attention=False,
- upcast_attention=False,
- resnet_time_scale_shift="default",
-):
- up_block_type = (
- up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
- )
- if up_block_type == "UpBlock2D":
- return UpBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- prev_output_channel=prev_output_channel,
- temb_channels=temb_channels,
- add_upsample=add_upsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- elif up_block_type == "ResnetUpsampleBlock2D":
- return ResnetUpsampleBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- prev_output_channel=prev_output_channel,
- temb_channels=temb_channels,
- add_upsample=add_upsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- elif up_block_type == "CrossAttnUpBlock2D":
- if cross_attention_dim is None:
- raise ValueError(
- "cross_attention_dim must be specified for CrossAttnUpBlock2D"
- )
- return CrossAttnUpBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- prev_output_channel=prev_output_channel,
- temb_channels=temb_channels,
- add_upsample=add_upsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- cross_attention_dim=cross_attention_dim,
- attn_num_head_channels=attn_num_head_channels,
- dual_cross_attention=dual_cross_attention,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention,
- upcast_attention=upcast_attention,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- elif up_block_type == "SimpleCrossAttnUpBlock2D":
- if cross_attention_dim is None:
- raise ValueError(
- "cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D"
- )
- return SimpleCrossAttnUpBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- prev_output_channel=prev_output_channel,
- temb_channels=temb_channels,
- add_upsample=add_upsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- cross_attention_dim=cross_attention_dim,
- attn_num_head_channels=attn_num_head_channels,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- elif up_block_type == "AttnUpBlock2D":
- return AttnUpBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- prev_output_channel=prev_output_channel,
- temb_channels=temb_channels,
- add_upsample=add_upsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- attn_num_head_channels=attn_num_head_channels,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- elif up_block_type == "SkipUpBlock2D":
- return SkipUpBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- prev_output_channel=prev_output_channel,
- temb_channels=temb_channels,
- add_upsample=add_upsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- elif up_block_type == "AttnSkipUpBlock2D":
- return AttnSkipUpBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- prev_output_channel=prev_output_channel,
- temb_channels=temb_channels,
- add_upsample=add_upsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- attn_num_head_channels=attn_num_head_channels,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- elif up_block_type == "UpDecoderBlock2D":
- return UpDecoderBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- add_upsample=add_upsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- elif up_block_type == "AttnUpDecoderBlock2D":
- return AttnUpDecoderBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- add_upsample=add_upsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- attn_num_head_channels=attn_num_head_channels,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- raise ValueError(f"{up_block_type} does not exist.")
-
-
-class UNetMidBlock2D(nn.Module):
- def __init__(
- self,
- in_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- add_attention: bool = True,
- attn_num_head_channels=1,
- output_scale_factor=1.0,
- ):
- super().__init__()
- resnet_groups = (
- resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
- )
- self.add_attention = add_attention
-
- # there is always at least one resnet
- resnets = [
- ResnetBlock2D(
- in_channels=in_channels,
- out_channels=in_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- ]
- attentions = []
-
- for _ in range(num_layers):
- if self.add_attention:
- attentions.append(
- AttentionBlock(
- in_channels,
- num_head_channels=attn_num_head_channels,
- rescale_output_factor=output_scale_factor,
- eps=resnet_eps,
- norm_num_groups=resnet_groups,
- )
- )
- else:
- attentions.append(None)
-
- resnets.append(
- ResnetBlock2D(
- in_channels=in_channels,
- out_channels=in_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
-
- self.attentions = nn.ModuleList(attentions)
- self.resnets = nn.ModuleList(resnets)
-
- def forward(self, hidden_states, temb=None):
- hidden_states = self.resnets[0](hidden_states, temb)
- for attn, resnet in zip(self.attentions, self.resnets[1:]):
- if attn is not None:
- hidden_states = attn(hidden_states)
- hidden_states = resnet(hidden_states, temb)
-
- return hidden_states
-
-
-class UNetMidBlock2DCrossAttn(nn.Module):
- def __init__(
- self,
- in_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- attn_num_head_channels=1,
- output_scale_factor=1.0,
- cross_attention_dim=1280,
- dual_cross_attention=False,
- use_linear_projection=False,
- upcast_attention=False,
- ):
- super().__init__()
-
- self.has_cross_attention = True
- self.attn_num_head_channels = attn_num_head_channels
- resnet_groups = (
- resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
- )
-
- # there is always at least one resnet
- resnets = [
- ResnetBlock2D(
- in_channels=in_channels,
- out_channels=in_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- ]
- attentions = []
-
- for _ in range(num_layers):
- if not dual_cross_attention:
- attentions.append(
- Transformer2DModel(
- attn_num_head_channels,
- in_channels // attn_num_head_channels,
- in_channels=in_channels,
- num_layers=1,
- cross_attention_dim=cross_attention_dim,
- norm_num_groups=resnet_groups,
- use_linear_projection=use_linear_projection,
- upcast_attention=upcast_attention,
- )
- )
- else:
- attentions.append(
- DualTransformer2DModel(
- attn_num_head_channels,
- in_channels // attn_num_head_channels,
- in_channels=in_channels,
- num_layers=1,
- cross_attention_dim=cross_attention_dim,
- norm_num_groups=resnet_groups,
- )
- )
- resnets.append(
- ResnetBlock2D(
- in_channels=in_channels,
- out_channels=in_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
-
- self.attentions = nn.ModuleList(attentions)
- self.resnets = nn.ModuleList(resnets)
-
- def forward(
- self,
- hidden_states,
- temb=None,
- encoder_hidden_states=None,
- attention_mask=None,
- cross_attention_kwargs=None,
- ):
- hidden_states = self.resnets[0](hidden_states, temb)
- for attn, resnet in zip(self.attentions, self.resnets[1:]):
- hidden_states = attn(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- cross_attention_kwargs=cross_attention_kwargs,
- ).sample
- hidden_states = resnet(hidden_states, temb)
-
- return hidden_states
-
-
-class UNetMidBlock2DSimpleCrossAttn(nn.Module):
- def __init__(
- self,
- in_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- attn_num_head_channels=1,
- output_scale_factor=1.0,
- cross_attention_dim=1280,
- ):
- super().__init__()
-
- self.has_cross_attention = True
-
- self.attn_num_head_channels = attn_num_head_channels
- resnet_groups = (
- resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
- )
-
- self.num_heads = in_channels // self.attn_num_head_channels
-
- # there is always at least one resnet
- resnets = [
- ResnetBlock2D(
- in_channels=in_channels,
- out_channels=in_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- ]
- attentions = []
-
- for _ in range(num_layers):
- attentions.append(
- CrossAttention(
- query_dim=in_channels,
- cross_attention_dim=in_channels,
- heads=self.num_heads,
- dim_head=attn_num_head_channels,
- added_kv_proj_dim=cross_attention_dim,
- norm_num_groups=resnet_groups,
- bias=True,
- upcast_softmax=True,
- processor=CrossAttnAddedKVProcessor(),
- )
- )
- resnets.append(
- ResnetBlock2D(
- in_channels=in_channels,
- out_channels=in_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
-
- self.attentions = nn.ModuleList(attentions)
- self.resnets = nn.ModuleList(resnets)
-
- def forward(
- self,
- hidden_states,
- temb=None,
- encoder_hidden_states=None,
- attention_mask=None,
- cross_attention_kwargs=None,
- ):
- cross_attention_kwargs = (
- cross_attention_kwargs if cross_attention_kwargs is not None else {}
- )
- hidden_states = self.resnets[0](hidden_states, temb)
- for attn, resnet in zip(self.attentions, self.resnets[1:]):
- # attn
- hidden_states = attn(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=attention_mask,
- **cross_attention_kwargs,
- )
-
- # resnet
- hidden_states = resnet(hidden_states, temb)
-
- return hidden_states
-
-
-class AttnDownBlock2D(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- attn_num_head_channels=1,
- output_scale_factor=1.0,
- downsample_padding=1,
- add_downsample=True,
- ):
- super().__init__()
- resnets = []
- attentions = []
-
- for i in range(num_layers):
- in_channels = in_channels if i == 0 else out_channels
- resnets.append(
- ResnetBlock2D(
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
- attentions.append(
- AttentionBlock(
- out_channels,
- num_head_channels=attn_num_head_channels,
- rescale_output_factor=output_scale_factor,
- eps=resnet_eps,
- norm_num_groups=resnet_groups,
- )
- )
-
- self.attentions = nn.ModuleList(attentions)
- self.resnets = nn.ModuleList(resnets)
-
- if add_downsample:
- self.downsamplers = nn.ModuleList(
- [
- Downsample2D(
- out_channels,
- use_conv=True,
- out_channels=out_channels,
- padding=downsample_padding,
- name="op",
- )
- ]
- )
- else:
- self.downsamplers = None
-
- def forward(self, hidden_states, temb=None):
- output_states = ()
-
- for resnet, attn in zip(self.resnets, self.attentions):
- hidden_states = resnet(hidden_states, temb)
- hidden_states = attn(hidden_states)
- output_states += (hidden_states,)
-
- if self.downsamplers is not None:
- for downsampler in self.downsamplers:
- hidden_states = downsampler(hidden_states)
-
- output_states += (hidden_states,)
-
- return hidden_states, output_states
-
-
-class CrossAttnDownBlock2D(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- attn_num_head_channels=1,
- cross_attention_dim=1280,
- output_scale_factor=1.0,
- downsample_padding=1,
- add_downsample=True,
- dual_cross_attention=False,
- use_linear_projection=False,
- only_cross_attention=False,
- upcast_attention=False,
- ):
- super().__init__()
- resnets = []
- attentions = []
-
- self.has_cross_attention = True
- self.attn_num_head_channels = attn_num_head_channels
-
- for i in range(num_layers):
- in_channels = in_channels if i == 0 else out_channels
- resnets.append(
- ResnetBlock2D(
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
- if not dual_cross_attention:
- attentions.append(
- Transformer2DModel(
- attn_num_head_channels,
- out_channels // attn_num_head_channels,
- in_channels=out_channels,
- num_layers=1,
- cross_attention_dim=cross_attention_dim,
- norm_num_groups=resnet_groups,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention,
- upcast_attention=upcast_attention,
- )
- )
- else:
- attentions.append(
- DualTransformer2DModel(
- attn_num_head_channels,
- out_channels // attn_num_head_channels,
- in_channels=out_channels,
- num_layers=1,
- cross_attention_dim=cross_attention_dim,
- norm_num_groups=resnet_groups,
- )
- )
- self.attentions = nn.ModuleList(attentions)
- self.resnets = nn.ModuleList(resnets)
-
- if add_downsample:
- self.downsamplers = nn.ModuleList(
- [
- Downsample2D(
- out_channels,
- use_conv=True,
- out_channels=out_channels,
- padding=downsample_padding,
- name="op",
- )
- ]
- )
- else:
- self.downsamplers = None
-
- self.gradient_checkpointing = False
-
- def forward(
- self,
- hidden_states,
- temb=None,
- encoder_hidden_states=None,
- attention_mask=None,
- cross_attention_kwargs=None,
- ):
- # TODO(Patrick, William) - attention mask is not used
- output_states = ()
-
- for resnet, attn in zip(self.resnets, self.attentions):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- hidden_states = pi.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb
- )
- hidden_states = pi.utils.checkpoint.checkpoint(
- create_custom_forward(attn, return_dict=False),
- hidden_states,
- encoder_hidden_states,
- cross_attention_kwargs,
- )[0]
- else:
- hidden_states = resnet(hidden_states, temb)
- hidden_states = attn(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- cross_attention_kwargs=cross_attention_kwargs,
- ).sample
-
- output_states += (hidden_states,)
-
- if self.downsamplers is not None:
- for downsampler in self.downsamplers:
- hidden_states = downsampler(hidden_states)
-
- output_states += (hidden_states,)
-
- return hidden_states, output_states
-
-
-class DownBlock2D(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- output_scale_factor=1.0,
- add_downsample=True,
- downsample_padding=1,
- ):
- super().__init__()
- resnets = []
-
- for i in range(num_layers):
- in_channels = in_channels if i == 0 else out_channels
- resnets.append(
- ResnetBlock2D(
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
-
- self.resnets = nn.ModuleList(resnets)
-
- if add_downsample:
- self.downsamplers = nn.ModuleList(
- [
- Downsample2D(
- out_channels,
- use_conv=True,
- out_channels=out_channels,
- padding=downsample_padding,
- name="op",
- )
- ]
- )
- else:
- self.downsamplers = None
-
- self.gradient_checkpointing = False
-
- def forward(self, hidden_states, temb=None):
- output_states = ()
-
- for resnet in self.resnets:
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- hidden_states = pi.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb
- )
- else:
- hidden_states = resnet(hidden_states, temb)
-
- output_states += (hidden_states,)
-
- if self.downsamplers is not None:
- for downsampler in self.downsamplers:
- hidden_states = downsampler(hidden_states)
-
- output_states += (hidden_states,)
-
- return hidden_states, output_states
-
-
-class DownEncoderBlock2D(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- output_scale_factor=1.0,
- add_downsample=True,
- downsample_padding=1,
- ):
- super().__init__()
- resnets = []
-
- for i in range(num_layers):
- in_channels = in_channels if i == 0 else out_channels
- resnets.append(
- ResnetBlock2D(
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=None,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
-
- self.resnets = nn.ModuleList(resnets)
-
- if add_downsample:
- self.downsamplers = nn.ModuleList(
- [
- Downsample2D(
- out_channels,
- use_conv=True,
- out_channels=out_channels,
- padding=downsample_padding,
- name="op",
- )
- ]
- )
- else:
- self.downsamplers = None
-
- def forward(self, hidden_states):
- for resnet in self.resnets:
- hidden_states = resnet(hidden_states, temb=None)
-
- if self.downsamplers is not None:
- for downsampler in self.downsamplers:
- hidden_states = downsampler(hidden_states)
-
- return hidden_states
-
-
-class AttnDownEncoderBlock2D(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- attn_num_head_channels=1,
- output_scale_factor=1.0,
- add_downsample=True,
- downsample_padding=1,
- ):
- super().__init__()
- resnets = []
- attentions = []
-
- for i in range(num_layers):
- in_channels = in_channels if i == 0 else out_channels
- resnets.append(
- ResnetBlock2D(
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=None,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
- attentions.append(
- AttentionBlock(
- out_channels,
- num_head_channels=attn_num_head_channels,
- rescale_output_factor=output_scale_factor,
- eps=resnet_eps,
- norm_num_groups=resnet_groups,
- )
- )
-
- self.attentions = nn.ModuleList(attentions)
- self.resnets = nn.ModuleList(resnets)
-
- if add_downsample:
- self.downsamplers = nn.ModuleList(
- [
- Downsample2D(
- out_channels,
- use_conv=True,
- out_channels=out_channels,
- padding=downsample_padding,
- name="op",
- )
- ]
- )
- else:
- self.downsamplers = None
-
- def forward(self, hidden_states):
- for resnet, attn in zip(self.resnets, self.attentions):
- hidden_states = resnet(hidden_states, temb=None)
- hidden_states = attn(hidden_states)
-
- if self.downsamplers is not None:
- for downsampler in self.downsamplers:
- hidden_states = downsampler(hidden_states)
-
- return hidden_states
-
-
-class AttnSkipDownBlock2D(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_pre_norm: bool = True,
- attn_num_head_channels=1,
- output_scale_factor=np.sqrt(2.0),
- downsample_padding=1,
- add_downsample=True,
- ):
- super().__init__()
- self.attentions = nn.ModuleList([])
- self.resnets = nn.ModuleList([])
-
- for i in range(num_layers):
- in_channels = in_channels if i == 0 else out_channels
- self.resnets.append(
- ResnetBlock2D(
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=min(in_channels // 4, 32),
- groups_out=min(out_channels // 4, 32),
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
- self.attentions.append(
- AttentionBlock(
- out_channels,
- num_head_channels=attn_num_head_channels,
- rescale_output_factor=output_scale_factor,
- eps=resnet_eps,
- )
- )
-
- if add_downsample:
- self.resnet_down = ResnetBlock2D(
- in_channels=out_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=min(out_channels // 4, 32),
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- use_in_shortcut=True,
- down=True,
- kernel="fir",
- )
- self.downsamplers = nn.ModuleList(
- [FirDownsample2D(out_channels, out_channels=out_channels)]
- )
- self.skip_conv = nn.Conv2d(
- 3, out_channels, kernel_size=(1, 1), stride=(1, 1)
- )
- else:
- self.resnet_down = None
- self.downsamplers = None
- self.skip_conv = None
-
- def forward(self, hidden_states, temb=None, skip_sample=None):
- output_states = ()
-
- for resnet, attn in zip(self.resnets, self.attentions):
- hidden_states = resnet(hidden_states, temb)
- hidden_states = attn(hidden_states)
- output_states += (hidden_states,)
-
- if self.downsamplers is not None:
- hidden_states = self.resnet_down(hidden_states, temb)
- for downsampler in self.downsamplers:
- skip_sample = downsampler(skip_sample)
-
- hidden_states = self.skip_conv(skip_sample) + hidden_states
-
- output_states += (hidden_states,)
-
- return hidden_states, output_states, skip_sample
-
-
-class SkipDownBlock2D(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_pre_norm: bool = True,
- output_scale_factor=np.sqrt(2.0),
- add_downsample=True,
- downsample_padding=1,
- ):
- super().__init__()
- self.resnets = nn.ModuleList([])
-
- for i in range(num_layers):
- in_channels = in_channels if i == 0 else out_channels
- self.resnets.append(
- ResnetBlock2D(
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=min(in_channels // 4, 32),
- groups_out=min(out_channels // 4, 32),
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
-
- if add_downsample:
- self.resnet_down = ResnetBlock2D(
- in_channels=out_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=min(out_channels // 4, 32),
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- use_in_shortcut=True,
- down=True,
- kernel="fir",
- )
- self.downsamplers = nn.ModuleList(
- [FirDownsample2D(out_channels, out_channels=out_channels)]
- )
- self.skip_conv = nn.Conv2d(
- 3, out_channels, kernel_size=(1, 1), stride=(1, 1)
- )
- else:
- self.resnet_down = None
- self.downsamplers = None
- self.skip_conv = None
-
- def forward(self, hidden_states, temb=None, skip_sample=None):
- output_states = ()
-
- for resnet in self.resnets:
- hidden_states = resnet(hidden_states, temb)
- output_states += (hidden_states,)
-
- if self.downsamplers is not None:
- hidden_states = self.resnet_down(hidden_states, temb)
- for downsampler in self.downsamplers:
- skip_sample = downsampler(skip_sample)
-
- hidden_states = self.skip_conv(skip_sample) + hidden_states
-
- output_states += (hidden_states,)
-
- return hidden_states, output_states, skip_sample
-
-
-class ResnetDownsampleBlock2D(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- output_scale_factor=1.0,
- add_downsample=True,
- ):
- super().__init__()
- resnets = []
-
- for i in range(num_layers):
- in_channels = in_channels if i == 0 else out_channels
- resnets.append(
- ResnetBlock2D(
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
-
- self.resnets = nn.ModuleList(resnets)
-
- if add_downsample:
- self.downsamplers = nn.ModuleList(
- [
- ResnetBlock2D(
- in_channels=out_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- down=True,
- )
- ]
- )
- else:
- self.downsamplers = None
-
- self.gradient_checkpointing = False
-
- def forward(self, hidden_states, temb=None):
- output_states = ()
-
- for resnet in self.resnets:
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- hidden_states = pi.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb
- )
- else:
- hidden_states = resnet(hidden_states, temb)
-
- output_states += (hidden_states,)
-
- if self.downsamplers is not None:
- for downsampler in self.downsamplers:
- hidden_states = downsampler(hidden_states, temb)
-
- output_states += (hidden_states,)
-
- return hidden_states, output_states
-
-
-class SimpleCrossAttnDownBlock2D(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- attn_num_head_channels=1,
- cross_attention_dim=1280,
- output_scale_factor=1.0,
- add_downsample=True,
- ):
- super().__init__()
-
- self.has_cross_attention = True
-
- resnets = []
- attentions = []
-
- self.attn_num_head_channels = attn_num_head_channels
- self.num_heads = out_channels // self.attn_num_head_channels
-
- for i in range(num_layers):
- in_channels = in_channels if i == 0 else out_channels
- resnets.append(
- ResnetBlock2D(
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
- attentions.append(
- CrossAttention(
- query_dim=out_channels,
- cross_attention_dim=out_channels,
- heads=self.num_heads,
- dim_head=attn_num_head_channels,
- added_kv_proj_dim=cross_attention_dim,
- norm_num_groups=resnet_groups,
- bias=True,
- upcast_softmax=True,
- processor=CrossAttnAddedKVProcessor(),
- )
- )
- self.attentions = nn.ModuleList(attentions)
- self.resnets = nn.ModuleList(resnets)
-
- if add_downsample:
- self.downsamplers = nn.ModuleList(
- [
- ResnetBlock2D(
- in_channels=out_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- down=True,
- )
- ]
- )
- else:
- self.downsamplers = None
-
- self.gradient_checkpointing = False
-
- def forward(
- self,
- hidden_states,
- temb=None,
- encoder_hidden_states=None,
- attention_mask=None,
- cross_attention_kwargs=None,
- ):
- output_states = ()
- cross_attention_kwargs = (
- cross_attention_kwargs if cross_attention_kwargs is not None else {}
- )
-
- for resnet, attn in zip(self.resnets, self.attentions):
- # resnet
- hidden_states = resnet(hidden_states, temb)
-
- # attn
- hidden_states = attn(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=attention_mask,
- **cross_attention_kwargs,
- )
-
- output_states += (hidden_states,)
-
- if self.downsamplers is not None:
- for downsampler in self.downsamplers:
- hidden_states = downsampler(hidden_states, temb)
-
- output_states += (hidden_states,)
-
- return hidden_states, output_states
-
-
-class AttnUpBlock2D(nn.Module):
- def __init__(
- self,
- in_channels: int,
- prev_output_channel: int,
- out_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- attn_num_head_channels=1,
- output_scale_factor=1.0,
- add_upsample=True,
- ):
- super().__init__()
- resnets = []
- attentions = []
-
- for i in range(num_layers):
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
-
- resnets.append(
- ResnetBlock2D(
- in_channels=resnet_in_channels + res_skip_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
- attentions.append(
- AttentionBlock(
- out_channels,
- num_head_channels=attn_num_head_channels,
- rescale_output_factor=output_scale_factor,
- eps=resnet_eps,
- norm_num_groups=resnet_groups,
- )
- )
-
- self.attentions = nn.ModuleList(attentions)
- self.resnets = nn.ModuleList(resnets)
-
- if add_upsample:
- self.upsamplers = nn.ModuleList(
- [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
- )
- else:
- self.upsamplers = None
-
- def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
- for resnet, attn in zip(self.resnets, self.attentions):
- # pop res hidden states
- res_hidden_states = res_hidden_states_tuple[-1]
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
- hidden_states = pi.cat([hidden_states, res_hidden_states], dim=1)
-
- hidden_states = resnet(hidden_states, temb)
- hidden_states = attn(hidden_states)
-
- if self.upsamplers is not None:
- for upsampler in self.upsamplers:
- hidden_states = upsampler(hidden_states)
-
- return hidden_states
-
-
-class CrossAttnUpBlock2D(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- prev_output_channel: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- attn_num_head_channels=1,
- cross_attention_dim=1280,
- output_scale_factor=1.0,
- add_upsample=True,
- dual_cross_attention=False,
- use_linear_projection=False,
- only_cross_attention=False,
- upcast_attention=False,
- ):
- super().__init__()
- resnets = []
- attentions = []
-
- self.has_cross_attention = True
- self.attn_num_head_channels = attn_num_head_channels
-
- for i in range(num_layers):
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
-
- resnets.append(
- ResnetBlock2D(
- in_channels=resnet_in_channels + res_skip_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
- if not dual_cross_attention:
- attentions.append(
- Transformer2DModel(
- attn_num_head_channels,
- out_channels // attn_num_head_channels,
- in_channels=out_channels,
- num_layers=1,
- cross_attention_dim=cross_attention_dim,
- norm_num_groups=resnet_groups,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention,
- upcast_attention=upcast_attention,
- )
- )
- else:
- attentions.append(
- DualTransformer2DModel(
- attn_num_head_channels,
- out_channels // attn_num_head_channels,
- in_channels=out_channels,
- num_layers=1,
- cross_attention_dim=cross_attention_dim,
- norm_num_groups=resnet_groups,
- )
- )
- self.attentions = nn.ModuleList(attentions)
- self.resnets = nn.ModuleList(resnets)
-
- if add_upsample:
- self.upsamplers = nn.ModuleList(
- [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
- )
- else:
- self.upsamplers = None
-
- self.gradient_checkpointing = False
-
- def forward(
- self,
- hidden_states,
- res_hidden_states_tuple,
- temb=None,
- encoder_hidden_states=None,
- cross_attention_kwargs=None,
- upsample_size=None,
- attention_mask=None,
- ):
- # TODO(Patrick, William) - attention mask is not used
- for resnet, attn in zip(self.resnets, self.attentions):
- # pop res hidden states
- res_hidden_states = res_hidden_states_tuple[-1]
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
- hidden_states = pi.cat([hidden_states, res_hidden_states], dim=1)
-
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- hidden_states = pi.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb
- )
- hidden_states = pi.utils.checkpoint.checkpoint(
- create_custom_forward(attn, return_dict=False),
- hidden_states,
- encoder_hidden_states,
- cross_attention_kwargs,
- )[0]
- else:
- hidden_states = resnet(hidden_states, temb)
- hidden_states = attn(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- cross_attention_kwargs=cross_attention_kwargs,
- ).sample
-
- if self.upsamplers is not None:
- for upsampler in self.upsamplers:
- hidden_states = upsampler(hidden_states, upsample_size)
-
- return hidden_states
-
-
-class UpBlock2D(nn.Module):
- def __init__(
- self,
- in_channels: int,
- prev_output_channel: int,
- out_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- output_scale_factor=1.0,
- add_upsample=True,
- ):
- super().__init__()
- resnets = []
-
- for i in range(num_layers):
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
-
- resnets.append(
- ResnetBlock2D(
- in_channels=resnet_in_channels + res_skip_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
-
- self.resnets = nn.ModuleList(resnets)
-
- if add_upsample:
- self.upsamplers = nn.ModuleList(
- [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
- )
- else:
- self.upsamplers = None
-
- self.gradient_checkpointing = False
-
- def forward(
- self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None
- ):
- for resnet in self.resnets:
- # pop res hidden states
- res_hidden_states = res_hidden_states_tuple[-1]
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
- hidden_states = pi.cat([hidden_states, res_hidden_states], dim=1)
-
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- hidden_states = pi.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb
- )
- else:
- hidden_states = resnet(hidden_states, temb)
-
- if self.upsamplers is not None:
- for upsampler in self.upsamplers:
- hidden_states = upsampler(hidden_states, upsample_size)
-
- return hidden_states
-
-
-class UpDecoderBlock2D(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- output_scale_factor=1.0,
- add_upsample=True,
- ):
- super().__init__()
- resnets = []
-
- for i in range(num_layers):
- input_channels = in_channels if i == 0 else out_channels
-
- resnets.append(
- ResnetBlock2D(
- in_channels=input_channels,
- out_channels=out_channels,
- temb_channels=None,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
-
- self.resnets = nn.ModuleList(resnets)
-
- if add_upsample:
- self.upsamplers = nn.ModuleList(
- [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
- )
- else:
- self.upsamplers = None
-
- def forward(self, hidden_states):
- for resnet in self.resnets:
- hidden_states = resnet(hidden_states, temb=None)
-
- if self.upsamplers is not None:
- for upsampler in self.upsamplers:
- hidden_states = upsampler(hidden_states)
-
- return hidden_states
-
-
-class AttnUpDecoderBlock2D(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- attn_num_head_channels=1,
- output_scale_factor=1.0,
- add_upsample=True,
- ):
- super().__init__()
- resnets = []
- attentions = []
-
- for i in range(num_layers):
- input_channels = in_channels if i == 0 else out_channels
-
- resnets.append(
- ResnetBlock2D(
- in_channels=input_channels,
- out_channels=out_channels,
- temb_channels=None,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
- attentions.append(
- AttentionBlock(
- out_channels,
- num_head_channels=attn_num_head_channels,
- rescale_output_factor=output_scale_factor,
- eps=resnet_eps,
- norm_num_groups=resnet_groups,
- )
- )
-
- self.attentions = nn.ModuleList(attentions)
- self.resnets = nn.ModuleList(resnets)
-
- if add_upsample:
- self.upsamplers = nn.ModuleList(
- [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
- )
- else:
- self.upsamplers = None
-
- def forward(self, hidden_states):
- for resnet, attn in zip(self.resnets, self.attentions):
- hidden_states = resnet(hidden_states, temb=None)
- hidden_states = attn(hidden_states)
-
- if self.upsamplers is not None:
- for upsampler in self.upsamplers:
- hidden_states = upsampler(hidden_states)
-
- return hidden_states
-
-
-class AttnSkipUpBlock2D(nn.Module):
- def __init__(
- self,
- in_channels: int,
- prev_output_channel: int,
- out_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_pre_norm: bool = True,
- attn_num_head_channels=1,
- output_scale_factor=np.sqrt(2.0),
- upsample_padding=1,
- add_upsample=True,
- ):
- super().__init__()
- self.attentions = nn.ModuleList([])
- self.resnets = nn.ModuleList([])
-
- for i in range(num_layers):
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
-
- self.resnets.append(
- ResnetBlock2D(
- in_channels=resnet_in_channels + res_skip_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=min(resnet_in_channels + res_skip_channels // 4, 32),
- groups_out=min(out_channels // 4, 32),
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
-
- self.attentions.append(
- AttentionBlock(
- out_channels,
- num_head_channels=attn_num_head_channels,
- rescale_output_factor=output_scale_factor,
- eps=resnet_eps,
- )
- )
-
- self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
- if add_upsample:
- self.resnet_up = ResnetBlock2D(
- in_channels=out_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=min(out_channels // 4, 32),
- groups_out=min(out_channels // 4, 32),
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- use_in_shortcut=True,
- up=True,
- kernel="fir",
- )
- self.skip_conv = nn.Conv2d(
- out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
- )
- self.skip_norm = pi.nn.GroupNorm(
- num_groups=min(out_channels // 4, 32),
- num_channels=out_channels,
- eps=resnet_eps,
- affine=True,
- )
- self.act = nn.SiLU()
- else:
- self.resnet_up = None
- self.skip_conv = None
- self.skip_norm = None
- self.act = None
-
- def forward(
- self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None
- ):
- for resnet in self.resnets:
- # pop res hidden states
- res_hidden_states = res_hidden_states_tuple[-1]
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
- hidden_states = pi.cat([hidden_states, res_hidden_states], dim=1)
-
- hidden_states = resnet(hidden_states, temb)
-
- hidden_states = self.attentions[0](hidden_states)
-
- if skip_sample is not None:
- skip_sample = self.upsampler(skip_sample)
- else:
- skip_sample = 0
-
- if self.resnet_up is not None:
- skip_sample_states = self.skip_norm(hidden_states)
- skip_sample_states = self.act(skip_sample_states)
- skip_sample_states = self.skip_conv(skip_sample_states)
-
- skip_sample = skip_sample + skip_sample_states
-
- hidden_states = self.resnet_up(hidden_states, temb)
-
- return hidden_states, skip_sample
-
-
-class SkipUpBlock2D(nn.Module):
- def __init__(
- self,
- in_channels: int,
- prev_output_channel: int,
- out_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_pre_norm: bool = True,
- output_scale_factor=np.sqrt(2.0),
- add_upsample=True,
- upsample_padding=1,
- ):
- super().__init__()
- self.resnets = nn.ModuleList([])
-
- for i in range(num_layers):
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
-
- self.resnets.append(
- ResnetBlock2D(
- in_channels=resnet_in_channels + res_skip_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
- groups_out=min(out_channels // 4, 32),
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
-
- self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
- if add_upsample:
- self.resnet_up = ResnetBlock2D(
- in_channels=out_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=min(out_channels // 4, 32),
- groups_out=min(out_channels // 4, 32),
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- use_in_shortcut=True,
- up=True,
- kernel="fir",
- )
- self.skip_conv = nn.Conv2d(
- out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
- )
- self.skip_norm = pi.nn.GroupNorm(
- num_groups=min(out_channels // 4, 32),
- num_channels=out_channels,
- eps=resnet_eps,
- affine=True,
- )
- self.act = nn.SiLU()
- else:
- self.resnet_up = None
- self.skip_conv = None
- self.skip_norm = None
- self.act = None
-
- def forward(
- self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None
- ):
- for resnet in self.resnets:
- # pop res hidden states
- res_hidden_states = res_hidden_states_tuple[-1]
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
- hidden_states = pi.cat([hidden_states, res_hidden_states], dim=1)
-
- hidden_states = resnet(hidden_states, temb)
-
- if skip_sample is not None:
- skip_sample = self.upsampler(skip_sample)
- else:
- skip_sample = 0
-
- if self.resnet_up is not None:
- skip_sample_states = self.skip_norm(hidden_states)
- skip_sample_states = self.act(skip_sample_states)
- skip_sample_states = self.skip_conv(skip_sample_states)
-
- skip_sample = skip_sample + skip_sample_states
-
- hidden_states = self.resnet_up(hidden_states, temb)
-
- return hidden_states, skip_sample
-
-
-class ResnetUpsampleBlock2D(nn.Module):
- def __init__(
- self,
- in_channels: int,
- prev_output_channel: int,
- out_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- output_scale_factor=1.0,
- add_upsample=True,
- ):
- super().__init__()
- resnets = []
-
- for i in range(num_layers):
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
-
- resnets.append(
- ResnetBlock2D(
- in_channels=resnet_in_channels + res_skip_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
-
- self.resnets = nn.ModuleList(resnets)
-
- if add_upsample:
- self.upsamplers = nn.ModuleList(
- [
- ResnetBlock2D(
- in_channels=out_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- up=True,
- )
- ]
- )
- else:
- self.upsamplers = None
-
- self.gradient_checkpointing = False
-
- def forward(
- self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None
- ):
- for resnet in self.resnets:
- # pop res hidden states
- res_hidden_states = res_hidden_states_tuple[-1]
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
- hidden_states = pi.cat([hidden_states, res_hidden_states], dim=1)
-
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- hidden_states = pi.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb
- )
- else:
- hidden_states = resnet(hidden_states, temb)
-
- if self.upsamplers is not None:
- for upsampler in self.upsamplers:
- hidden_states = upsampler(hidden_states, temb)
-
- return hidden_states
-
-
-class SimpleCrossAttnUpBlock2D(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- prev_output_channel: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- attn_num_head_channels=1,
- cross_attention_dim=1280,
- output_scale_factor=1.0,
- add_upsample=True,
- ):
- super().__init__()
- resnets = []
- attentions = []
-
- self.has_cross_attention = True
- self.attn_num_head_channels = attn_num_head_channels
-
- self.num_heads = out_channels // self.attn_num_head_channels
-
- for i in range(num_layers):
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
-
- resnets.append(
- ResnetBlock2D(
- in_channels=resnet_in_channels + res_skip_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
- attentions.append(
- CrossAttention(
- query_dim=out_channels,
- cross_attention_dim=out_channels,
- heads=self.num_heads,
- dim_head=attn_num_head_channels,
- added_kv_proj_dim=cross_attention_dim,
- norm_num_groups=resnet_groups,
- bias=True,
- upcast_softmax=True,
- processor=CrossAttnAddedKVProcessor(),
- )
- )
- self.attentions = nn.ModuleList(attentions)
- self.resnets = nn.ModuleList(resnets)
-
- if add_upsample:
- self.upsamplers = nn.ModuleList(
- [
- ResnetBlock2D(
- in_channels=out_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- up=True,
- )
- ]
- )
- else:
- self.upsamplers = None
-
- self.gradient_checkpointing = False
-
- def forward(
- self,
- hidden_states,
- res_hidden_states_tuple,
- temb=None,
- encoder_hidden_states=None,
- upsample_size=None,
- attention_mask=None,
- cross_attention_kwargs=None,
- ):
- cross_attention_kwargs = (
- cross_attention_kwargs if cross_attention_kwargs is not None else {}
- )
- for resnet, attn in zip(self.resnets, self.attentions):
- # resnet
- # pop res hidden states
- res_hidden_states = res_hidden_states_tuple[-1]
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
- hidden_states = pi.cat([hidden_states, res_hidden_states], dim=1)
-
- hidden_states = resnet(hidden_states, temb)
-
- # attn
- hidden_states = attn(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=attention_mask,
- **cross_attention_kwargs,
- )
-
- if self.upsamplers is not None:
- for upsampler in self.upsamplers:
- hidden_states = upsampler(hidden_states, temb)
-
- return hidden_states
diff --git a/pi/models/unet/unet_2d_condition.py b/pi/models/unet/unet_2d_condition.py
deleted file mode 100644
index e5df4d1..0000000
--- a/pi/models/unet/unet_2d_condition.py
+++ /dev/null
@@ -1,517 +0,0 @@
-import logging
-import math
-from dataclasses import dataclass
-from typing import Optional, Tuple, Union, Dict, List, Any
-
-from diffusers.models.unet_2d_blocks import get_down_block
-
-from .cross_attention import AttnProcessor
-from .embeddings import Timesteps, TimestepEmbedding
-from .unet_2d_blocks import (
- UNetMidBlock2DCrossAttn,
- UNetMidBlock2DSimpleCrossAttn,
- get_up_block,
- CrossAttnDownBlock2D,
- DownBlock2D,
- CrossAttnUpBlock2D,
- UpBlock2D,
-)
-from ... import nn, Tensor
-from ... import pi
-
-
-logger = logging.getLogger(__name__)
-
-
-@dataclass
-class UNet2DConditionOutput:
- """
- Args:
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
- Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
- """
-
- sample: pi.FloatTensor
-
-
-class UNet2DConditionModel:
- def __init__(
- self,
- sample_size: Optional[int] = None,
- in_channels: int = 4,
- out_channels: int = 4,
- center_input_sample: bool = False,
- flip_sin_to_cos: bool = True,
- freq_shift: int = 0,
- down_block_types: Tuple[str] = (
- "CrossAttnDownBlock2D",
- "CrossAttnDownBlock2D",
- "CrossAttnDownBlock2D",
- "DownBlock2D",
- ),
- mid_block_type: str = "UNetMidBlock2DCrossAttn",
- up_block_types: Tuple[str] = (
- "UpBlock2D",
- "CrossAttnUpBlock2D",
- "CrossAttnUpBlock2D",
- "CrossAttnUpBlock2D",
- ),
- only_cross_attention: Union[bool, Tuple[bool]] = False,
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
- layers_per_block: int = 2,
- downsample_padding: int = 1,
- mid_block_scale_factor: float = 1,
- act_fn: str = "silu",
- norm_num_groups: int = 32,
- norm_eps: float = 1e-5,
- cross_attention_dim: int = 1280,
- attention_head_dim: Union[int, Tuple[int]] = 8,
- dual_cross_attention: bool = False,
- use_linear_projection: bool = False,
- class_embed_type: Optional[str] = None,
- num_class_embeds: Optional[int] = None,
- upcast_attention: bool = False,
- resnet_time_scale_shift: str = "default",
- ):
- super().__init__()
-
- self.sample_size = sample_size
- time_embed_dim = block_out_channels[0] * 4
-
- # input
- self.conv_in = nn.Conv2d(
- in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)
- )
-
- # time
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
- timestep_input_dim = block_out_channels[0]
-
- self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
-
- # class embedding
- if class_embed_type is None and num_class_embeds is not None:
- self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
- elif class_embed_type == "timestep":
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
- elif class_embed_type == "identity":
- self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
- else:
- self.class_embedding = None
-
- self.down_blocks = nn.ModuleList([])
- self.mid_block = None
- self.up_blocks = nn.ModuleList([])
-
- if isinstance(only_cross_attention, bool):
- only_cross_attention = [only_cross_attention] * len(down_block_types)
-
- if isinstance(attention_head_dim, int):
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
-
- # down
- output_channel = block_out_channels[0]
- for i, down_block_type in enumerate(down_block_types):
- input_channel = output_channel
- output_channel = block_out_channels[i]
- is_final_block = i == len(block_out_channels) - 1
-
- down_block = get_down_block(
- down_block_type,
- num_layers=layers_per_block,
- in_channels=input_channel,
- out_channels=output_channel,
- temb_channels=time_embed_dim,
- add_downsample=not is_final_block,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- resnet_groups=norm_num_groups,
- cross_attention_dim=cross_attention_dim,
- attn_num_head_channels=attention_head_dim[i],
- downsample_padding=downsample_padding,
- dual_cross_attention=dual_cross_attention,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention[i],
- upcast_attention=upcast_attention,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- self.down_blocks.append(down_block)
-
- # mid
- if mid_block_type == "UNetMidBlock2DCrossAttn":
- self.mid_block = UNetMidBlock2DCrossAttn(
- in_channels=block_out_channels[-1],
- temb_channels=time_embed_dim,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- output_scale_factor=mid_block_scale_factor,
- resnet_time_scale_shift=resnet_time_scale_shift,
- cross_attention_dim=cross_attention_dim,
- attn_num_head_channels=attention_head_dim[-1],
- resnet_groups=norm_num_groups,
- dual_cross_attention=dual_cross_attention,
- use_linear_projection=use_linear_projection,
- upcast_attention=upcast_attention,
- )
- elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
- self.mid_block = UNetMidBlock2DSimpleCrossAttn(
- in_channels=block_out_channels[-1],
- temb_channels=time_embed_dim,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- output_scale_factor=mid_block_scale_factor,
- cross_attention_dim=cross_attention_dim,
- attn_num_head_channels=attention_head_dim[-1],
- resnet_groups=norm_num_groups,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- else:
- raise ValueError(f"unknown mid_block_type : {mid_block_type}")
-
- # count how many layers upsample the images
- self.num_upsamplers = 0
-
- # up
- reversed_block_out_channels = list(reversed(block_out_channels))
- reversed_attention_head_dim = list(reversed(attention_head_dim))
- only_cross_attention = list(reversed(only_cross_attention))
- output_channel = reversed_block_out_channels[0]
- for i, up_block_type in enumerate(up_block_types):
- is_final_block = i == len(block_out_channels) - 1
-
- prev_output_channel = output_channel
- output_channel = reversed_block_out_channels[i]
- input_channel = reversed_block_out_channels[
- min(i + 1, len(block_out_channels) - 1)
- ]
-
- # add upsample block for all BUT final layer
- if not is_final_block:
- add_upsample = True
- self.num_upsamplers += 1
- else:
- add_upsample = False
-
- up_block = get_up_block(
- up_block_type,
- num_layers=layers_per_block + 1,
- in_channels=input_channel,
- out_channels=output_channel,
- prev_output_channel=prev_output_channel,
- temb_channels=time_embed_dim,
- add_upsample=add_upsample,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- resnet_groups=norm_num_groups,
- cross_attention_dim=cross_attention_dim,
- attn_num_head_channels=reversed_attention_head_dim[i],
- dual_cross_attention=dual_cross_attention,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention[i],
- upcast_attention=upcast_attention,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- self.up_blocks.append(up_block)
- prev_output_channel = output_channel
-
- # out
- self.conv_norm_out = nn.GroupNorm(
- num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
- )
- self.conv_act = nn.SiLU()
- self.conv_out = nn.Conv2d(
- block_out_channels[0], out_channels, kernel_size=3, padding=1
- )
-
- @property
- def attn_processors(self) -> Dict[str, AttnProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(
- name: str, module: pi.nn.Module, processors: Dict[str, AttnProcessor]
- ):
- if hasattr(module, "set_processor"):
- processors[f"{name}.processor"] = module.processor
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- def set_attn_processor(
- self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]
- ):
- r"""
- Parameters:
- `processor (`dict` of `AttnProcessor` or `AttnProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- of **all** `CrossAttention` layers.
- In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: pi.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
- def set_attention_slice(self, slice_size):
- r"""
- Enable sliced attention computation.
-
- When this option is enabled, the attention module will split the input tensor in slices, to compute attention
- in several steps. This is useful to save some memory in exchange for a small speed decrease.
-
- Args:
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
- When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
- `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
- must be a multiple of `slice_size`.
- """
- sliceable_head_dims = []
-
- def fn_recursive_retrieve_slicable_dims(module: pi.nn.Module):
- if hasattr(module, "set_attention_slice"):
- sliceable_head_dims.append(module.sliceable_head_dim)
-
- for child in module.children():
- fn_recursive_retrieve_slicable_dims(child)
-
- # retrieve number of attention layers
- for module in self.children():
- fn_recursive_retrieve_slicable_dims(module)
-
- num_slicable_layers = len(sliceable_head_dims)
-
- if slice_size == "auto":
- # half the attention head size is usually a good trade-off between
- # speed and memory
- slice_size = [dim // 2 for dim in sliceable_head_dims]
- elif slice_size == "max":
- # make smallest slice possible
- slice_size = num_slicable_layers * [1]
-
- slice_size = (
- num_slicable_layers * [slice_size]
- if not isinstance(slice_size, list)
- else slice_size
- )
-
- if len(slice_size) != len(sliceable_head_dims):
- raise ValueError(
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
- )
-
- for i in range(len(slice_size)):
- size = slice_size[i]
- dim = sliceable_head_dims[i]
- if size is not None and size > dim:
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
-
- # Recursively walk through all the children.
- # Any children which exposes the set_attention_slice method
- # gets the message
- def fn_recursive_set_attention_slice(
- module: pi.nn.Module, slice_size: List[int]
- ):
- if hasattr(module, "set_attention_slice"):
- module.set_attention_slice(slice_size.pop())
-
- for child in module.children():
- fn_recursive_set_attention_slice(child, slice_size)
-
- reversed_slice_size = list(reversed(slice_size))
- for module in self.children():
- fn_recursive_set_attention_slice(module, reversed_slice_size)
-
- def _set_gradient_checkpointing(self, module, value=False):
- if isinstance(
- module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)
- ):
- module.gradient_checkpointing = value
-
- def forward(
- self,
- sample: pi.FloatTensor,
- timestep: Union[pi.Tensor, float, int],
- encoder_hidden_states: pi.Tensor,
- class_labels: Optional[pi.Tensor] = None,
- attention_mask: Optional[pi.Tensor] = None,
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
- return_dict: bool = True,
- ) -> Union[UNet2DConditionOutput, Tuple]:
- r"""
- Args:
- sample (`pi.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
- timestep (`pi.FloatTensor` or `float` or `int`): (batch) timesteps
- encoder_hidden_states (`pi.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
- return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
-
- Returns:
- [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
- [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
- returning a tuple, the first element is the sample tensor.
- """
- # By default samples have to be AT least a multiple of the overall upsampling factor.
- # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
- # However, the upsampling interpolation output size can be forced to fit any upsampling size
- # on the fly if necessary.
- default_overall_up_factor = 2 ** self.num_upsamplers
-
- # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
- forward_upsample_size = False
- upsample_size = None
-
- if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
- logger.info("Forward upsample size to force interpolation output size.")
- forward_upsample_size = True
-
- # prepare attention_mask
- if attention_mask is not None:
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
- attention_mask = attention_mask.unsqueeze(1)
-
- # 0. center input if necessary
- if self.config.center_input_sample:
- sample = 2 * sample - 1.0
-
- # 1. time
- timesteps = timestep
- if not pi.is_tensor(timesteps):
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
- # This would be a good case for the `match` statement (Python 3.10+)
- is_mps = sample.device.type == "mps"
- if isinstance(timestep, float):
- dtype = pi.float32 if is_mps else pi.float64
- else:
- dtype = pi.int32 if is_mps else pi.int64
- timesteps = pi.tensor([timesteps], dtype=dtype, device=sample.device)
- elif len(timesteps.shape) == 0:
- timesteps = timesteps[None].to(sample.device)
-
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
- timesteps = timesteps.expand(sample.shape[0])
-
- t_emb = self.time_proj(timesteps)
-
- # timesteps does not contain any weights and will always return f32 tensors
- # but time_embedding might actually be running in fp16. so we need to cast here.
- # there might be better ways to encapsulate this.
- t_emb = t_emb.to(dtype=self.dtype)
- emb = self.time_embedding(t_emb)
-
- if self.class_embedding is not None:
- if class_labels is None:
- raise ValueError(
- "class_labels should be provided when num_class_embeds > 0"
- )
-
- if self.config.class_embed_type == "timestep":
- class_labels = self.time_proj(class_labels)
-
- class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
- emb = emb + class_emb
-
- # 2. pre-process
- sample = self.conv_in(sample)
-
- # 3. down
- down_block_res_samples = (sample,)
- for downsample_block in self.down_blocks:
- if (
- hasattr(downsample_block, "has_cross_attention")
- and downsample_block.has_cross_attention
- ):
- sample, res_samples = downsample_block(
- hidden_states=sample,
- temb=emb,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=attention_mask,
- cross_attention_kwargs=cross_attention_kwargs,
- )
- else:
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
-
- down_block_res_samples += res_samples
-
- # 4. mid
- sample = self.mid_block(
- sample,
- emb,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=attention_mask,
- cross_attention_kwargs=cross_attention_kwargs,
- )
-
- # 5. up
- for i, upsample_block in enumerate(self.up_blocks):
- is_final_block = i == len(self.up_blocks) - 1
-
- res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
- down_block_res_samples = down_block_res_samples[
- : -len(upsample_block.resnets)
- ]
-
- # if we have not reached the final block and need to forward the
- # upsample size, we do it here
- if not is_final_block and forward_upsample_size:
- upsample_size = down_block_res_samples[-1].shape[2:]
-
- if (
- hasattr(upsample_block, "has_cross_attention")
- and upsample_block.has_cross_attention
- ):
- sample = upsample_block(
- hidden_states=sample,
- temb=emb,
- res_hidden_states_tuple=res_samples,
- encoder_hidden_states=encoder_hidden_states,
- cross_attention_kwargs=cross_attention_kwargs,
- upsample_size=upsample_size,
- attention_mask=attention_mask,
- )
- else:
- sample = upsample_block(
- hidden_states=sample,
- temb=emb,
- res_hidden_states_tuple=res_samples,
- upsample_size=upsample_size,
- )
- # 6. post-process
- sample = self.conv_norm_out(sample)
- sample = self.conv_act(sample)
- sample = self.conv_out(sample)
-
- if not return_dict:
- return (sample,)
-
- return UNet2DConditionOutput(sample=sample)
diff --git a/pyproject.toml b/pyproject.toml
index 37876fc..314893f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -9,6 +9,12 @@ requires = ["setuptools>=42",
"ninja"]
build-backend = "setuptools.build_meta"
+[tool.pytest.ini_options]
+log_cli = false
+log_cli_level = "DEBUG"
+log_cli_format = "[%(filename)s:%(funcName)s:%(lineno)d] %(message)s"
+log_cli_date_format = "%Y-%m-%d %H:%M:%S"
+
[tool.cibuildwheel]
before-build = "pip install -r build-requirements.txt -r requirements.txt -v"
# HOLY FUCK
diff --git a/requirements.txt b/requirements.txt
index 8ec3a01..a9b60f5 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,3 +7,4 @@ typeguard==3.0.0b2
multiprocess
pybind11
pytest
+pyccolo
\ No newline at end of file