diff --git a/ddpm/unet.py b/ddpm/unet.py index ef457ba..89ee662 100644 --- a/ddpm/unet.py +++ b/ddpm/unet.py @@ -6,6 +6,17 @@ from einops import rearrange +def get_norm(norm, num_features, num_groups): + if norm == "in": + return nn.InstanceNorm2d(num_features, affine=True) + elif norm == "bn": + return nn.BatchNorm2d(num_features) + elif norm == "gn": + return nn.GroupNorm2d(num_groups, num_features) + else: + raise ValueError("unknown normalization type") + + class PositionalEmbedding(nn.Module): __doc__ = r"""Computes a positional embedding of timesteps. @@ -79,7 +90,7 @@ def __init__(self, in_channels, align_corners=True, use_reflection_pad=False): super().__init__() self.upsample = nn.Sequential( - nn.Upsample(scale_factor=2, mode="bilinear", align_corners=align_corners), + nn.Upsample(scale_factor=2, align_corners=align_corners), nn.Conv2d(in_channels, in_channels, 3, padding=1, padding_mode="zeros" if not use_reflection_pad else "reflect"), ) @@ -87,54 +98,6 @@ def forward(self, x): return self.upsample(x) -class ConvBlock(nn.Module): - __doc__ = r"""Applies 2d convolution, normalization and activation to a tensor. - - Input: - x: tensor of shape (N, in_channels, H, W) - Output: - tensor of shape (N, out_channels, H, W) - Args: - in_channels (int): number of input channels - out_channels (int): number of output channels - activation (function): activation function. Default: torch.nn.functional.relu - norm (string or None): which normalization to use (instance, group, batch, or none). Default: "in" - num_groups (int): number of groups used in group normalization. Default: 8 - """ - - def __init__( - self, - in_channels, - out_channels, - activation=F.relu, - norm="in", - num_groups=8, - ): - super().__init__() - - self.activation = activation - - modules = [] - modules.append(nn.Conv2d( - in_channels, out_channels, 3, padding=1, - bias=False if norm is not None else True, - )) - - if norm == "in": - modules.append(nn.InstanceNorm2d(out_channels, affine=True)) - elif norm == "gn": - modules.append(nn.GroupNorm(num_groups, out_channels)) - elif norm == "bn": - modules.append(nn.BatchNorm2d(out_channels)) - elif norm is not None: - raise ValueError("__init__() got unknown normalization type") - - self.block = nn.Sequential(*modules) - - def forward(self, x): - return self.activation(self.block(x)) - - class ResidualBlock(nn.Module): __doc__ = r"""Applies two conv blocks with resudual connection. Adds time and class conditioning by adding bias after first convolution. @@ -150,8 +113,8 @@ class ResidualBlock(nn.Module): time_emb_dim (int or None): time embedding dimension or None if the block doesn't use time conditioning. Default: None num_classes (int or None): number of classes or None if the block doesn't use class conditioning. Default: None activation (function): activation function. Default: torch.nn.functional.relu - norm (string or None): which normalization to use (instance, group, batch, or none). Default: "in" - num_groups (int): number of groups used in group normalization. Default: 8 + norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn" + num_groups (int): number of groups used in group normalization. Default: 32 """ def __init__( @@ -161,42 +124,51 @@ def __init__( dropout, time_emb_dim=None, num_classes=None, - **kwargs, + activation=F.relu, + norm="gn", + num_groups=32, ): super().__init__() - self.conv_block_1 = ConvBlock(in_channels, out_channels, **kwargs) - self.conv_block_2 = nn.Sequential( + self.activation = activation + + self.norm_1 = get_norm(norm, in_channels, num_groups) + self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1) + + self.conv_2 = nn.Sequential( nn.Dropout(p=dropout), - ConvBlock(out_channels, out_channels, **kwargs), + nn.Conv2d(out_channels, out_channels, 3, padding=1), ) + self.norm_2 = get_norm(norm, out_channels) - self.time_bias = nn.Sequential(nn.ReLU(), nn.Linear(time_emb_dim, out_channels)) if time_emb_dim is not None else None + self.time_bias = nn.Linear(time_emb_dim, out_channels) if time_emb_dim is not None else None self.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else None self.residual_connection = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() def forward(self, x, time_emb=None, y=None): - out = self.conv_block_1(x) + out = self.activation(self.norm_1(x)) + out = self.conv_1(out) if self.time_bias is not None: if time_emb is None: raise ValueError("time conditioning was specified but time_emb is not passed") - out += self.time_bias(time_emb)[:, :, None, None] + out += self.time_bias(self.activation(time_emb))[:, :, None, None] if self.class_bias is not None: if y is None: raise ValueError("class conditioning was specified but y is not passed") out += self.class_bias(y)[:, :, None, None] - - out = self.conv_block_2(out) + + out = self.activation(self.norm_2(x)) + out = self.conv_2(out) return out + self.residual_connection(x) class AttentionBlock(nn.Module): - __doc__ = r"""Applies linear attention with a residual connection. + __doc__ = r"""Applies attention with a residual connection. This part differs a lot from what was used in the paper. Input: x: tensor of shape (N, in_channels, H, W) @@ -211,7 +183,7 @@ def __init__(self, in_channels, heads=4, head_channels=32): super().__init__() self.heads = heads - self.norm = nn.InstanceNorm2d(in_channels, affine=True) + self.norm = nn.InstanceNorm2d(in_channels, affine=True) # GroupNorm in paper mid_channels = head_channels * heads self.to_qkv = nn.Conv2d(in_channels, mid_channels * 3, 1, bias=False) @@ -246,9 +218,9 @@ class UNet(nn.Module): num_classes (int or None): number of classes or None if the block doesn't use class conditioning. Default: None activation (function): activation function. Default: torch.nn.functional.relu dropout (float): dropout rate at the end of each residual block - use_attn (bool): it True linear attention is used in residual blocks. Default: True - norm (string or None): which normalization to use (instance, group, batch, or none). Default: "in" - num_groups (int): number of groups used in group normalization. Default: 8 + attention_resolutions (tuple): list of relative resolutions at which to apply attention. Default: () + norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn" + num_groups (int): number of groups used in group normalization. Default: 32 align_corners (bool): align_corners in bilinear upsampling. Default: True use_reflection_pad (bool): if True reflection pad is used, otherwise zero pad is used. Default: False initial_pad (int): initial padding applied to image. Should be used if height or width is not a power of 2. Default: 0 @@ -264,15 +236,16 @@ def __init__( num_classes=None, activation=F.relu, dropout=0.1, - use_attn=True, - norm="in", - num_groups=8, + attention_resolutions=(), + norm="gn", + num_groups=32, align_corners=True, use_reflection_pad=False, initial_pad=0, ): super().__init__() + self.activation = activation self.initial_pad = initial_pad self.num_classes = num_classes @@ -283,6 +256,8 @@ def __init__( nn.Linear(time_emb_dim * 4, time_emb_dim), ) if time_emb_dim is not None else None + self.init_conv = nn.Conv2d(img_channels, in_channels, 3, padding=1) + channels = (img_channels, *[base_channels * mult for mult in channel_mults]) channel_pairs = tuple(zip(channels[:-1], channels[1:])) @@ -291,13 +266,14 @@ def __init__( for ind, (in_channels, out_channels) in enumerate(channel_pairs): is_last = (ind == len(channel_pairs) - 1) + relative_resolution = ind self.downs.append(nn.ModuleList([ ResidualBlock(in_channels, out_channels, dropout, time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups), ResidualBlock(out_channels, out_channels, dropout, time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups), - AttentionBlock(out_channels) if use_attn else nn.Identity(), + AttentionBlock(out_channels) if relative_resolution in attention_resolutions else nn.Identity(), Downsample(out_channels, use_reflection_pad=use_reflection_pad) if not is_last else nn.Identity(), ])) @@ -305,25 +281,25 @@ def __init__( self.mid1 = ResidualBlock( mid_channels, mid_channels, dropout, time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups) - self.mid_attn = AttentionBlock(mid_channels) if use_attn else nn.Identity() + self.mid_attn = AttentionBlock(mid_channels) self.mid2 = ResidualBlock( mid_channels, mid_channels, dropout, time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups) for ind, (in_channels, out_channels) in enumerate(reversed(channel_pairs[1:])): + relative_resolution = len(channel_mults) - ind - 1 + self.ups.append(nn.ModuleList([ ResidualBlock(out_channels * 2, in_channels, dropout, time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups), ResidualBlock(in_channels, in_channels, dropout, time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups), - AttentionBlock(in_channels) if use_attn else nn.Identity(), + AttentionBlock(in_channels) if relative_resolution in attention_resolutions else nn.Identity(), Upsample(in_channels, align_corners=align_corners, use_reflection_pad=use_reflection_pad), ])) - self.out_conv = nn.Sequential( - ConvBlock(base_channels, base_channels, activation=activation, norm=norm, num_groups=num_groups), - nn.Conv2d(base_channels, img_channels, 1), - ) + self.out_norm = get_norm(norm, base_channels, num_groups) + self.out_conv = nn.Conv2d(base_channels, img_channels, 3) def forward(self, x, time=None, y=None): ip = self.initial_pad @@ -341,6 +317,8 @@ def forward(self, x, time=None, y=None): if self.num_classes is not None and y is None: raise ValueError("class conditioning was specified but y is not passed") + x = self.init_conv(x) + skips = [] for r1, r2, attn, downsample in self.downs: @@ -359,8 +337,11 @@ def forward(self, x, time=None, y=None): x = r2(x, time_emb, y) x = attn(x) x = upsample(x) + + x = self.activation(self.out_norm(x)) + x = self.out_conv(x) if self.initial_pad != 0: - return self.out_conv(x)[:, :, ip:-ip, ip:-ip] + return x[:, :, ip:-ip, ip:-ip] else: - return self.out_conv(x) + return x