Skip to content

Commit

Permalink
Changed UNet.py
Browse files Browse the repository at this point in the history
Changed layer order from conv -> norm -> act to norm -> act -> conv
Fixed incorrect activations in residual block bias and residual connection
Removed ConvBlock
  • Loading branch information
abarankab committed Sep 21, 2021
1 parent 579e3b2 commit 6bf4897
Showing 1 changed file with 58 additions and 77 deletions.
135 changes: 58 additions & 77 deletions ddpm/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@
from einops import rearrange


def get_norm(norm, num_features, num_groups):
if norm == "in":
return nn.InstanceNorm2d(num_features, affine=True)
elif norm == "bn":
return nn.BatchNorm2d(num_features)
elif norm == "gn":
return nn.GroupNorm2d(num_groups, num_features)
else:
raise ValueError("unknown normalization type")


class PositionalEmbedding(nn.Module):
__doc__ = r"""Computes a positional embedding of timesteps.
Expand Down Expand Up @@ -79,62 +90,14 @@ def __init__(self, in_channels, align_corners=True, use_reflection_pad=False):
super().__init__()

self.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=align_corners),
nn.Upsample(scale_factor=2, align_corners=align_corners),
nn.Conv2d(in_channels, in_channels, 3, padding=1, padding_mode="zeros" if not use_reflection_pad else "reflect"),
)

def forward(self, x):
return self.upsample(x)


class ConvBlock(nn.Module):
__doc__ = r"""Applies 2d convolution, normalization and activation to a tensor.
Input:
x: tensor of shape (N, in_channels, H, W)
Output:
tensor of shape (N, out_channels, H, W)
Args:
in_channels (int): number of input channels
out_channels (int): number of output channels
activation (function): activation function. Default: torch.nn.functional.relu
norm (string or None): which normalization to use (instance, group, batch, or none). Default: "in"
num_groups (int): number of groups used in group normalization. Default: 8
"""

def __init__(
self,
in_channels,
out_channels,
activation=F.relu,
norm="in",
num_groups=8,
):
super().__init__()

self.activation = activation

modules = []
modules.append(nn.Conv2d(
in_channels, out_channels, 3, padding=1,
bias=False if norm is not None else True,
))

if norm == "in":
modules.append(nn.InstanceNorm2d(out_channels, affine=True))
elif norm == "gn":
modules.append(nn.GroupNorm(num_groups, out_channels))
elif norm == "bn":
modules.append(nn.BatchNorm2d(out_channels))
elif norm is not None:
raise ValueError("__init__() got unknown normalization type")

self.block = nn.Sequential(*modules)

def forward(self, x):
return self.activation(self.block(x))


class ResidualBlock(nn.Module):
__doc__ = r"""Applies two conv blocks with resudual connection. Adds time and class conditioning by adding bias after first convolution.
Expand All @@ -150,8 +113,8 @@ class ResidualBlock(nn.Module):
time_emb_dim (int or None): time embedding dimension or None if the block doesn't use time conditioning. Default: None
num_classes (int or None): number of classes or None if the block doesn't use class conditioning. Default: None
activation (function): activation function. Default: torch.nn.functional.relu
norm (string or None): which normalization to use (instance, group, batch, or none). Default: "in"
num_groups (int): number of groups used in group normalization. Default: 8
norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn"
num_groups (int): number of groups used in group normalization. Default: 32
"""

def __init__(
Expand All @@ -161,42 +124,51 @@ def __init__(
dropout,
time_emb_dim=None,
num_classes=None,
**kwargs,
activation=F.relu,
norm="gn",
num_groups=32,
):
super().__init__()

self.conv_block_1 = ConvBlock(in_channels, out_channels, **kwargs)
self.conv_block_2 = nn.Sequential(
self.activation = activation

self.norm_1 = get_norm(norm, in_channels, num_groups)
self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)

self.conv_2 = nn.Sequential(
nn.Dropout(p=dropout),
ConvBlock(out_channels, out_channels, **kwargs),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
)
self.norm_2 = get_norm(norm, out_channels)

self.time_bias = nn.Sequential(nn.ReLU(), nn.Linear(time_emb_dim, out_channels)) if time_emb_dim is not None else None
self.time_bias = nn.Linear(time_emb_dim, out_channels) if time_emb_dim is not None else None
self.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else None

self.residual_connection = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()

def forward(self, x, time_emb=None, y=None):
out = self.conv_block_1(x)
out = self.activation(self.norm_1(x))
out = self.conv_1(out)

if self.time_bias is not None:
if time_emb is None:
raise ValueError("time conditioning was specified but time_emb is not passed")
out += self.time_bias(time_emb)[:, :, None, None]
out += self.time_bias(self.activation(time_emb))[:, :, None, None]

if self.class_bias is not None:
if y is None:
raise ValueError("class conditioning was specified but y is not passed")

out += self.class_bias(y)[:, :, None, None]

out = self.conv_block_2(out)

out = self.activation(self.norm_2(x))
out = self.conv_2(out)

return out + self.residual_connection(x)


class AttentionBlock(nn.Module):
__doc__ = r"""Applies linear attention with a residual connection.
__doc__ = r"""Applies attention with a residual connection. This part differs a lot from what was used in the paper.
Input:
x: tensor of shape (N, in_channels, H, W)
Expand All @@ -211,7 +183,7 @@ def __init__(self, in_channels, heads=4, head_channels=32):
super().__init__()
self.heads = heads

self.norm = nn.InstanceNorm2d(in_channels, affine=True)
self.norm = nn.InstanceNorm2d(in_channels, affine=True) # GroupNorm in paper

mid_channels = head_channels * heads
self.to_qkv = nn.Conv2d(in_channels, mid_channels * 3, 1, bias=False)
Expand Down Expand Up @@ -246,9 +218,9 @@ class UNet(nn.Module):
num_classes (int or None): number of classes or None if the block doesn't use class conditioning. Default: None
activation (function): activation function. Default: torch.nn.functional.relu
dropout (float): dropout rate at the end of each residual block
use_attn (bool): it True linear attention is used in residual blocks. Default: True
norm (string or None): which normalization to use (instance, group, batch, or none). Default: "in"
num_groups (int): number of groups used in group normalization. Default: 8
attention_resolutions (tuple): list of relative resolutions at which to apply attention. Default: ()
norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn"
num_groups (int): number of groups used in group normalization. Default: 32
align_corners (bool): align_corners in bilinear upsampling. Default: True
use_reflection_pad (bool): if True reflection pad is used, otherwise zero pad is used. Default: False
initial_pad (int): initial padding applied to image. Should be used if height or width is not a power of 2. Default: 0
Expand All @@ -264,15 +236,16 @@ def __init__(
num_classes=None,
activation=F.relu,
dropout=0.1,
use_attn=True,
norm="in",
num_groups=8,
attention_resolutions=(),
norm="gn",
num_groups=32,
align_corners=True,
use_reflection_pad=False,
initial_pad=0,
):
super().__init__()

self.activation = activation
self.initial_pad = initial_pad

self.num_classes = num_classes
Expand All @@ -283,6 +256,8 @@ def __init__(
nn.Linear(time_emb_dim * 4, time_emb_dim),
) if time_emb_dim is not None else None

self.init_conv = nn.Conv2d(img_channels, in_channels, 3, padding=1)

channels = (img_channels, *[base_channels * mult for mult in channel_mults])
channel_pairs = tuple(zip(channels[:-1], channels[1:]))

Expand All @@ -291,39 +266,40 @@ def __init__(

for ind, (in_channels, out_channels) in enumerate(channel_pairs):
is_last = (ind == len(channel_pairs) - 1)
relative_resolution = ind

self.downs.append(nn.ModuleList([
ResidualBlock(in_channels, out_channels, dropout,
time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups),
ResidualBlock(out_channels, out_channels, dropout,
time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups),
AttentionBlock(out_channels) if use_attn else nn.Identity(),
AttentionBlock(out_channels) if relative_resolution in attention_resolutions else nn.Identity(),
Downsample(out_channels, use_reflection_pad=use_reflection_pad) if not is_last else nn.Identity(),
]))

mid_channels = channels[-1]
self.mid1 = ResidualBlock(
mid_channels, mid_channels, dropout,
time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups)
self.mid_attn = AttentionBlock(mid_channels) if use_attn else nn.Identity()
self.mid_attn = AttentionBlock(mid_channels)
self.mid2 = ResidualBlock(
mid_channels, mid_channels, dropout,
time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups)

for ind, (in_channels, out_channels) in enumerate(reversed(channel_pairs[1:])):
relative_resolution = len(channel_mults) - ind - 1

self.ups.append(nn.ModuleList([
ResidualBlock(out_channels * 2, in_channels, dropout,
time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups),
ResidualBlock(in_channels, in_channels, dropout,
time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups),
AttentionBlock(in_channels) if use_attn else nn.Identity(),
AttentionBlock(in_channels) if relative_resolution in attention_resolutions else nn.Identity(),
Upsample(in_channels, align_corners=align_corners, use_reflection_pad=use_reflection_pad),
]))

self.out_conv = nn.Sequential(
ConvBlock(base_channels, base_channels, activation=activation, norm=norm, num_groups=num_groups),
nn.Conv2d(base_channels, img_channels, 1),
)
self.out_norm = get_norm(norm, base_channels, num_groups)
self.out_conv = nn.Conv2d(base_channels, img_channels, 3)

def forward(self, x, time=None, y=None):
ip = self.initial_pad
Expand All @@ -341,6 +317,8 @@ def forward(self, x, time=None, y=None):
if self.num_classes is not None and y is None:
raise ValueError("class conditioning was specified but y is not passed")

x = self.init_conv(x)

skips = []

for r1, r2, attn, downsample in self.downs:
Expand All @@ -359,8 +337,11 @@ def forward(self, x, time=None, y=None):
x = r2(x, time_emb, y)
x = attn(x)
x = upsample(x)

x = self.activation(self.out_norm(x))
x = self.out_conv(x)

if self.initial_pad != 0:
return self.out_conv(x)[:, :, ip:-ip, ip:-ip]
return x[:, :, ip:-ip, ip:-ip]
else:
return self.out_conv(x)
return x

0 comments on commit 6bf4897

Please sign in to comment.