Skip to content

Commit

Permalink
Added attention and dropout, fixed bugs in utils.py and diffusion.py
Browse files Browse the repository at this point in the history
  • Loading branch information
abarankab committed Aug 19, 2021
1 parent 90d0389 commit 975987d
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 14 deletions.
2 changes: 1 addition & 1 deletion ddpm/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def sample(self, batch_size, device, y=None, use_ema=True):
t_batch = torch.tensor([t], device=device).repeat(batch_size)
x = self.remove_noise(x, t_batch, y, use_ema)

if t > 1:
if t > 0:
x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)

return x.cpu().detach()
Expand Down
75 changes: 64 additions & 11 deletions ddpm/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange


class PositionalEmbedding(nn.Module):
__doc__ = r"""Computes a positional embedding of timesteps.
Expand Down Expand Up @@ -156,14 +158,18 @@ def __init__(
self,
in_channels,
out_channels,
dropout,
time_emb_dim=None,
num_classes=None,
**kwargs,
):
super().__init__()

self.conv_block_1 = ConvBlock(in_channels, out_channels, **kwargs)
self.conv_block_2 = ConvBlock(out_channels, out_channels, **kwargs)
self.conv_block_2 = nn.Sequential(
nn.Dropout(p=dropout),
ConvBlock(out_channels, out_channels, **kwargs),
)

self.time_bias = nn.Sequential(nn.ReLU(), nn.Linear(time_emb_dim, out_channels)) if time_emb_dim is not None else None
self.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else None
Expand All @@ -189,8 +195,41 @@ def forward(self, x, time_emb=None, y=None):
return out + self.residual_connection(x)


class AttentionBlock(nn.Module):
__doc__ = r"""Applies linear attention with a residual connection.
Input:
x: tensor of shape (N, in_channels, H, W)
Output:
tensor of shape (N, in_channels, H, W)
Args:
in_channels (int): number of input channels
heads (int): number of attention heads
head_channels (int): number of channels in a head
"""
def __init__(self, in_channels, heads=4, head_channels=32):
super().__init__()
self.heads = heads

self.norm = nn.InstanceNorm2d(in_channels, affine=True)

mid_channels = head_channels * heads
self.to_qkv = nn.Conv2d(in_channels, mid_channels * 3, 1, bias=False)
self.to_out = nn.Conv2d(mid_channels, in_channels, 1)

def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(self.norm(x))
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
return self.to_out(out) + x


class UNet(nn.Module):
__doc__ = """UNet model used to estimate noise. No attention is used.
__doc__ = """UNet model used to estimate noise.
Input:
x: tensor of shape (N, in_channels, H, W)
Expand All @@ -206,6 +245,8 @@ class UNet(nn.Module):
time_emb_scale (float): linear scale to be applied to timesteps. Default: 1.0
num_classes (int or None): number of classes or None if the block doesn't use class conditioning. Default: None
activation (function): activation function. Default: torch.nn.functional.relu
dropout (float): dropout rate at the end of each residual block
use_attn (bool): it True linear attention is used in residual blocks. Default: True
norm (string or None): which normalization to use (instance, group, batch, or none). Default: "in"
num_groups (int): number of groups used in group normalization. Default: 8
align_corners (bool): align_corners in bilinear upsampling. Default: True
Expand All @@ -222,6 +263,8 @@ def __init__(
time_emb_scale=1.0,
num_classes=None,
activation=F.relu,
dropout=0.1,
use_attn=True,
norm="in",
num_groups=8,
align_corners=True,
Expand Down Expand Up @@ -250,24 +293,30 @@ def __init__(
is_last = (ind == len(channel_pairs) - 1)

self.downs.append(nn.ModuleList([
ResidualBlock(in_channels, out_channels,
ResidualBlock(in_channels, out_channels, dropout,
time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups),
ResidualBlock(out_channels, out_channels,
ResidualBlock(out_channels, out_channels, dropout,
time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups),
AttentionBlock(out_channels) if use_attn else nn.Identity,
Downsample(out_channels, use_reflection_pad=use_reflection_pad) if not is_last else nn.Identity(),
]))

mid_channels = channels[-1]
self.mid = ResidualBlock(
mid_channels, mid_channels,
self.mid1 = ResidualBlock(
mid_channels, mid_channels, dropout,
time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups)
self.mid_attn = AttentionBlock(mid_channels) if use_attn else nn.Identity()
self.mid2 = ResidualBlock(
mid_channels, mid_channels, dropout,
time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups)

for ind, (in_channels, out_channels) in enumerate(reversed(channel_pairs[1:])):
self.ups.append(nn.ModuleList([
ResidualBlock(out_channels * 2, in_channels,
ResidualBlock(out_channels * 2, in_channels, dropout,
time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups),
ResidualBlock(in_channels, in_channels,
ResidualBlock(in_channels, in_channels, dropout,
time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups),
AttentionBlock(in_channels) if use_attn else nn.Identity,
Upsample(in_channels, align_corners=align_corners, use_reflection_pad=use_reflection_pad),
]))

Expand All @@ -294,17 +343,21 @@ def forward(self, x, time=None, y=None):

skips = []

for r1, r2, downsample in self.downs:
for r1, r2, attn, downsample in self.downs:
x = r1(x, time_emb, y)
x = r2(x, time_emb, y)
x = attn(x)
skips.append(x)
x = downsample(x)

x = self.mid(x, time_emb, y)
x = self.mid1(x, time_emb, y)
x = self.mid_attn(x)
x = self.mid2(x, time_emb, y)

for r1, r2, upsample in self.ups:
for r1, r2, attn, upsample in self.ups:
x = r1(torch.cat([x, skips.pop()], dim=1), time_emb, y)
x = r2(x, time_emb, y)
x = attn(x)
x = upsample(x)

if self.initial_pad != 0:
Expand Down
4 changes: 2 additions & 2 deletions ddpm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def show_images(image_tensor, rows, cols=None, colorbar=False):

plt.subplot(rows, cols, i * cols + j + 1)
if is_rgb:
plt.imshow(image_tensor[i * cols + j].permute(1, 2, 0))
plt.imshow(((image_tensor[i * cols + j] + 1) / 2).permute(1, 2, 0))
else:
plt.imshow(image_tensor[i * cols + j].squeeze(), cmap="gray")
plt.imshow(((image_tensor[i * cols + j] + 1) / 2).squeeze(), cmap="gray")

if colorbar:
plt.colorbar()
Expand Down

0 comments on commit 975987d

Please sign in to comment.