Skip to content

Commit

Permalink
use layernorm for attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 7, 2021
1 parent f4cd064 commit 49c4b27
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
27 changes: 19 additions & 8 deletions lightweight_gan/lightweight_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,15 +156,26 @@ def forward(self, x):
fn = self.fn if random() < self.prob else self.fn_else
return fn(x)

class LayerScale(nn.Module):
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
scale = torch.zeros(1, dim, 1, 1).fill_(1e-3)
self.g = nn.Parameter(scale)
self.norm = ChanNorm(dim)

def forward(self, x):
return self.g * self.fn(x)
return self.fn(self.norm(x))

class ChanNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

def forward(self, x):
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b

class Residual(nn.Module):
def __init__(self, fn):
Expand Down Expand Up @@ -483,7 +494,7 @@ def __init__(

attn = None
if image_width in attn_res_layers:
attn = LayerScale(chan_in, LinearAttention(chan_in))
attn = PreNorm(chan_in, LinearAttention(chan_in))

sle = None
if res in self.sle_map:
Expand Down Expand Up @@ -629,7 +640,7 @@ def __init__(

attn = None
if image_width in attn_res_layers:
attn = LayerScale(chan_in, LinearAttention(chan_in))
attn = PreNorm(chan_in, LinearAttention(chan_in))

self.residual_layers.append(nn.ModuleList([
SumBranches([
Expand Down Expand Up @@ -667,7 +678,7 @@ def __init__(

self.to_shape_disc_out = nn.Sequential(
nn.Conv2d(init_channel, 64, 3, padding = 1),
Residual(LayerScale(64, LinearAttention(64))),
Residual(PreNorm(64, LinearAttention(64))),
SumBranches([
nn.Sequential(
Blur(),
Expand All @@ -683,7 +694,7 @@ def __init__(
nn.LeakyReLU(0.1),
)
]),
Residual(LayerScale(32, LinearAttention(32))),
Residual(PreNorm(32, LinearAttention(32))),
nn.AdaptiveAvgPool2d((4, 4)),
nn.Conv2d(32, 1, 4)
)
Expand Down
2 changes: 1 addition & 1 deletion lightweight_gan/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.19.0'
__version__ = '0.20.0'

0 comments on commit 49c4b27

Please sign in to comment.