Skip to content

Commit

Permalink
feat: add apex unet example
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Dec 28, 2022
1 parent a6de0f8 commit 3090c6b
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 90 deletions.
137 changes: 70 additions & 67 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,89 +69,92 @@ x = torch.randn(1, 8, 16, 16)
y = unet(x) # [1, 8, 16, 16]
```

### Attention UNet

<details> <summary> (Code): A UNet generic to any dimension augmented with attention and cross attention for conditioning. </summary>
### ApeX UNet

<details> <summary> (Code): ApeX is a UNet template complete with tools for easy customizability. The following example UNet includes multiple features, including: (1) custom item arrangement for resnets, modulation, attention, and cross attention, (2) custom skip connection with concatenation, (3) time conditioning (usually used for diffusion), (4) classifier free guidance. </summary>

```py
from typing import List
from torch import nn
from a_unet import T, Ts, Downsample, Upsample, ResnetBlock, Attention, FeedForward, Select, Sequential, Repeat, Packed, Skip
from typing import Sequence, Optional

from a_unet import TimeConditioningPlugin, ClassifierFreeGuidancePlugin
from a_unet.apex import (
XUNet,
XBlock,
ResnetItem as R,
AttentionItem as A,
CrossAttentionItem as C,
FeedForwardItem as F,
ModulationItem as M,
SkipCat
)

def UNet(
dim: int,
in_channels: int,
context_features: int,
channels: List[int],
factors: List[int],
blocks: List[int],
attentions: List[int],
attention_heads: int,
channels: Sequence[int],
factors: Sequence[int],
items: Sequence[int],
attentions: Sequence[int],
cross_attentions: Sequence[int],
attention_features: int,
attention_multiplier: int,
attention_heads: int,
embedding_features: Optional[int] = None,
skip_t: Callable = SkipCat,
resnet_groups: int = 8,
modulation_features: int = 1024,
embedding_max_length: int = 0,
use_classifier_free_guidance: bool = False,
out_channels: Optional[int] = None,
):
# Check that all lists have matching lengths
n_layers = len(channels)
assert all(len(xs) == n_layers for xs in (factors, blocks, attentions))

# Selects only first module input, ignores context
S = Select(lambda x, context: x)

# Pre-initalize attention, cross-attention, and feed-forward types with parameters
A = T(Attention)(head_features=attention_features, num_heads=attention_heads)
C = T(A)(context_features=context_features) # Same as A but with context features
F = T(FeedForward)(multiplier=attention_multiplier)

def Stack(channels: int, n_blocks: int, n_attentions: int):
# Build resnet stack type
Block = T(ResnetBlock)(dim=dim, in_channels=channels, out_channels=channels)
ResnetStack = S(Repeat(Block, times=n_blocks))
# Build attention, cross att, and feed forward types (ignoring context in A & F)
Attention = T(S(A))(features=channels)
CrossAttention = T(C)(features=channels)
FeedForward = T(S(F))(features=channels)
# Build transformer type
Transformer = Ts(Sequential)(Attention, CrossAttention, FeedForward)
TransformerStack = Repeat(Transformer, times=n_attentions)
# Instantiate sequential resnet stack and transformer stack
return Sequential(ResnetStack(), Packed(TransformerStack()))

# Downsample and upsample types that ignore context
Down = T(S(Downsample))(dim=dim)
Up = T(S(Upsample))(dim=dim)

def Net(i: int):
if i == n_layers: return S(nn.Identity)()
n_channels = channels[i-1] if i > 0 else in_channels
factor = factors[i]

return Skip(torch.add)(
Down(factor=factor, in_channels=n_channels, out_channels=channels[i]),
Stack(channels=channels[i], n_blocks=blocks[i], n_attentions=attentions[i]),
Net(i+1),
Stack(channels=channels[i], n_blocks=blocks[i], n_attentions=attentions[i]),
Up(factor=factor, in_channels=channels[i], out_channels=n_channels)
)

return Net(0)
# Check lengths
num_layers = len(channels)
sequences = (channels, factors, items, attentions, cross_attentions)
assert all(len(sequence) == num_layers for sequence in sequences)

# Define UNet type with time conditioning and CFG plugins
UNet = TimeConditioningPlugin(XUNet)
if use_classifier_free_guidance:
UNet = ClassifierFreeGuidancePlugin(UNet, embedding_max_length)

return UNet(
dim=dim,
in_channels=in_channels,
out_channels=out_channels,
blocks=[
XBlock(
channels=channels,
factor=factor,
items=([R, M] + [A] * n_att + [C] * n_cross) * n_items,
) for channels, factor, n_items, n_att, n_cross in zip(*sequences)
],
skip_t=skip_t,
attention_features=attention_features,
attention_heads=attention_heads,
embedding_features=embedding_features,
modulation_features=modulation_features,
resnet_groups=resnet_groups
)
```

</details>

```py
unet = UNet(
dim=2,
in_channels=8,
context_features=512,
channels=[256, 512],
factors=[2, 2],
blocks=[2, 2],
attentions=[2, 2],
attention_heads=8,
in_channels=2,
channels=[128, 256, 512, 1024],
factors=[2, 2, 2, 2],
items=[2, 2, 2, 2],
attentions=[0, 0, 0, 1],
cross_attentions=[1, 1, 1, 1],
attention_features=64,
attention_multiplier=4,
attention_heads=8,
embedding_features=768,
use_classifier_free_guidance=False
)
x = torch.randn(1, 8, 16, 16)
context = torch.randn(1, 256, 512)
y = unet(x, context) # [1, 8, 16, 16]
x = torch.randn(2, 2, 64, 64)
time = [0.2, 0.5]
embedding = torch.randn(2, 512, 768)
y = unet(x, time=time, embedding=embedding) # [2, 2, 64, 64]
```
26 changes: 13 additions & 13 deletions a_unet/unet/apex.py → a_unet/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from torch import Tensor, nn

from ..blocks import (
from .blocks import (
Attention,
Conv,
ConvBlock,
Expand All @@ -28,7 +28,7 @@
Items
"""

# Selections for item forward paramters
# Selections for item forward parameters
SelectX = Select(lambda x, *_: (x,))
SelectXE = Select(lambda x, f, e, *_: (x, e))
SelectXF = Select(lambda x, f, *_: (x, f))
Expand Down Expand Up @@ -204,13 +204,13 @@ def FeedForwardItem(
""" Skip Adapters """


def SkipAdapterItem(
def SkipAdapter(
dim: Optional[int] = None,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
**kwargs,
):
msg = "SkipAdapterItem requires dim, in_channels, out_channels"
msg = "SkipAdapter requires dim, in_channels, out_channels"
assert exists(dim) and exists(in_channels) and exists(out_channels), msg
Item = SelectX(Conv)
return (
Expand All @@ -228,25 +228,25 @@ def SkipAdapterItem(
""" Skip Connections """


def SkipAddItem(**kwargs) -> nn.Module:
def SkipAdd(**kwargs) -> nn.Module:
return MergeAdd()


def SkipCatItem(
def SkipCat(
dim: Optional[int] = None, out_channels: Optional[int] = None, **kwargs
) -> nn.Module:
msg = "SkipCatItem requires dim, out_channels"
msg = "SkipCat requires dim, out_channels"
assert exists(dim) and exists(out_channels), msg
return MergeCat(dim=dim, channels=out_channels)


def SkipModulateItem(
def SkipModulate(
dim: Optional[int] = None,
out_channels: Optional[int] = None,
modulation_features: Optional[int] = None,
**kwargs,
) -> nn.Module:
msg = "SkipModulateItem requires dim, out_channels, modulation_features"
msg = "SkipModulate requires dim, out_channels, modulation_features"
assert exists(dim) and exists(out_channels) and exists(modulation_features), msg
return MergeModulate(
dim=dim, channels=out_channels, modulation_features=modulation_features
Expand All @@ -262,8 +262,8 @@ def __init__(
in_channels: int,
downsample_t: Callable = DownsampleItem,
upsample_t: Callable = UpsampleItem,
skip_t: Callable = SkipAddItem,
skip_adapter_t: Callable = SkipAdapterItem,
skip_t: Callable = SkipAdd,
skip_adapter_t: Callable = SkipAdapter,
items: Sequence[Callable] = [],
items_up: Optional[Sequence[Callable]] = None,
out_channels: Optional[int] = None,
Expand Down Expand Up @@ -304,13 +304,13 @@ def forward(


# Block type, to be provided in UNet
BlockT = T(Block, override=False)
XBlock = T(Block, override=False)


""" UNet """


class UNet(nn.Module):
class XUNet(nn.Module):
def __init__(
self,
in_channels: int,
Expand Down
18 changes: 9 additions & 9 deletions a_unet/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,15 +344,6 @@ def forward(x: Tensor, y: Tensor, features: Tensor, *args) -> Tensor:
return Module([to_scale], forward)


def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor:
if proba == 1:
return torch.ones(shape, device=device, dtype=torch.bool)
elif proba == 0:
return torch.zeros(shape, device=device, dtype=torch.bool)
else:
return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)


"""
Embedders
"""
Expand Down Expand Up @@ -419,6 +410,15 @@ def forward(self, texts: Sequence[str]) -> Tensor:
"""


def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor:
if proba == 1:
return torch.ones(shape, device=device, dtype=torch.bool)
elif proba == 0:
return torch.zeros(shape, device=device, dtype=torch.bool)
else:
return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)


def ClassifierFreeGuidancePlugin(
net_t: Type[nn.Module],
embedding_max_length: int,
Expand Down
Empty file removed a_unet/unet/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="a-unet",
packages=find_packages(exclude=[]),
version="0.0.7",
version="0.0.8",
license="MIT",
description="A-UNet",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 3090c6b

Please sign in to comment.