diff --git a/ddpm/diffusion.py b/ddpm/diffusion.py index 6a64fb5..029f720 100644 --- a/ddpm/diffusion.py +++ b/ddpm/diffusion.py @@ -111,7 +111,7 @@ def sample(self, batch_size, device, y=None, use_ema=True): t_batch = torch.tensor([t], device=device).repeat(batch_size) x = self.remove_noise(x, t_batch, y, use_ema) - if t > 1: + if t > 0: x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x) return x.cpu().detach() diff --git a/ddpm/unet.py b/ddpm/unet.py index ab71d75..57725e3 100644 --- a/ddpm/unet.py +++ b/ddpm/unet.py @@ -3,6 +3,8 @@ import torch.nn as nn import torch.nn.functional as F +from einops import rearrange + class PositionalEmbedding(nn.Module): __doc__ = r"""Computes a positional embedding of timesteps. @@ -156,6 +158,7 @@ def __init__( self, in_channels, out_channels, + dropout, time_emb_dim=None, num_classes=None, **kwargs, @@ -163,7 +166,10 @@ def __init__( super().__init__() self.conv_block_1 = ConvBlock(in_channels, out_channels, **kwargs) - self.conv_block_2 = ConvBlock(out_channels, out_channels, **kwargs) + self.conv_block_2 = nn.Sequential( + nn.Dropout(p=dropout), + ConvBlock(out_channels, out_channels, **kwargs), + ) self.time_bias = nn.Sequential(nn.ReLU(), nn.Linear(time_emb_dim, out_channels)) if time_emb_dim is not None else None self.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else None @@ -189,8 +195,41 @@ def forward(self, x, time_emb=None, y=None): return out + self.residual_connection(x) +class AttentionBlock(nn.Module): + __doc__ = r"""Applies linear attention with a residual connection. + + Input: + x: tensor of shape (N, in_channels, H, W) + Output: + tensor of shape (N, in_channels, H, W) + Args: + in_channels (int): number of input channels + heads (int): number of attention heads + head_channels (int): number of channels in a head + """ + def __init__(self, in_channels, heads=4, head_channels=32): + super().__init__() + self.heads = heads + + self.norm = nn.InstanceNorm2d(in_channels, affine=True) + + mid_channels = head_channels * heads + self.to_qkv = nn.Conv2d(in_channels, mid_channels * 3, 1, bias=False) + self.to_out = nn.Conv2d(mid_channels, in_channels, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(self.norm(x)) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + x + + class UNet(nn.Module): - __doc__ = """UNet model used to estimate noise. No attention is used. + __doc__ = """UNet model used to estimate noise. Input: x: tensor of shape (N, in_channels, H, W) @@ -206,6 +245,8 @@ class UNet(nn.Module): time_emb_scale (float): linear scale to be applied to timesteps. Default: 1.0 num_classes (int or None): number of classes or None if the block doesn't use class conditioning. Default: None activation (function): activation function. Default: torch.nn.functional.relu + dropout (float): dropout rate at the end of each residual block + use_attn (bool): it True linear attention is used in residual blocks. Default: True norm (string or None): which normalization to use (instance, group, batch, or none). Default: "in" num_groups (int): number of groups used in group normalization. Default: 8 align_corners (bool): align_corners in bilinear upsampling. Default: True @@ -222,6 +263,8 @@ def __init__( time_emb_scale=1.0, num_classes=None, activation=F.relu, + dropout=0.1, + use_attn=True, norm="in", num_groups=8, align_corners=True, @@ -250,24 +293,30 @@ def __init__( is_last = (ind == len(channel_pairs) - 1) self.downs.append(nn.ModuleList([ - ResidualBlock(in_channels, out_channels, + ResidualBlock(in_channels, out_channels, dropout, time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups), - ResidualBlock(out_channels, out_channels, + ResidualBlock(out_channels, out_channels, dropout, time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups), + AttentionBlock(out_channels) if use_attn else nn.Identity, Downsample(out_channels, use_reflection_pad=use_reflection_pad) if not is_last else nn.Identity(), ])) mid_channels = channels[-1] - self.mid = ResidualBlock( - mid_channels, mid_channels, + self.mid1 = ResidualBlock( + mid_channels, mid_channels, dropout, + time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups) + self.mid_attn = AttentionBlock(mid_channels) if use_attn else nn.Identity() + self.mid2 = ResidualBlock( + mid_channels, mid_channels, dropout, time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups) for ind, (in_channels, out_channels) in enumerate(reversed(channel_pairs[1:])): self.ups.append(nn.ModuleList([ - ResidualBlock(out_channels * 2, in_channels, + ResidualBlock(out_channels * 2, in_channels, dropout, time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups), - ResidualBlock(in_channels, in_channels, + ResidualBlock(in_channels, in_channels, dropout, time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups), + AttentionBlock(in_channels) if use_attn else nn.Identity, Upsample(in_channels, align_corners=align_corners, use_reflection_pad=use_reflection_pad), ])) @@ -294,17 +343,21 @@ def forward(self, x, time=None, y=None): skips = [] - for r1, r2, downsample in self.downs: + for r1, r2, attn, downsample in self.downs: x = r1(x, time_emb, y) x = r2(x, time_emb, y) + x = attn(x) skips.append(x) x = downsample(x) - x = self.mid(x, time_emb, y) + x = self.mid1(x, time_emb, y) + x = self.mid_attn(x) + x = self.mid2(x, time_emb, y) - for r1, r2, upsample in self.ups: + for r1, r2, attn, upsample in self.ups: x = r1(torch.cat([x, skips.pop()], dim=1), time_emb, y) x = r2(x, time_emb, y) + x = attn(x) x = upsample(x) if self.initial_pad != 0: diff --git a/ddpm/utils.py b/ddpm/utils.py index 65a298b..1f1fef4 100644 --- a/ddpm/utils.py +++ b/ddpm/utils.py @@ -53,9 +53,9 @@ def show_images(image_tensor, rows, cols=None, colorbar=False): plt.subplot(rows, cols, i * cols + j + 1) if is_rgb: - plt.imshow(image_tensor[i * cols + j].permute(1, 2, 0)) + plt.imshow(((image_tensor[i * cols + j] + 1) / 2).permute(1, 2, 0)) else: - plt.imshow(image_tensor[i * cols + j].squeeze(), cmap="gray") + plt.imshow(((image_tensor[i * cols + j] + 1) / 2).squeeze(), cmap="gray") if colorbar: plt.colorbar()