diff --git a/lightweight_gan/lightweight_gan.py b/lightweight_gan/lightweight_gan.py index ffb8b8e..bc1d107 100644 --- a/lightweight_gan/lightweight_gan.py +++ b/lightweight_gan/lightweight_gan.py @@ -156,15 +156,26 @@ def forward(self, x): fn = self.fn if random() < self.prob else self.fn_else return fn(x) -class LayerScale(nn.Module): +class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn - scale = torch.zeros(1, dim, 1, 1).fill_(1e-3) - self.g = nn.Parameter(scale) + self.norm = ChanNorm(dim) def forward(self, x): - return self.g * self.fn(x) + return self.fn(self.norm(x)) + +class ChanNorm(nn.Module): + def __init__(self, dim, eps = 1e-5): + super().__init__() + self.eps = eps + self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) + self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) + + def forward(self, x): + std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt() + mean = torch.mean(x, dim = 1, keepdim = True) + return (x - mean) / (std + self.eps) * self.g + self.b class Residual(nn.Module): def __init__(self, fn): @@ -483,7 +494,7 @@ def __init__( attn = None if image_width in attn_res_layers: - attn = LayerScale(chan_in, LinearAttention(chan_in)) + attn = PreNorm(chan_in, LinearAttention(chan_in)) sle = None if res in self.sle_map: @@ -629,7 +640,7 @@ def __init__( attn = None if image_width in attn_res_layers: - attn = LayerScale(chan_in, LinearAttention(chan_in)) + attn = PreNorm(chan_in, LinearAttention(chan_in)) self.residual_layers.append(nn.ModuleList([ SumBranches([ @@ -667,7 +678,7 @@ def __init__( self.to_shape_disc_out = nn.Sequential( nn.Conv2d(init_channel, 64, 3, padding = 1), - Residual(LayerScale(64, LinearAttention(64))), + Residual(PreNorm(64, LinearAttention(64))), SumBranches([ nn.Sequential( Blur(), @@ -683,7 +694,7 @@ def __init__( nn.LeakyReLU(0.1), ) ]), - Residual(LayerScale(32, LinearAttention(32))), + Residual(PreNorm(32, LinearAttention(32))), nn.AdaptiveAvgPool2d((4, 4)), nn.Conv2d(32, 1, 4) ) diff --git a/lightweight_gan/version.py b/lightweight_gan/version.py index 482e4a1..2f15b8c 100644 --- a/lightweight_gan/version.py +++ b/lightweight_gan/version.py @@ -1 +1 @@ -__version__ = '0.19.0' +__version__ = '0.20.0'