From 06d26c3706ed4bc2a253936947dae199f8b3e3f8 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 22 Jan 2024 07:12:55 -0800 Subject: [PATCH 01/21] begin work on the karras unet --- README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/README.md b/README.md index b493dad5a..907cbc86f 100644 --- a/README.md +++ b/README.md @@ -343,3 +343,14 @@ You could consider adding a suitable metric to the training loop yourself after url = {https://api.semanticscholar.org/CorpusID:259224568} } ``` + +```bibtex +@article{Karras2023AnalyzingAI, + title = {Analyzing and Improving the Training Dynamics of Diffusion Models}, + author = {Tero Karras and Miika Aittala and Jaakko Lehtinen and Janne Hellsten and Timo Aila and Samuli Laine}, + journal = {ArXiv}, + year = {2023}, + volume = {abs/2312.02696}, + url = {https://api.semanticscholar.org/CorpusID:265659032} +} +``` From 33ff8ed0b10db7cf847de9a1b4da92148bb78637 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 2 Feb 2024 13:00:17 -0800 Subject: [PATCH 02/21] begin by removing all biases and groupnorms --- denoising_diffusion_pytorch/__init__.py | 2 + denoising_diffusion_pytorch/karras_unet.py | 329 +++++++++++++++++++++ 2 files changed, 331 insertions(+) create mode 100644 denoising_diffusion_pytorch/karras_unet.py diff --git a/denoising_diffusion_pytorch/__init__.py b/denoising_diffusion_pytorch/__init__.py index adba20b10..05ded85d2 100644 --- a/denoising_diffusion_pytorch/__init__.py +++ b/denoising_diffusion_pytorch/__init__.py @@ -7,3 +7,5 @@ from denoising_diffusion_pytorch.v_param_continuous_time_gaussian_diffusion import VParamContinuousTimeGaussianDiffusion from denoising_diffusion_pytorch.denoising_diffusion_pytorch_1d import GaussianDiffusion1D, Unet1D, Trainer1D, Dataset1D + +from denoising_diffusion_pytorch.karras_unet import KarrasUnet diff --git a/denoising_diffusion_pytorch/karras_unet.py b/denoising_diffusion_pytorch/karras_unet.py new file mode 100644 index 000000000..3215e7fd2 --- /dev/null +++ b/denoising_diffusion_pytorch/karras_unet.py @@ -0,0 +1,329 @@ +import math +import copy +from pathlib import Path +from random import random +from functools import partial +from collections import namedtuple + +import torch +from torch import nn, einsum +from torch.cuda.amp import autocast +import torch.nn.functional as F + +from einops import rearrange, reduce, repeat +from einops.layers.torch import Rearrange + +from denoising_diffusion_pytorch.attend import Attend + +# helpers functions + +def exists(x): + return x is not None + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + +def cast_tuple(t, length = 1): + if isinstance(t, tuple): + return t + return ((t,) * length) + +def divisible_by(numer, denom): + return (numer % denom) == 0 + +def identity(t, *args, **kwargs): + return t + +# small helper modules + +def Upsample(dim, dim_out = None): + return nn.Sequential( + nn.Upsample(scale_factor = 2, mode = 'nearest'), + nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1, bias = False) + ) + +def Downsample(dim, dim_out = None): + return nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2), + nn.Conv2d(dim * 4, default(dim_out, dim), 1, bias = False) + ) + +class RMSNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) + + def forward(self, x): + return F.normalize(x, dim = 1) * self.g * (x.shape[1] ** 0.5) + +# sinusoidal positional embeds + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim, theta = 10000): + super().__init__() + self.dim = dim + self.theta = theta + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(self.theta) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + +class RandomOrLearnedSinusoidalPosEmb(nn.Module): + """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """ + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ + + def __init__(self, dim, is_random = False): + super().__init__() + assert divisible_by(dim, 2) + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random) + + def forward(self, x): + x = rearrange(x, 'b -> b 1') + freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) + fouriered = torch.cat((x, fouriered), dim = -1) + return fouriered + +# building block modules + +class Block(nn.Module): + def __init__(self, dim, dim_out, groups = 8): + super().__init__() + self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1, bias = False) + self.act = nn.SiLU() + + def forward(self, x, scale_shift = None): + x = self.proj(x) + + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + + x = self.act(x) + return x + +class ResnetBlock(nn.Module): + def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8): + super().__init__() + self.mlp = nn.Sequential( + nn.SiLU(), + nn.Linear(time_emb_dim, dim_out * 2, bias = False) + ) if exists(time_emb_dim) else None + + self.block1 = Block(dim, dim_out, groups = groups) + self.block2 = Block(dim_out, dim_out, groups = groups) + self.res_conv = nn.Conv2d(dim, dim_out, 1, bias = False) if dim != dim_out else nn.Identity() + + def forward(self, x, time_emb = None): + + scale_shift = None + if exists(self.mlp) and exists(time_emb): + time_emb = self.mlp(time_emb) + time_emb = rearrange(time_emb, 'b c -> b c 1 1') + scale_shift = time_emb.chunk(2, dim = 1) + + h = self.block1(x, scale_shift = scale_shift) + + h = self.block2(h) + + return h + self.res_conv(x) + +class Attention(nn.Module): + def __init__( + self, + dim, + heads = 4, + dim_head = 32, + num_mem_kv = 4, + flash = False + ): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + + self.norm = RMSNorm(dim) + self.attend = Attend(flash = flash) + + self.mem_kv = nn.Parameter(torch.randn(2, heads, num_mem_kv, dim_head)) + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1, bias = False) + + def forward(self, x): + b, c, h, w = x.shape + + x = self.norm(x) + + qkv = self.to_qkv(x).chunk(3, dim = 1) + q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h (x y) c', h = self.heads), qkv) + + mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), self.mem_kv) + k, v = map(partial(torch.cat, dim = -2), ((mk, k), (mv, v))) + + out = self.attend(q, k, v) + + out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w) + return self.to_out(out) + +# model + +class KarrasUnet(nn.Module): + def __init__( + self, + dim, + init_dim = None, + out_dim = None, + dim_mults = (1, 2, 4, 8), + channels = 3, + self_condition = False, + resnet_block_groups = 8, + learned_variance = False, + learned_sinusoidal_cond = False, + random_fourier_features = False, + learned_sinusoidal_dim = 16, + sinusoidal_pos_emb_theta = 10000, + attn_dim_head = 32, + attn_heads = 4, + full_attn = None, # defaults to full attention only for inner most layer + flash_attn = False + ): + super().__init__() + + # determine dimensions + + self.channels = channels + self.self_condition = self_condition + input_channels = channels * (2 if self_condition else 1) + + init_dim = default(init_dim, dim) + self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3) + + dims = [init_dim, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + + block_klass = partial(ResnetBlock, groups = resnet_block_groups) + + # time embeddings + + time_dim = dim * 4 + + self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features + + if self.random_or_learned_sinusoidal_cond: + sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features) + fourier_dim = learned_sinusoidal_dim + 1 + else: + sinu_pos_emb = SinusoidalPosEmb(dim, theta = sinusoidal_pos_emb_theta) + fourier_dim = dim + + self.time_mlp = nn.Sequential( + sinu_pos_emb, + nn.Linear(fourier_dim, time_dim), + nn.GELU(), + nn.Linear(time_dim, time_dim) + ) + + # attention + + if not full_attn: + full_attn = (*((False,) * (len(dim_mults) - 1)), True) + + num_stages = len(dim_mults) + full_attn = cast_tuple(full_attn, num_stages) + attn_heads = cast_tuple(attn_heads, num_stages) + attn_dim_head = cast_tuple(attn_dim_head, num_stages) + + assert len(full_attn) == len(dim_mults) + + FullAttention = partial(Attention, flash = flash_attn) + + # layers + + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + num_resolutions = len(in_out) + + for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(in_out, full_attn, attn_heads, attn_dim_head)): + is_last = ind >= (num_resolutions - 1) + + self.downs.append(nn.ModuleList([ + block_klass(dim_in, dim_in, time_emb_dim = time_dim), + block_klass(dim_in, dim_in, time_emb_dim = time_dim), + FullAttention(dim_in, dim_head = layer_attn_dim_head, heads = layer_attn_heads), + Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) + ])) + + mid_dim = dims[-1] + self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_attn = FullAttention(mid_dim, heads = attn_heads[-1], dim_head = attn_dim_head[-1]) + self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) + + for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(*map(reversed, (in_out, full_attn, attn_heads, attn_dim_head)))): + is_last = ind == (len(in_out) - 1) + + self.ups.append(nn.ModuleList([ + block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), + block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), + FullAttention(dim_out, dim_head = layer_attn_dim_head, heads = layer_attn_heads), + Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) + ])) + + default_out_dim = channels * (1 if not learned_variance else 2) + self.out_dim = default(out_dim, default_out_dim) + + self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim) + self.final_conv = nn.Conv2d(dim, self.out_dim, 1, bias = False) + + @property + def downsample_factor(self): + return 2 ** (len(self.downs) - 1) + + def forward(self, x, time, x_self_cond = None): + assert all([divisible_by(d, self.downsample_factor) for d in x.shape[-2:]]), f'your input dimensions {x.shape[-2:]} need to be divisible by {self.downsample_factor}, given the unet' + + if self.self_condition: + x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) + x = torch.cat((x_self_cond, x), dim = 1) + + x = self.init_conv(x) + r = x.clone() + + t = self.time_mlp(time) + + h = [] + + for block1, block2, attn, downsample in self.downs: + x = block1(x, t) + h.append(x) + + x = block2(x, t) + x = attn(x) + x + h.append(x) + + x = downsample(x) + + x = self.mid_block1(x, t) + x = self.mid_attn(x) + x + x = self.mid_block2(x, t) + + for block1, block2, attn, upsample in self.ups: + x = torch.cat((x, h.pop()), dim = 1) + x = block1(x, t) + + x = torch.cat((x, h.pop()), dim = 1) + x = block2(x, t) + x = attn(x) + x + + x = upsample(x) + + x = torch.cat((x, r), dim = 1) + + x = self.final_res_block(x, t) + return self.final_conv(x) From 3f8d5eb5eb892af3dd6cc750f94bb6d96a26ee2e Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 2 Feb 2024 13:10:16 -0800 Subject: [PATCH 03/21] more cleanup --- denoising_diffusion_pytorch/karras_unet.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/denoising_diffusion_pytorch/karras_unet.py b/denoising_diffusion_pytorch/karras_unet.py index 3215e7fd2..c9dd58ad2 100644 --- a/denoising_diffusion_pytorch/karras_unet.py +++ b/denoising_diffusion_pytorch/karras_unet.py @@ -1,9 +1,6 @@ import math -import copy -from pathlib import Path from random import random from functools import partial -from collections import namedtuple import torch from torch import nn, einsum @@ -225,9 +222,9 @@ def __init__( self.time_mlp = nn.Sequential( sinu_pos_emb, - nn.Linear(fourier_dim, time_dim), + nn.Linear(fourier_dim, time_dim, bias = False), nn.GELU(), - nn.Linear(time_dim, time_dim) + nn.Linear(time_dim, time_dim, bias = False) ) # attention @@ -242,8 +239,6 @@ def __init__( assert len(full_attn) == len(dim_mults) - FullAttention = partial(Attention, flash = flash_attn) - # layers self.downs = nn.ModuleList([]) @@ -256,13 +251,13 @@ def __init__( self.downs.append(nn.ModuleList([ block_klass(dim_in, dim_in, time_emb_dim = time_dim), block_klass(dim_in, dim_in, time_emb_dim = time_dim), - FullAttention(dim_in, dim_head = layer_attn_dim_head, heads = layer_attn_heads), + Attention(dim_in, dim_head = layer_attn_dim_head, heads = layer_attn_heads, flash = flash_attn), Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) ])) mid_dim = dims[-1] self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) - self.mid_attn = FullAttention(mid_dim, heads = attn_heads[-1], dim_head = attn_dim_head[-1]) + self.mid_attn = Attention(mid_dim, heads = attn_heads[-1], dim_head = attn_dim_head[-1]) self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(*map(reversed, (in_out, full_attn, attn_heads, attn_dim_head)))): @@ -271,8 +266,8 @@ def __init__( self.ups.append(nn.ModuleList([ block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), - FullAttention(dim_out, dim_head = layer_attn_dim_head, heads = layer_attn_heads), - Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) + Attention(dim_out, dim_head = layer_attn_dim_head, heads = layer_attn_heads, flash = flash_attn), + Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1, bias = False) ])) default_out_dim = channels * (1 if not learned_variance else 2) From 8aa36bc04fb5ce2c7f967838034cae24cd2d44dc Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 2 Feb 2024 16:46:23 -0800 Subject: [PATCH 04/21] take care of activation fns --- denoising_diffusion_pytorch/karras_unet.py | 44 +++++++++++++++------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/denoising_diffusion_pytorch/karras_unet.py b/denoising_diffusion_pytorch/karras_unet.py index c9dd58ad2..9fd233e6b 100644 --- a/denoising_diffusion_pytorch/karras_unet.py +++ b/denoising_diffusion_pytorch/karras_unet.py @@ -4,6 +4,7 @@ import torch from torch import nn, einsum +from torch.nn import Module, ModuleList from torch.cuda.amp import autocast import torch.nn.functional as F @@ -47,7 +48,22 @@ def Downsample(dim, dim_out = None): nn.Conv2d(dim * 4, default(dim_out, dim), 1, bias = False) ) -class RMSNorm(nn.Module): +# mp activations +# section 2.5 + +class MPSiLU(Module): + def forward(self, x): + return F.silu(x) / 0.596 + +def sin(t): + return torch.sin(t) * (2 ** 0.5) + +def cos(t): + return torch.cos(t) * (2 ** 0.5) + +# norm + +class RMSNorm(Module): def __init__(self, dim): super().__init__() self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) @@ -57,7 +73,7 @@ def forward(self, x): # sinusoidal positional embeds -class SinusoidalPosEmb(nn.Module): +class SinusoidalPosEmb(Module): def __init__(self, dim, theta = 10000): super().__init__() self.dim = dim @@ -69,10 +85,10 @@ def forward(self, x): emb = math.log(self.theta) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = x[:, None] * emb[None, :] - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + emb = torch.cat((sin(emb), cos(emb)), dim=-1) return emb -class RandomOrLearnedSinusoidalPosEmb(nn.Module): +class RandomOrLearnedSinusoidalPosEmb(Module): """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """ """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ @@ -91,11 +107,11 @@ def forward(self, x): # building block modules -class Block(nn.Module): +class Block(Module): def __init__(self, dim, dim_out, groups = 8): super().__init__() self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1, bias = False) - self.act = nn.SiLU() + self.act = MPSiLU() def forward(self, x, scale_shift = None): x = self.proj(x) @@ -107,11 +123,11 @@ def forward(self, x, scale_shift = None): x = self.act(x) return x -class ResnetBlock(nn.Module): +class ResnetBlock(Module): def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8): super().__init__() self.mlp = nn.Sequential( - nn.SiLU(), + MPSiLU(), nn.Linear(time_emb_dim, dim_out * 2, bias = False) ) if exists(time_emb_dim) else None @@ -133,7 +149,7 @@ def forward(self, x, time_emb = None): return h + self.res_conv(x) -class Attention(nn.Module): +class Attention(Module): def __init__( self, dim, @@ -171,7 +187,7 @@ def forward(self, x): # model -class KarrasUnet(nn.Module): +class KarrasUnet(Module): def __init__( self, dim, @@ -241,14 +257,14 @@ def __init__( # layers - self.downs = nn.ModuleList([]) - self.ups = nn.ModuleList([]) + self.downs = ModuleList([]) + self.ups = ModuleList([]) num_resolutions = len(in_out) for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(in_out, full_attn, attn_heads, attn_dim_head)): is_last = ind >= (num_resolutions - 1) - self.downs.append(nn.ModuleList([ + self.downs.append(ModuleList([ block_klass(dim_in, dim_in, time_emb_dim = time_dim), block_klass(dim_in, dim_in, time_emb_dim = time_dim), Attention(dim_in, dim_head = layer_attn_dim_head, heads = layer_attn_heads, flash = flash_attn), @@ -263,7 +279,7 @@ def __init__( for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(*map(reversed, (in_out, full_attn, attn_heads, attn_dim_head)))): is_last = ind == (len(in_out) - 1) - self.ups.append(nn.ModuleList([ + self.ups.append(ModuleList([ block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), Attention(dim_out, dim_head = layer_attn_dim_head, heads = layer_attn_heads, flash = flash_attn), From 3c8fe052a694a5e69a123b4c39825443f2c40a4c Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 3 Feb 2024 08:52:57 -0800 Subject: [PATCH 05/21] add gain --- denoising_diffusion_pytorch/karras_unet.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/denoising_diffusion_pytorch/karras_unet.py b/denoising_diffusion_pytorch/karras_unet.py index 9fd233e6b..9f09d82c4 100644 --- a/denoising_diffusion_pytorch/karras_unet.py +++ b/denoising_diffusion_pytorch/karras_unet.py @@ -61,6 +61,16 @@ def sin(t): def cos(t): return torch.cos(t) * (2 ** 0.5) +# gain - layer scaling + +class Gain(Module): + def __init__(self): + super().__init__() + self.gain = nn.Parameter(torch.tensor(0.)) + + def forward(self, x): + return x * self.gain + # norm class RMSNorm(Module): From f6ce3c54a4df1555a6a4d7dcbc636b8a1179ce60 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 3 Feb 2024 09:01:53 -0800 Subject: [PATCH 06/21] complete the magnitude preserving concat, equation 103 --- denoising_diffusion_pytorch/karras_unet.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/denoising_diffusion_pytorch/karras_unet.py b/denoising_diffusion_pytorch/karras_unet.py index 9f09d82c4..40831af3f 100644 --- a/denoising_diffusion_pytorch/karras_unet.py +++ b/denoising_diffusion_pytorch/karras_unet.py @@ -1,4 +1,5 @@ import math +from math import sqrt from random import random from functools import partial @@ -71,6 +72,26 @@ def __init__(self): def forward(self, x): return x * self.gain +# magnitude preserving concat +# equation (103) - default to 0.5, which they recommended + +class MPCat(Module): + def __init__(self, t = 0.5, dim = -1): + super().__init__() + self.t = t + self.dim = dim + + def forward(self, a, b): + dim, t = self.dim, self.t + Na, Nb = a.shape[dim], b.shape[dim] + + C = sqrt((Na + Nb) / ((1 - t) ** 2 + t ** 2)) + + a = a * (1 - t) / sqrt(Na) + b = b * t / sqrt(Nb) + + return C * torch.cat((a, b), dim = dim) + # norm class RMSNorm(Module): From f36abd7bdeaa6a00fb016363dd11fae501c5ae8b Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 3 Feb 2024 09:14:52 -0800 Subject: [PATCH 07/21] complete magnitude preserving sum for residuals and embedding --- denoising_diffusion_pytorch/karras_unet.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/denoising_diffusion_pytorch/karras_unet.py b/denoising_diffusion_pytorch/karras_unet.py index 40831af3f..335ee42b3 100644 --- a/denoising_diffusion_pytorch/karras_unet.py +++ b/denoising_diffusion_pytorch/karras_unet.py @@ -85,13 +85,29 @@ def forward(self, a, b): dim, t = self.dim, self.t Na, Nb = a.shape[dim], b.shape[dim] - C = sqrt((Na + Nb) / ((1 - t) ** 2 + t ** 2)) + C = sqrt((Na + Nb) / ((1. - t) ** 2 + t ** 2)) - a = a * (1 - t) / sqrt(Na) + a = a * (1. - t) / sqrt(Na) b = b * t / sqrt(Nb) return C * torch.cat((a, b), dim = dim) +# magnitude preserving sum +# equation (88) +# empirically, they found t=0.3 for encoder / decoder / attention residuals +# and for embedding, t=0.5 + +class MPSum(Module): + def __init__(self, t): + super().__init__() + self.t = t + + def forward(self, x, res): + a, b, t = x, res, self.t + num = a * (1. - t) + b * t + den = sqrt((1 - t) ** 2 + t ** 2) + return num / den + # norm class RMSNorm(Module): From ac8e7a4da5a303b17b6231db9a5281e1e85f95cb Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 3 Feb 2024 09:21:48 -0800 Subject: [PATCH 08/21] they use cosine sim attention successfully --- denoising_diffusion_pytorch/karras_unet.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/denoising_diffusion_pytorch/karras_unet.py b/denoising_diffusion_pytorch/karras_unet.py index 335ee42b3..076c77b23 100644 --- a/denoising_diffusion_pytorch/karras_unet.py +++ b/denoising_diffusion_pytorch/karras_unet.py @@ -35,6 +35,11 @@ def divisible_by(numer, denom): def identity(t, *args, **kwargs): return t +# in paper, they use eps 1e-4 for pixelnorm + +def l2norm(t, dim = -1, eps = 1e-12): + return F.normalize(t, dim = dim, eps = eps) + # small helper modules def Upsample(dim, dim_out = None): @@ -196,7 +201,7 @@ def forward(self, x, time_emb = None): return h + self.res_conv(x) -class Attention(Module): +class CosineSimAttention(Module): def __init__( self, dim, @@ -212,6 +217,8 @@ def __init__( self.norm = RMSNorm(dim) self.attend = Attend(flash = flash) + self.temperatures = nn.Parameter(torch.zero(heads, 1, 1)) + self.mem_kv = nn.Parameter(torch.randn(2, heads, num_mem_kv, dim_head)) self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) self.to_out = nn.Conv2d(hidden_dim, dim, 1, bias = False) @@ -227,6 +234,10 @@ def forward(self, x): mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), self.mem_kv) k, v = map(partial(torch.cat, dim = -2), ((mk, k), (mv, v))) + q, k = map(l2norm, (q, k)) + + q = q * self.temperature.exp() # unsure if they did learned temperature in paper + out = self.attend(q, k, v) out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w) @@ -314,13 +325,13 @@ def __init__( self.downs.append(ModuleList([ block_klass(dim_in, dim_in, time_emb_dim = time_dim), block_klass(dim_in, dim_in, time_emb_dim = time_dim), - Attention(dim_in, dim_head = layer_attn_dim_head, heads = layer_attn_heads, flash = flash_attn), + CosineSimAttention(dim_in, dim_head = layer_attn_dim_head, heads = layer_attn_heads, flash = flash_attn), Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) ])) mid_dim = dims[-1] self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) - self.mid_attn = Attention(mid_dim, heads = attn_heads[-1], dim_head = attn_dim_head[-1]) + self.mid_attn = CosineSimAttention(mid_dim, heads = attn_heads[-1], dim_head = attn_dim_head[-1]) self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(*map(reversed, (in_out, full_attn, attn_heads, attn_dim_head)))): @@ -329,7 +340,7 @@ def __init__( self.ups.append(ModuleList([ block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), - Attention(dim_out, dim_head = layer_attn_dim_head, heads = layer_attn_heads, flash = flash_attn), + CosineSimAttention(dim_out, dim_head = layer_attn_dim_head, heads = layer_attn_heads, flash = flash_attn), Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1, bias = False) ])) From 94f291ee22f8920fe94e58bb2893b74b4673552f Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 3 Feb 2024 09:27:32 -0800 Subject: [PATCH 09/21] a fix scale was used for cosine sim attention, wow --- denoising_diffusion_pytorch/attend.py | 13 +++++++++++-- denoising_diffusion_pytorch/karras_unet.py | 6 ++---- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/denoising_diffusion_pytorch/attend.py b/denoising_diffusion_pytorch/attend.py index 6fd6e3bd9..fa689f906 100644 --- a/denoising_diffusion_pytorch/attend.py +++ b/denoising_diffusion_pytorch/attend.py @@ -17,6 +17,9 @@ def exists(val): return val is not None +def default(val, d): + return val if exists(val) else d + def once(fn): called = False @wraps(fn) @@ -36,10 +39,12 @@ class Attend(nn.Module): def __init__( self, dropout = 0., - flash = False + flash = False, + scale = None ): super().__init__() self.dropout = dropout + self.scale = scale self.attn_dropout = nn.Dropout(dropout) self.flash = flash @@ -65,6 +70,10 @@ def __init__( def flash_attn(self, q, k, v): _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device + if exists(self.scale): + default_scale = q.shape[-1] + q = q * (scale / default_scale) + q, k, v = map(lambda t: t.contiguous(), (q, k, v)) # Check if there is a compatible device for flash attention @@ -95,7 +104,7 @@ def forward(self, q, k, v): if self.flash: return self.flash_attn(q, k, v) - scale = q.shape[-1] ** -0.5 + scale = default(self.scale, q.shape[-1] ** -0.5) # similarity diff --git a/denoising_diffusion_pytorch/karras_unet.py b/denoising_diffusion_pytorch/karras_unet.py index 076c77b23..fb61b47fc 100644 --- a/denoising_diffusion_pytorch/karras_unet.py +++ b/denoising_diffusion_pytorch/karras_unet.py @@ -215,9 +215,9 @@ def __init__( hidden_dim = dim_head * heads self.norm = RMSNorm(dim) - self.attend = Attend(flash = flash) - self.temperatures = nn.Parameter(torch.zero(heads, 1, 1)) + # equation (34) - they used cosine sim of queries and keys with a fixed scale of sqrt(Nc) + self.attend = Attend(flash = flash, scale = dim_head ** 0.5) self.mem_kv = nn.Parameter(torch.randn(2, heads, num_mem_kv, dim_head)) self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) @@ -236,8 +236,6 @@ def forward(self, x): q, k = map(l2norm, (q, k)) - q = q * self.temperature.exp() # unsure if they did learned temperature in paper - out = self.attend(q, k, v) out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w) From 8f3dafa91952426919e78e58631911b8b400ea11 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 3 Feb 2024 09:39:08 -0800 Subject: [PATCH 10/21] pixel norm with rather high epsilon --- denoising_diffusion_pytorch/karras_unet.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/denoising_diffusion_pytorch/karras_unet.py b/denoising_diffusion_pytorch/karras_unet.py index fb61b47fc..d24ee297f 100644 --- a/denoising_diffusion_pytorch/karras_unet.py +++ b/denoising_diffusion_pytorch/karras_unet.py @@ -113,6 +113,19 @@ def forward(self, x, res): den = sqrt((1 - t) ** 2 + t ** 2) return num / den +# pixelnorm +# equation (30) + +class PixelNorm(Module): + def __init__(self, dim, eps = 1e-4): + super().__init__() + # high epsilon for the pixel norm in the paper + self.dim = dim + self.eps = eps + + def forward(self, x): + dim = self.dim + return l2norm(x, dim = dim, eps = self.eps) * sqrt(x.shape[dim]) # norm class RMSNorm(Module): From 5c4955ae709e35cc0c32437bcfdfde47b307fdf6 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 3 Feb 2024 10:07:36 -0800 Subject: [PATCH 11/21] complete the forced weight normed conv2d and linear following algorithm 1 in the paper --- denoising_diffusion_pytorch/karras_unet.py | 49 +++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/denoising_diffusion_pytorch/karras_unet.py b/denoising_diffusion_pytorch/karras_unet.py index d24ee297f..eb21fdff4 100644 --- a/denoising_diffusion_pytorch/karras_unet.py +++ b/denoising_diffusion_pytorch/karras_unet.py @@ -9,7 +9,7 @@ from torch.cuda.amp import autocast import torch.nn.functional as F -from einops import rearrange, reduce, repeat +from einops import rearrange, reduce, repeat, pack, unpack from einops.layers.torch import Rearrange from denoising_diffusion_pytorch.attend import Attend @@ -24,6 +24,12 @@ def default(val, d): return val return d() if callable(d) else d +def pack_one(t, pattern): + return pack([t], pattern) + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + def cast_tuple(t, length = 1): if isinstance(t, tuple): return t @@ -126,6 +132,47 @@ def __init__(self, dim, eps = 1e-4): def forward(self, x): dim = self.dim return l2norm(x, dim = dim, eps = self.eps) * sqrt(x.shape[dim]) + +# forced weight normed conv2d and linear +# algorithm 1 in paper + +class WeightNormedConv2d(Module): + def __init__(self, dim_in, dim_out, kernel_size, eps = 1e-4): + super().__init__() + weight = torch.randn(dim_out, dim_in, kernel_size, kernel_size) + self.weight = nn.Parameter(weight) + + self.eps = eps + self.fan_in = dim_in * kernel_size ** 2 + + def forward(self, x): + if self.training: + with torch.no_grad(): + weight, ps = pack_one(self.weight, 'o *') + normed_weight = l2norm(weight, eps = self.eps) + normed_weight = unpack_one(normed_weight, ps, 'o *') + self.weight.copy_(normed_weight) + + weight = l2norm(self.weight, eps = self.eps) / sqrt(self.fan_in) + return F.conv2d(x, weight, padding='same') + +class WeightNormedLinear(Module): + def __init__(self, dim_in, dim_out, eps = 1e-4): + super().__init__() + weight = torch.randn(dim_out, dim_in) + self.weight = nn.Parameter(weight) + self.eps = eps + self.fan_in = dim_in + + def forward(self, x): + if self.training: + with torch.no_grad(): + normed_weight = l2norm(self.weight, eps = self.eps) + self.weight.copy_(normed_weight) + + weight = l2norm(self.weight, eps = self.eps) / sqrt(self.fan_in) + return F.linear(x, weight) + # norm class RMSNorm(Module): From ed98b624b6419e3ed946ff2a3ec3cf17d7e79a69 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 3 Feb 2024 11:33:19 -0800 Subject: [PATCH 12/21] film is reduced to only scaling with no issues. bias not needed --- denoising_diffusion_pytorch/karras_unet.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/denoising_diffusion_pytorch/karras_unet.py b/denoising_diffusion_pytorch/karras_unet.py index eb21fdff4..416373c87 100644 --- a/denoising_diffusion_pytorch/karras_unet.py +++ b/denoising_diffusion_pytorch/karras_unet.py @@ -225,12 +225,11 @@ def __init__(self, dim, dim_out, groups = 8): self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1, bias = False) self.act = MPSiLU() - def forward(self, x, scale_shift = None): + def forward(self, x, scale = None): x = self.proj(x) - if exists(scale_shift): - scale, shift = scale_shift - x = x * (scale + 1) + shift + if exists(scale): + x = x * (scale + 1) x = self.act(x) return x @@ -249,13 +248,12 @@ def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8): def forward(self, x, time_emb = None): - scale_shift = None + scale = None if exists(self.mlp) and exists(time_emb): - time_emb = self.mlp(time_emb) - time_emb = rearrange(time_emb, 'b c -> b c 1 1') - scale_shift = time_emb.chunk(2, dim = 1) + scale = self.mlp(time_emb) + scale = rearrange(scale, 'b c -> b c 1 1') - h = self.block1(x, scale_shift = scale_shift) + h = self.block1(x, scale = scale) h = self.block2(h) From 77bf181ed6e3d3c42bfa5ee62affaa38cbe78e81 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 4 Feb 2024 08:33:32 -0800 Subject: [PATCH 13/21] add ability to concat ones to input for weight normed conv2d --- denoising_diffusion_pytorch/karras_unet.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/denoising_diffusion_pytorch/karras_unet.py b/denoising_diffusion_pytorch/karras_unet.py index 416373c87..47c444242 100644 --- a/denoising_diffusion_pytorch/karras_unet.py +++ b/denoising_diffusion_pytorch/karras_unet.py @@ -137,13 +137,21 @@ def forward(self, x): # algorithm 1 in paper class WeightNormedConv2d(Module): - def __init__(self, dim_in, dim_out, kernel_size, eps = 1e-4): + def __init__( + self, + dim_in, + dim_out, + kernel_size, + eps = 1e-4, + concat_ones_to_input = False # they use this in the input block to protect against loss of expressivity due to removal of all biases, even though they claim they observed none + ): super().__init__() - weight = torch.randn(dim_out, dim_in, kernel_size, kernel_size) + weight = torch.randn(dim_out, dim_in + int(concat_ones_to_input), kernel_size, kernel_size) self.weight = nn.Parameter(weight) self.eps = eps self.fan_in = dim_in * kernel_size ** 2 + self.concat_ones_to_input = concat_ones_to_input def forward(self, x): if self.training: @@ -154,6 +162,10 @@ def forward(self, x): self.weight.copy_(normed_weight) weight = l2norm(self.weight, eps = self.eps) / sqrt(self.fan_in) + + if self.concat_ones_to_input: + x = F.pad(x, (0, 0, 0, 0, 1, 0), value = 1.) + return F.conv2d(x, weight, padding='same') class WeightNormedLinear(Module): From 86c7717e867af023a161c498d890c6db340e1a29 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 4 Feb 2024 08:43:41 -0800 Subject: [PATCH 14/21] pixel norm was applied to qkv, no pre-norm in cosine sim attention --- denoising_diffusion_pytorch/karras_unet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/denoising_diffusion_pytorch/karras_unet.py b/denoising_diffusion_pytorch/karras_unet.py index 47c444242..c1c14d115 100644 --- a/denoising_diffusion_pytorch/karras_unet.py +++ b/denoising_diffusion_pytorch/karras_unet.py @@ -284,7 +284,7 @@ def __init__( self.heads = heads hidden_dim = dim_head * heads - self.norm = RMSNorm(dim) + self.pixel_norm = PixelNorm(dim = 1) # equation (34) - they used cosine sim of queries and keys with a fixed scale of sqrt(Nc) self.attend = Attend(flash = flash, scale = dim_head ** 0.5) @@ -296,14 +296,14 @@ def __init__( def forward(self, x): b, c, h, w = x.shape - x = self.norm(x) - qkv = self.to_qkv(x).chunk(3, dim = 1) q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h (x y) c', h = self.heads), qkv) mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), self.mem_kv) k, v = map(partial(torch.cat, dim = -2), ((mk, k), (mv, v))) + q, k, v = map(self.pixel_norm, (q, k, v)) + q, k = map(l2norm, (q, k)) out = self.attend(q, k, v) From c101b7066ed40d56d63e1a84b7ed847f27ff65c8 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 4 Feb 2024 09:26:47 -0800 Subject: [PATCH 15/21] box filter for both down and upsample --- denoising_diffusion_pytorch/karras_unet.py | 24 +++++++++++++--------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/denoising_diffusion_pytorch/karras_unet.py b/denoising_diffusion_pytorch/karras_unet.py index c1c14d115..6c01b53bc 100644 --- a/denoising_diffusion_pytorch/karras_unet.py +++ b/denoising_diffusion_pytorch/karras_unet.py @@ -49,16 +49,20 @@ def l2norm(t, dim = -1, eps = 1e-12): # small helper modules def Upsample(dim, dim_out = None): - return nn.Sequential( - nn.Upsample(scale_factor = 2, mode = 'nearest'), - nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1, bias = False) - ) - -def Downsample(dim, dim_out = None): - return nn.Sequential( - Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2), - nn.Conv2d(dim * 4, default(dim_out, dim), 1, bias = False) - ) + return nn.Upsample(scale_factor = 2, mode = 'bilinear') + +class Downsample(Module): + def __init__(self, dim): + super().__init__() + self.conv = WeightNormedConv2d(dim, dim, 1) + self.pixel_norm = PixelNorm(dim = 1) + + def forward(self, x): + h, w = x.shape[-2:] + x = F.interpolate(x, (h // 2, w // 2), mode = 'bilinear') + x = self.conv(x) + x = self.pixel_norm(x) + return x # mp activations # section 2.5 From 4fef364959f6eb4f4da0ea822c0fc036890653da Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 4 Feb 2024 09:51:08 -0800 Subject: [PATCH 16/21] have the resnet and attention block handle the residual with magnitude preserving add --- denoising_diffusion_pytorch/karras_unet.py | 46 ++++++++++++++++------ 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/denoising_diffusion_pytorch/karras_unet.py b/denoising_diffusion_pytorch/karras_unet.py index 6c01b53bc..04c13d3a5 100644 --- a/denoising_diffusion_pytorch/karras_unet.py +++ b/denoising_diffusion_pytorch/karras_unet.py @@ -236,31 +236,46 @@ def forward(self, x): # building block modules class Block(Module): - def __init__(self, dim, dim_out, groups = 8): + def __init__( + self, + dim, + dim_out, + mp_sum_t = 0.3 + ): super().__init__() - self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1, bias = False) + self.proj = WeightNormedConv2d(dim, dim_out, 3) self.act = MPSiLU() + self.mp_add = MPSum(t = mp_sum_t) def forward(self, x, scale = None): + res = x + x = self.proj(x) if exists(scale): x = x * (scale + 1) x = self.act(x) - return x + + return mp_add(x, res) class ResnetBlock(Module): - def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8): + def __init__( + self, + dim, + dim_out, + *, + time_emb_dim = None + ): super().__init__() self.mlp = nn.Sequential( MPSiLU(), - nn.Linear(time_emb_dim, dim_out * 2, bias = False) + nn.Linear(time_emb_dim, dim_out, bias = False) ) if exists(time_emb_dim) else None self.block1 = Block(dim, dim_out, groups = groups) self.block2 = Block(dim_out, dim_out, groups = groups) - self.res_conv = nn.Conv2d(dim, dim_out, 1, bias = False) if dim != dim_out else nn.Identity() + self.res_conv = WeightNormedConv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb = None): @@ -282,7 +297,8 @@ def __init__( heads = 4, dim_head = 32, num_mem_kv = 4, - flash = False + flash = False, + mp_sum_t = 0.3 ): super().__init__() self.heads = heads @@ -294,11 +310,13 @@ def __init__( self.attend = Attend(flash = flash, scale = dim_head ** 0.5) self.mem_kv = nn.Parameter(torch.randn(2, heads, num_mem_kv, dim_head)) - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) - self.to_out = nn.Conv2d(hidden_dim, dim, 1, bias = False) + self.to_qkv = WeightNormedConv2d(dim, hidden_dim * 3, 1) + self.to_out = WeightNormedConv2d(hidden_dim, dim, 1) + + self.mp_add = MPSum(t = mp_sum_t) def forward(self, x): - b, c, h, w = x.shape + res, b, c, h, w = x, *x.shape qkv = self.to_qkv(x).chunk(3, dim = 1) q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h (x y) c', h = self.heads), qkv) @@ -313,7 +331,9 @@ def forward(self, x): out = self.attend(q, k, v) out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w) - return self.to_out(out) + out = self.to_out(out) + + return self.mp_add(out, res) # model @@ -346,7 +366,7 @@ def __init__( input_channels = channels * (2 if self_condition else 1) init_dim = default(init_dim, dim) - self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3) + self.init_conv = WeightNormedConv2d(input_channels, init_dim, 7, concat_ones_to_input = True) dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) @@ -420,7 +440,7 @@ def __init__( self.out_dim = default(out_dim, default_out_dim) self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim) - self.final_conv = nn.Conv2d(dim, self.out_dim, 1, bias = False) + self.final_conv = WeightNormedConv2d(dim, self.out_dim, 1) @property def downsample_factor(self): From 1ac9ba55c9af35f35b87e7f5d06a4a137357cb9f Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 4 Feb 2024 11:05:08 -0800 Subject: [PATCH 17/21] use magnitude preserving cat for skip connections --- denoising_diffusion_pytorch/karras_unet.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/denoising_diffusion_pytorch/karras_unet.py b/denoising_diffusion_pytorch/karras_unet.py index 04c13d3a5..998cb7bcb 100644 --- a/denoising_diffusion_pytorch/karras_unet.py +++ b/denoising_diffusion_pytorch/karras_unet.py @@ -59,6 +59,8 @@ def __init__(self, dim): def forward(self, x): h, w = x.shape[-2:] + assert all([divisible_by(_, 2) for _ in (h, w)]) + x = F.interpolate(x, (h // 2, w // 2), mode = 'bilinear') x = self.conv(x) x = self.pixel_norm(x) @@ -355,7 +357,8 @@ def __init__( attn_dim_head = 32, attn_heads = 4, full_attn = None, # defaults to full attention only for inner most layer - flash_attn = False + flash_attn = False, + mp_cat_t = 0.5 ): super().__init__() @@ -411,6 +414,8 @@ def __init__( self.ups = ModuleList([]) num_resolutions = len(in_out) + self.mp_cat = MPCat(t = mp_cat_t, dim = 1) + for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(in_out, full_attn, attn_heads, attn_dim_head)): is_last = ind >= (num_resolutions - 1) @@ -475,10 +480,10 @@ def forward(self, x, time, x_self_cond = None): x = self.mid_block2(x, t) for block1, block2, attn, upsample in self.ups: - x = torch.cat((x, h.pop()), dim = 1) + x = self.mp_cat(x, h.pop()) x = block1(x, t) - x = torch.cat((x, h.pop()), dim = 1) + x = self.mp_cat(x, h.pop()) x = block2(x, t) x = attn(x) + x From 9b7fc320dee9722826f554e43921251a8948d7a4 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 5 Feb 2024 09:57:59 -0800 Subject: [PATCH 18/21] complete the mp fourier embedding --- denoising_diffusion_pytorch/karras_unet.py | 35 +++++----------------- 1 file changed, 8 insertions(+), 27 deletions(-) diff --git a/denoising_diffusion_pytorch/karras_unet.py b/denoising_diffusion_pytorch/karras_unet.py index 998cb7bcb..400b90cce 100644 --- a/denoising_diffusion_pytorch/karras_unet.py +++ b/denoising_diffusion_pytorch/karras_unet.py @@ -73,12 +73,6 @@ class MPSiLU(Module): def forward(self, x): return F.silu(x) / 0.596 -def sin(t): - return torch.sin(t) * (2 ** 0.5) - -def cos(t): - return torch.cos(t) * (2 ** 0.5) - # gain - layer scaling class Gain(Module): @@ -218,22 +212,18 @@ def forward(self, x): emb = torch.cat((sin(emb), cos(emb)), dim=-1) return emb -class RandomOrLearnedSinusoidalPosEmb(Module): - """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """ - """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ +class MPFourierEmbedding(Module): - def __init__(self, dim, is_random = False): + def __init__(self, dim): super().__init__() assert divisible_by(dim, 2) half_dim = dim // 2 - self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random) + self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = False) def forward(self, x): x = rearrange(x, 'b -> b 1') freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi - fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) - fouriered = torch.cat((x, fouriered), dim = -1) - return fouriered + return torch.cat((freqs.sin(), freqs.cos()), dim = -1) * (2 ** 0.5) # building block modules @@ -349,11 +339,8 @@ def __init__( channels = 3, self_condition = False, resnet_block_groups = 8, - learned_variance = False, - learned_sinusoidal_cond = False, - random_fourier_features = False, - learned_sinusoidal_dim = 16, - sinusoidal_pos_emb_theta = 10000, + sinusoidal_dim = 16, + fourier_theta = 10000, attn_dim_head = 32, attn_heads = 4, full_attn = None, # defaults to full attention only for inner most layer @@ -380,14 +367,8 @@ def __init__( time_dim = dim * 4 - self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features - - if self.random_or_learned_sinusoidal_cond: - sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features) - fourier_dim = learned_sinusoidal_dim + 1 - else: - sinu_pos_emb = SinusoidalPosEmb(dim, theta = sinusoidal_pos_emb_theta) - fourier_dim = dim + sinu_pos_emb = MPFourierEmbedding(sinusoidal_dim) + fourier_dim = sinusoidal_dim self.time_mlp = nn.Sequential( sinu_pos_emb, From de35b120ebe9cf939d413b087c0dcf4c75476c83 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 5 Feb 2024 10:02:25 -0800 Subject: [PATCH 19/21] cleanup --- denoising_diffusion_pytorch/karras_unet.py | 58 +++++----------------- 1 file changed, 13 insertions(+), 45 deletions(-) diff --git a/denoising_diffusion_pytorch/karras_unet.py b/denoising_diffusion_pytorch/karras_unet.py index 400b90cce..77bc139f6 100644 --- a/denoising_diffusion_pytorch/karras_unet.py +++ b/denoising_diffusion_pytorch/karras_unet.py @@ -38,9 +38,6 @@ def cast_tuple(t, length = 1): def divisible_by(numer, denom): return (numer % denom) == 0 -def identity(t, *args, **kwargs): - return t - # in paper, they use eps 1e-4 for pixelnorm def l2norm(t, dim = -1, eps = 1e-12): @@ -48,7 +45,7 @@ def l2norm(t, dim = -1, eps = 1e-12): # small helper modules -def Upsample(dim, dim_out = None): +def Upsample(dim): return nn.Upsample(scale_factor = 2, mode = 'bilinear') class Downsample(Module): @@ -185,35 +182,9 @@ def forward(self, x): weight = l2norm(self.weight, eps = self.eps) / sqrt(self.fan_in) return F.linear(x, weight) -# norm - -class RMSNorm(Module): - def __init__(self, dim): - super().__init__() - self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) - - def forward(self, x): - return F.normalize(x, dim = 1) * self.g * (x.shape[1] ** 0.5) - -# sinusoidal positional embeds - -class SinusoidalPosEmb(Module): - def __init__(self, dim, theta = 10000): - super().__init__() - self.dim = dim - self.theta = theta - - def forward(self, x): - device = x.device - half_dim = self.dim // 2 - emb = math.log(self.theta) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device=device) * -emb) - emb = x[:, None] * emb[None, :] - emb = torch.cat((sin(emb), cos(emb)), dim=-1) - return emb +# mp fourier embeds class MPFourierEmbedding(Module): - def __init__(self, dim): super().__init__() assert divisible_by(dim, 2) @@ -265,8 +236,8 @@ def __init__( nn.Linear(time_emb_dim, dim_out, bias = False) ) if exists(time_emb_dim) else None - self.block1 = Block(dim, dim_out, groups = groups) - self.block2 = Block(dim_out, dim_out, groups = groups) + self.block1 = Block(dim, dim_out) + self.block2 = Block(dim_out, dim_out) self.res_conv = WeightNormedConv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb = None): @@ -338,7 +309,6 @@ def __init__( dim_mults = (1, 2, 4, 8), channels = 3, self_condition = False, - resnet_block_groups = 8, sinusoidal_dim = 16, fourier_theta = 10000, attn_dim_head = 32, @@ -361,8 +331,6 @@ def __init__( dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) - block_klass = partial(ResnetBlock, groups = resnet_block_groups) - # time embeddings time_dim = dim * 4 @@ -401,31 +369,31 @@ def __init__( is_last = ind >= (num_resolutions - 1) self.downs.append(ModuleList([ - block_klass(dim_in, dim_in, time_emb_dim = time_dim), - block_klass(dim_in, dim_in, time_emb_dim = time_dim), + ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim), + ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim), CosineSimAttention(dim_in, dim_head = layer_attn_dim_head, heads = layer_attn_heads, flash = flash_attn), - Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) + Downsample(dim_in, dim_out) if not is_last else WeightNormedConv2d(dim_in, dim_out, 3) ])) mid_dim = dims[-1] - self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim = time_dim) self.mid_attn = CosineSimAttention(mid_dim, heads = attn_heads[-1], dim_head = attn_dim_head[-1]) - self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim = time_dim) for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(*map(reversed, (in_out, full_attn, attn_heads, attn_dim_head)))): is_last = ind == (len(in_out) - 1) self.ups.append(ModuleList([ - block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), - block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), + ResnetBlock(dim_out + dim_in, dim_out, time_emb_dim = time_dim), + ResnetBlock(dim_out + dim_in, dim_out, time_emb_dim = time_dim), CosineSimAttention(dim_out, dim_head = layer_attn_dim_head, heads = layer_attn_heads, flash = flash_attn), - Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1, bias = False) + Upsample(dim_out, dim_in) if not is_last else WeightNormedConv2d(dim_out, dim_in, 3) ])) default_out_dim = channels * (1 if not learned_variance else 2) self.out_dim = default(out_dim, default_out_dim) - self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim) + self.final_res_block = ResnetBlock(dim * 2, dim, time_emb_dim = time_dim) self.final_conv = WeightNormedConv2d(dim, self.out_dim, 1) @property From e82a4a03ae68eaa797a942008fdcc38404760e33 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 6 Feb 2024 11:45:10 -0800 Subject: [PATCH 20/21] almost done --- denoising_diffusion_pytorch/karras_unet.py | 441 ++++++++++++++------- 1 file changed, 295 insertions(+), 146 deletions(-) diff --git a/denoising_diffusion_pytorch/karras_unet.py b/denoising_diffusion_pytorch/karras_unet.py index 77bc139f6..5a764e3e0 100644 --- a/denoising_diffusion_pytorch/karras_unet.py +++ b/denoising_diffusion_pytorch/karras_unet.py @@ -1,5 +1,5 @@ import math -from math import sqrt +from math import sqrt, ceil from random import random from functools import partial @@ -24,6 +24,12 @@ def default(val, d): return val return d() if callable(d) else d +def xnor(x, y): + return not (x ^ y) + +def prepend(arr, el): + arr.insert(0, el) + def pack_one(t, pattern): return pack([t], pattern) @@ -43,26 +49,6 @@ def divisible_by(numer, denom): def l2norm(t, dim = -1, eps = 1e-12): return F.normalize(t, dim = dim, eps = eps) -# small helper modules - -def Upsample(dim): - return nn.Upsample(scale_factor = 2, mode = 'bilinear') - -class Downsample(Module): - def __init__(self, dim): - super().__init__() - self.conv = WeightNormedConv2d(dim, dim, 1) - self.pixel_norm = PixelNorm(dim = 1) - - def forward(self, x): - h, w = x.shape[-2:] - assert all([divisible_by(_, 2) for _ in (h, w)]) - - x = F.interpolate(x, (h // 2, w // 2), mode = 'bilinear') - x = self.conv(x) - x = self.pixel_norm(x) - return x - # mp activations # section 2.5 @@ -105,7 +91,7 @@ def forward(self, a, b): # empirically, they found t=0.3 for encoder / decoder / attention residuals # and for embedding, t=0.5 -class MPSum(Module): +class MPAdd(Module): def __init__(self, t): super().__init__() self.t = t @@ -133,7 +119,7 @@ def forward(self, x): # forced weight normed conv2d and linear # algorithm 1 in paper -class WeightNormedConv2d(Module): +class Conv2d(Module): def __init__( self, dim_in, @@ -165,7 +151,7 @@ def forward(self, x): return F.conv2d(x, weight, padding='same') -class WeightNormedLinear(Module): +class Linear(Module): def __init__(self, dim_in, dim_out, eps = 1e-4): super().__init__() weight = torch.randn(dim_out, dim_in) @@ -194,71 +180,169 @@ def __init__(self, dim): def forward(self, x): x = rearrange(x, 'b -> b 1') freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi - return torch.cat((freqs.sin(), freqs.cos()), dim = -1) * (2 ** 0.5) + return torch.cat((freqs.sin(), freqs.cos()), dim = -1) * sqrt(2) # building block modules -class Block(Module): +class Encoder(Module): def __init__( self, dim, - dim_out, - mp_sum_t = 0.3 + dim_out = None, + *, + emb_dim = None, + dropout = 0.1, + mp_add_t = 0.3, + has_attn = False, + attn_dim_head = 64, + downsample = False ): super().__init__() - self.proj = WeightNormedConv2d(dim, dim_out, 3) - self.act = MPSiLU() - self.mp_add = MPSum(t = mp_sum_t) + dim_out = default(dim_out, dim) + + self.downsample = downsample + self.downsample_conv = None + + if downsample: + self.downsample_conv = Conv2d(dim, dim_out, 1) + + self.pixel_norm = PixelNorm(dim = 1) + + self.to_emb = None + if exists(emb_dim): + self.to_emb = nn.Sequential( + Linear(emb_dim, dim_out), + Gain() + ) + + self.block1 = nn.Sequential( + MPSiLU(), + Conv2d(dim, dim_out, 3) + ) + + self.block2 = nn.Sequential( + MPSiLU(), + nn.Dropout(dropout), + Conv2d(dim_out, dim_out, 3) + ) + + self.res_mp_add = MPAdd(t = mp_add_t) + + self.attn = None + if has_attn: + self.attn = Attention( + dim = dim, + heads = ceil(dim / attn_dim_head), + dim_head = attn_dim_head + ) + + def forward( + self, + x, + emb = None + ): + if self.downsample: + h, w = x.shape[-2:] + x = F.interpolate(x, (h // 2, w // 2), mode = 'bilinear') + x = self.downsample_conv(x) + + x = self.pixel_norm(x) + + res = x.clone() - def forward(self, x, scale = None): - res = x + x = self.block1(x) - x = self.proj(x) + if exists(emb): + scale = self.to_emb(emb) + 1 + x = x * scale - if exists(scale): - x = x * (scale + 1) + x = self.block2(x) - x = self.act(x) + x = self.res_mp_add(x, res) - return mp_add(x, res) + if exists(self.attn): + x = self.attn(x) -class ResnetBlock(Module): + return x + +class Decoder(Module): def __init__( self, dim, - dim_out, + dim_out = None, *, - time_emb_dim = None + emb_dim = None, + dropout = 0.1, + mp_add_t = 0.3, + has_attn = False, + attn_dim_head = 64, + upsample = False ): super().__init__() - self.mlp = nn.Sequential( + dim_out = default(dim_out, dim) + + self.upsample = upsample + + self.to_emb = None + if exists(emb_dim): + self.to_emb = nn.Sequential( + Linear(emb_dim, dim_out), + Gain() + ) + + self.block1 = nn.Sequential( MPSiLU(), - nn.Linear(time_emb_dim, dim_out, bias = False) - ) if exists(time_emb_dim) else None + Conv2d(dim, dim_out, 3) + ) - self.block1 = Block(dim, dim_out) - self.block2 = Block(dim_out, dim_out) - self.res_conv = WeightNormedConv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() + self.block2 = nn.Sequential( + MPSiLU(), + nn.Dropout(dropout), + Conv2d(dim_out, dim_out, 3) + ) - def forward(self, x, time_emb = None): + self.res_conv = Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() - scale = None - if exists(self.mlp) and exists(time_emb): - scale = self.mlp(time_emb) - scale = rearrange(scale, 'b c -> b c 1 1') + self.res_mp_add = MPAdd(t = mp_add_t) - h = self.block1(x, scale = scale) + self.attn = None + if has_attn: + self.attn = Attention( + dim = dim, + heads = ceil(dim / attn_dim_head), + dim_head = attn_dim_head + ) - h = self.block2(h) + def forward( + self, + x, + emb = None + ): + res = self.res_conv(x) - return h + self.res_conv(x) + x = self.block1(x) -class CosineSimAttention(Module): + if exists(emb): + scale = self.to_emb(emb) + 1 + x = x * rearrange(scale, 'b c -> b c 1 1') + + x = self.block2(x) + + x = self.res_mp_add(x, res) + + if exists(self.attn): + x = self.attn(x) + + return x + +# attention + +class Attention(Module): def __init__( self, dim, heads = 4, - dim_head = 32, + dim_head = 64, num_mem_kv = 4, flash = False, mp_sum_t = 0.3 @@ -270,13 +354,13 @@ def __init__( self.pixel_norm = PixelNorm(dim = 1) # equation (34) - they used cosine sim of queries and keys with a fixed scale of sqrt(Nc) - self.attend = Attend(flash = flash, scale = dim_head ** 0.5) + self.attend = Attend(flash = flash, scale = sqrt(dim_head)) self.mem_kv = nn.Parameter(torch.randn(2, heads, num_mem_kv, dim_head)) - self.to_qkv = WeightNormedConv2d(dim, hidden_dim * 3, 1) - self.to_out = WeightNormedConv2d(hidden_dim, dim, 1) + self.to_qkv = Conv2d(dim, hidden_dim * 3, 1) + self.to_out = Conv2d(hidden_dim, dim, 1) - self.mp_add = MPSum(t = mp_sum_t) + self.mp_add = MPAdd(t = mp_sum_t) def forward(self, x): res, b, c, h, w = x, *x.shape @@ -298,147 +382,212 @@ def forward(self, x): return self.mp_add(out, res) -# model +# unet proposed by karras +# bias-less, no group-norms, with magnitude preserving operations class KarrasUnet(Module): + """ + going by figure 21. config G + """ + def __init__( self, - dim, - init_dim = None, - out_dim = None, - dim_mults = (1, 2, 4, 8), - channels = 3, - self_condition = False, - sinusoidal_dim = 16, - fourier_theta = 10000, - attn_dim_head = 32, - attn_heads = 4, - full_attn = None, # defaults to full attention only for inner most layer - flash_attn = False, - mp_cat_t = 0.5 + *, + image_size, + dim = 192, + dim_max = 768, # channels will double every downsample and cap out to this value + num_classes = None, # in paper, they do 1000 classes for a popular benchmark + channels = 4, # 4 channels in paper for some reason, must be alpha channel? + num_downsamples = 3, + num_blocks_per_stage = 4, + attn_res = (16, 8), + fourier_dim = 16, + attn_dim_head = 64, + attn_flash = False, + mp_cat_t = 0.5, + mp_add_emb_t = 0.5, + attn_mp_sum_t = 0.3, + resnet_mp_sum_t = 0.3, + dropout = 0.1, + self_condition = False ): super().__init__() + self.self_condition = self_condition + # determine dimensions self.channels = channels - self.self_condition = self_condition + self.image_size = image_size input_channels = channels * (2 if self_condition else 1) - init_dim = default(init_dim, dim) - self.init_conv = WeightNormedConv2d(input_channels, init_dim, 7, concat_ones_to_input = True) + # input and output blocks - dims = [init_dim, *map(lambda m: dim * m, dim_mults)] - in_out = list(zip(dims[:-1], dims[1:])) + self.input_block = Conv2d(input_channels, dim, 3, concat_ones_to_input = True) - # time embeddings + self.output_block = nn.Sequential( + Conv2d(dim, channels, 3), + Gain() + ) - time_dim = dim * 4 + # time embedding - sinu_pos_emb = MPFourierEmbedding(sinusoidal_dim) - fourier_dim = sinusoidal_dim + emb_dim = dim * 4 - self.time_mlp = nn.Sequential( - sinu_pos_emb, - nn.Linear(fourier_dim, time_dim, bias = False), - nn.GELU(), - nn.Linear(time_dim, time_dim, bias = False) + self.to_time_emb = nn.Sequential( + MPFourierEmbedding(fourier_dim), + Linear(fourier_dim, emb_dim) ) + # class embedding + + self.needs_class_labels = exists(num_classes) + self.num_classes = num_classes + + self.to_class_emb = Linear(num_classes, 4 * dim) + self.add_class_emb = MPAdd(t = mp_add_emb_t) + + # final embedding activations + + self.emb_activation = MPSiLU() + + # number of downsamples + + self.num_downsamples = num_downsamples + # attention - if not full_attn: - full_attn = (*((False,) * (len(dim_mults) - 1)), True) + attn_kwargs = dict( + flash = attn_flash, + dim_head = attn_dim_head, + mp_sum_t = attn_mp_sum_t + ) - num_stages = len(dim_mults) - full_attn = cast_tuple(full_attn, num_stages) - attn_heads = cast_tuple(attn_heads, num_stages) - attn_dim_head = cast_tuple(attn_dim_head, num_stages) + attn_res = cast_tuple(attn_res) - assert len(full_attn) == len(dim_mults) + # resnet block - # layers + block_kwargs = dict( + dropout = dropout, + emb_dim = emb_dim + ) + # unet encoder and decoders + + stages = num_downsamples + 1 self.downs = ModuleList([]) self.ups = ModuleList([]) - num_resolutions = len(in_out) - self.mp_cat = MPCat(t = mp_cat_t, dim = 1) + curr_dim = dim + curr_res = image_size + + stage_dims = [dim] + + skip_dims = [dim] + self.skip_mp_cat = MPCat(t = mp_cat_t, dim = 1) + + # take care of skip connection for initial input block + + prepend(self.ups, Decoder(dim + skip_dims.pop(), dim, **block_kwargs)) + + # stages - for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(in_out, full_attn, attn_heads, attn_dim_head)): - is_last = ind >= (num_resolutions - 1) + return + + for stage in range(stages): + is_last = stage == (stages - 1) + downsample = not is_last + has_attn = curr_res in attn_res + + if downsample: + dim_out = min(curr_dim * 2, dim_max) self.downs.append(ModuleList([ ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim), ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim), - CosineSimAttention(dim_in, dim_head = layer_attn_dim_head, heads = layer_attn_heads, flash = flash_attn), - Downsample(dim_in, dim_out) if not is_last else WeightNormedConv2d(dim_in, dim_out, 3) + Attention(dim_in, dim_head = layer_attn_dim_head, heads = layer_attn_heads, flash = flash_attn), + Downsample(dim_in, dim_out) if not is_last else Conv2d(dim_in, dim_out, 3) ])) - mid_dim = dims[-1] - self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim = time_dim) - self.mid_attn = CosineSimAttention(mid_dim, heads = attn_heads[-1], dim_head = attn_dim_head[-1]) - self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim = time_dim) + if downsample: + curr_res //= 2 + curr_dim = dim_out + stage_dims.append(dim_out) - for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(*map(reversed, (in_out, full_attn, attn_heads, attn_dim_head)))): - is_last = ind == (len(in_out) - 1) + for stage in range(stages): + is_first = stage == 0 self.ups.append(ModuleList([ ResnetBlock(dim_out + dim_in, dim_out, time_emb_dim = time_dim), ResnetBlock(dim_out + dim_in, dim_out, time_emb_dim = time_dim), - CosineSimAttention(dim_out, dim_head = layer_attn_dim_head, heads = layer_attn_heads, flash = flash_attn), - Upsample(dim_out, dim_in) if not is_last else WeightNormedConv2d(dim_out, dim_in, 3) + Attention(dim_out, dim_head = layer_attn_dim_head, heads = layer_attn_heads, flash = flash_attn), + Upsample(dim_out, dim_in) if not is_last else Conv2d(dim_out, dim_in, 3) ])) - default_out_dim = channels * (1 if not learned_variance else 2) - self.out_dim = default(out_dim, default_out_dim) - - self.final_res_block = ResnetBlock(dim * 2, dim, time_emb_dim = time_dim) - self.final_conv = WeightNormedConv2d(dim, self.out_dim, 1) - @property def downsample_factor(self): - return 2 ** (len(self.downs) - 1) + return 2 ** self.num_downsamples + + def forward( + self, + x, + time, + self_cond = None, + class_labels = None + ): + # validate image shape + + assert x.shape[1:] == (self.channels, self.image_size, self.image_size) - def forward(self, x, time, x_self_cond = None): - assert all([divisible_by(d, self.downsample_factor) for d in x.shape[-2:]]), f'your input dimensions {x.shape[-2:]} need to be divisible by {self.downsample_factor}, given the unet' + # self conditioning if self.self_condition: - x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) - x = torch.cat((x_self_cond, x), dim = 1) + self_cond = default(self_cond, lambda: torch.zeros_like(x)) + x = torch.cat((self_cond, x), dim = 1) + else: + assert not exists(self_cond) + + # time condition + + time_emb = self.to_time_emb(time) + + # class condition + + assert xnor(exists(class_labels), self.needs_class_labels) + + if self.needs_class_labels: + if class_labels.dtype in (torch.int, torch.long): + class_labels = F.one_hot(class_labels, self.num_classes) + + assert class_labels.shape[-1] == self.num_classes + class_labels = class_labels.float() * sqrt(self.num_classes) + + class_emb = self.to_class_emb(class_labels) + + time_emb = self.add_class_emb(time_emb, class_emb) - x = self.init_conv(x) - r = x.clone() + # final mp-silu for embedding - t = self.time_mlp(time) + emb = self.emb_activation(time_emb) - h = [] + # skip connections - for block1, block2, attn, downsample in self.downs: - x = block1(x, t) - h.append(x) + skips = [] - x = block2(x, t) - x = attn(x) + x - h.append(x) + # input block - x = downsample(x) + x = self.input_block(x) - x = self.mid_block1(x, t) - x = self.mid_attn(x) + x - x = self.mid_block2(x, t) + skips.append(x) - for block1, block2, attn, upsample in self.ups: - x = self.mp_cat(x, h.pop()) - x = block1(x, t) + # down - x = self.mp_cat(x, h.pop()) - x = block2(x, t) - x = attn(x) + x + # up - x = upsample(x) + for decoder in self.ups: + x = self.skip_mp_cat(x, skips.pop()) + x = decoder(x, emb = emb) - x = torch.cat((x, r), dim = 1) + # output block - x = self.final_res_block(x, t) - return self.final_conv(x) + return self.output_block(x) From 42a9e793ad6623bb93d3b4dd4b97760ca51084aa Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 6 Feb 2024 14:28:46 -0800 Subject: [PATCH 21/21] complete the karras unet --- denoising_diffusion_pytorch/karras_unet.py | 155 ++++++++++++++------- denoising_diffusion_pytorch/version.py | 2 +- 2 files changed, 104 insertions(+), 53 deletions(-) diff --git a/denoising_diffusion_pytorch/karras_unet.py b/denoising_diffusion_pytorch/karras_unet.py index 5a764e3e0..96c2a6c9e 100644 --- a/denoising_diffusion_pytorch/karras_unet.py +++ b/denoising_diffusion_pytorch/karras_unet.py @@ -27,6 +27,9 @@ def default(val, d): def xnor(x, y): return not (x ^ y) +def append(arr, el): + arr.append(el) + def prepend(arr, el): arr.insert(0, el) @@ -195,6 +198,8 @@ def __init__( mp_add_t = 0.3, has_attn = False, attn_dim_head = 64, + attn_res_mp_add_t = 0.3, + attn_flash = False, downsample = False ): super().__init__() @@ -203,8 +208,10 @@ def __init__( self.downsample = downsample self.downsample_conv = None + curr_dim = dim if downsample: - self.downsample_conv = Conv2d(dim, dim_out, 1) + self.downsample_conv = Conv2d(curr_dim, dim_out, 1) + curr_dim = dim_out self.pixel_norm = PixelNorm(dim = 1) @@ -217,7 +224,7 @@ def __init__( self.block1 = nn.Sequential( MPSiLU(), - Conv2d(dim, dim_out, 3) + Conv2d(curr_dim, dim_out, 3) ) self.block2 = nn.Sequential( @@ -231,9 +238,11 @@ def __init__( self.attn = None if has_attn: self.attn = Attention( - dim = dim, - heads = ceil(dim / attn_dim_head), - dim_head = attn_dim_head + dim = dim_out, + heads = ceil(dim_out / attn_dim_head), + dim_head = attn_dim_head, + mp_add_t = attn_res_mp_add_t, + flash = attn_flash ) def forward( @@ -254,7 +263,7 @@ def forward( if exists(emb): scale = self.to_emb(emb) + 1 - x = x * scale + x = x * rearrange(scale, 'b c -> b c 1 1') x = self.block2(x) @@ -276,12 +285,15 @@ def __init__( mp_add_t = 0.3, has_attn = False, attn_dim_head = 64, + attn_res_mp_add_t = 0.3, + attn_flash = False, upsample = False ): super().__init__() dim_out = default(dim_out, dim) self.upsample = upsample + self.needs_skip = not upsample self.to_emb = None if exists(emb_dim): @@ -308,9 +320,11 @@ def __init__( self.attn = None if has_attn: self.attn = Attention( - dim = dim, - heads = ceil(dim / attn_dim_head), - dim_head = attn_dim_head + dim = dim_out, + heads = ceil(dim_out / attn_dim_head), + dim_head = attn_dim_head, + mp_add_t = attn_res_mp_add_t, + flash = attn_flash ) def forward( @@ -318,6 +332,10 @@ def forward( x, emb = None ): + if self.upsample: + h, w = x.shape[-2:] + x = F.interpolate(x, (h * 2, w * 2), mode = 'bilinear') + res = self.res_conv(x) x = self.block1(x) @@ -345,7 +363,7 @@ def __init__( dim_head = 64, num_mem_kv = 4, flash = False, - mp_sum_t = 0.3 + mp_add_t = 0.3 ): super().__init__() self.heads = heads @@ -360,7 +378,7 @@ def __init__( self.to_qkv = Conv2d(dim, hidden_dim * 3, 1) self.to_out = Conv2d(hidden_dim, dim, 1) - self.mp_add = MPAdd(t = mp_sum_t) + self.mp_add = MPAdd(t = mp_add_t) def forward(self, x): res, b, c, h, w = x, *x.shape @@ -406,8 +424,8 @@ def __init__( attn_flash = False, mp_cat_t = 0.5, mp_add_emb_t = 0.5, - attn_mp_sum_t = 0.3, - resnet_mp_sum_t = 0.3, + attn_res_mp_add_t = 0.3, + resnet_mp_add_t = 0.3, dropout = 0.1, self_condition = False ): @@ -457,72 +475,73 @@ def __init__( # attention - attn_kwargs = dict( - flash = attn_flash, - dim_head = attn_dim_head, - mp_sum_t = attn_mp_sum_t - ) - - attn_res = cast_tuple(attn_res) + attn_res = set(cast_tuple(attn_res)) # resnet block block_kwargs = dict( dropout = dropout, - emb_dim = emb_dim + emb_dim = emb_dim, + attn_dim_head = attn_dim_head, + attn_res_mp_add_t = attn_res_mp_add_t, + attn_flash = attn_flash ) # unet encoder and decoders - stages = num_downsamples + 1 self.downs = ModuleList([]) self.ups = ModuleList([]) curr_dim = dim curr_res = image_size - stage_dims = [dim] - - skip_dims = [dim] self.skip_mp_cat = MPCat(t = mp_cat_t, dim = 1) - # take care of skip connection for initial input block + # take care of skip connection for initial input block and first three encoder blocks + + prepend(self.ups, Decoder(dim * 2, dim, **block_kwargs)) + + assert num_blocks_per_stage >= 1 + + for _ in range(num_blocks_per_stage): + enc = Encoder(curr_dim, curr_dim, **block_kwargs) + dec = Decoder(curr_dim * 2, curr_dim, **block_kwargs) - prepend(self.ups, Decoder(dim + skip_dims.pop(), dim, **block_kwargs)) + append(self.downs, enc) + prepend(self.ups, dec) # stages - return + for _ in range(self.num_downsamples): + dim_out = min(dim_max, curr_dim * 2) + upsample = Decoder(dim_out, curr_dim, has_attn = curr_res in attn_res, upsample = True, **block_kwargs) - for stage in range(stages): - is_last = stage == (stages - 1) - downsample = not is_last + curr_res //= 2 has_attn = curr_res in attn_res - if downsample: - dim_out = min(curr_dim * 2, dim_max) + downsample = Encoder(curr_dim, dim_out, downsample = True, has_attn = has_attn, **block_kwargs) - self.downs.append(ModuleList([ - ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim), - ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim), - Attention(dim_in, dim_head = layer_attn_dim_head, heads = layer_attn_heads, flash = flash_attn), - Downsample(dim_in, dim_out) if not is_last else Conv2d(dim_in, dim_out, 3) - ])) + append(self.downs, downsample) + prepend(self.ups, upsample) + prepend(self.ups, Decoder(dim_out * 2, dim_out, has_attn = has_attn, **block_kwargs)) - if downsample: - curr_res //= 2 - curr_dim = dim_out - stage_dims.append(dim_out) + for _ in range(num_blocks_per_stage): + enc = Encoder(dim_out, dim_out, has_attn = has_attn, **block_kwargs) + dec = Decoder(dim_out * 2, dim_out, has_attn = has_attn, **block_kwargs) - for stage in range(stages): - is_first = stage == 0 + append(self.downs, enc) + prepend(self.ups, dec) - self.ups.append(ModuleList([ - ResnetBlock(dim_out + dim_in, dim_out, time_emb_dim = time_dim), - ResnetBlock(dim_out + dim_in, dim_out, time_emb_dim = time_dim), - Attention(dim_out, dim_head = layer_attn_dim_head, heads = layer_attn_heads, flash = flash_attn), - Upsample(dim_out, dim_in) if not is_last else Conv2d(dim_out, dim_in, 3) - ])) + curr_dim = dim_out + + # take care of the two middle decoders + + mid_has_attn = curr_res in attn_res + + self.mids = ModuleList([ + Decoder(curr_dim, curr_dim, has_attn = mid_has_attn, **block_kwargs), + Decoder(curr_dim, curr_dim, has_attn = mid_has_attn, **block_kwargs), + ]) @property def downsample_factor(self): @@ -582,12 +601,44 @@ def forward( # down + for encoder in self.downs: + x = encoder(x, emb = emb) + skips.append(x) + + # mid + + for decoder in self.mids: + x = decoder(x, emb = emb) + # up for decoder in self.ups: - x = self.skip_mp_cat(x, skips.pop()) + if decoder.needs_skip: + skip = skips.pop() + x = self.skip_mp_cat(x, skip) + x = decoder(x, emb = emb) # output block return self.output_block(x) + +# example + +if __name__ == '__main__': + unet = KarrasUnet( + image_size = 64, + dim = 192, + dim_max = 768, + num_classes = 1000, + ) + + images = torch.randn(2, 4, 64, 64) + + denoised_images = unet( + images, + time = torch.ones(2,), + class_labels = torch.randint(0, 1000, (2,)) + ) + + assert denoised_images.shape == images.shape diff --git a/denoising_diffusion_pytorch/version.py b/denoising_diffusion_pytorch/version.py index 4c3e0583c..52af183e5 100644 --- a/denoising_diffusion_pytorch/version.py +++ b/denoising_diffusion_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.9.6' +__version__ = '1.10.0'