Skip to content

Commit

Permalink
upgrade to best downsample type
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 25, 2022
1 parent 33b454f commit 7cebbbe
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 7 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -311,4 +311,14 @@ If you want the current state of the art GAN, you can find it at https://github.
}
```

```bibtex
@article{Sunkara2022NoMS,
title = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
author = {Raja Sunkara and Tie Luo},
journal = {ArXiv},
year = {2022},
volume = {abs/2208.03641}
}
```

*What I cannot create, I do not understand* - Richard Feynman
22 changes: 16 additions & 6 deletions lightweight_gan/lightweight_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from tqdm import tqdm
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange

from adabelief_pytorch import AdaBelief

Expand Down Expand Up @@ -456,6 +457,15 @@ def init_conv_(self, conv):
def forward(self, x):
return self.net(x)

def SPConvDownsample(dim, dim_out = None):
# https://arxiv.org/abs/2208.03641 shows this is the most optimal way to downsample
# named SP-conv in the paper, but basically a pixel unshuffle
dim_out = default(dim_out, dim)
return nn.Sequential(
Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2),
nn.Conv2d(dim * 4, dim_out, 1)
)

# squeeze excitation classes

# global context network
Expand Down Expand Up @@ -611,9 +621,9 @@ def __init__(

layer = nn.ModuleList([
nn.Sequential(
PixelShuffleUpsample(chan_in, chan_out),
PixelShuffleUpsample(chan_in),
Blur(),
Conv2dSame(chan_out, chan_out * 2, 4),
Conv2dSame(chan_in, chan_out * 2, 4),
Noise(),
norm_class(chan_out * 2),
nn.GLU(dim = 1)
Expand Down Expand Up @@ -667,8 +677,8 @@ def __init__(
last_layer = ind == (num_upsamples - 1)
chan_out = chans if not last_layer else final_chan * 2
layer = nn.Sequential(
PixelShuffleUpsample(chans, chan_out),
nn.Conv2d(chan_out, chan_out, 3, padding = 1),
PixelShuffleUpsample(chans),
nn.Conv2d(chans, chan_out, 3, padding = 1),
nn.GLU(dim = 1)
)
self.layers.append(layer)
Expand Down Expand Up @@ -743,7 +753,7 @@ def __init__(
SumBranches([
nn.Sequential(
Blur(),
nn.Conv2d(chan_in, chan_out, 4, stride = 2, padding = 1),
SPConvDownsample(chan_in, chan_out),
nn.LeakyReLU(0.1),
nn.Conv2d(chan_out, chan_out, 3, padding = 1),
nn.LeakyReLU(0.1)
Expand Down Expand Up @@ -779,7 +789,7 @@ def __init__(
SumBranches([
nn.Sequential(
Blur(),
nn.Conv2d(64, 32, 4, stride = 2, padding = 1),
SPConvDownsample(64, 32),
nn.LeakyReLU(0.1),
nn.Conv2d(32, 32, 3, padding = 1),
nn.LeakyReLU(0.1)
Expand Down
2 changes: 1 addition & 1 deletion lightweight_gan/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.0.0'
__version__ = '1.1.0'

0 comments on commit 7cebbbe

Please sign in to comment.