Skip to content

Commit

Permalink
release 0.1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 16, 2020
1 parent 2c98d99 commit 3cd5d84
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 26 deletions.
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
## 'Lightweight' GAN (wip)
## 'Lightweight' GAN

Implementation of an extremely <a href="https://openreview.net/forum?id=1Fqg133qRaI">'lightweight' GAN</a> proposed in ICLR 2021, in Pytorch. The main contributions of the paper is a skip-layer excitation in the generator, paired with autoencoding self-supervised learning in the discriminator. Quoting the one-line summary "converge on single gpu with few hours' training, on 1024 resolution sub-hundred images".

## Install

```bash
$ pip install lightweight-gan
```

## Usage

```bash
$ lightweight_gan --data ./path/to/images --image-size 512
```

## Citations

```bibtex
Expand Down
139 changes: 115 additions & 24 deletions lightweight_gan/lightweight_gan.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
import json
import multiprocessing
from random import random
import math
from math import log2, floor
from functools import partial
from contextlib import contextmanager
from pathlib import Path
from shutil import rmtree

import torch
from torch.optim import Adam
from torch import nn, einsum
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import Dataset, DataLoader
from torch.autograd import grad as torch_grad

from PIL import Image
import torchvision
from torchvision import transforms

from lightweight_gan.diff_augment import DiffAugment
from lightweight_gan.version import __version__

from tqdm import tqdm
from einops import rearrange
from pytorch_fid import fid_score

Expand All @@ -34,6 +42,10 @@
def exists(val):
return val is not None

@contextmanager
def null_context():
yield

def default(val, d):
return val if exists(val) else d

Expand All @@ -59,6 +71,34 @@ def gradient_penalty(images, outputs, weight = 10):
gradients = gradients.reshape(batch_size, -1)
return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()

def gradient_accumulate_contexts(gradient_accumulate_every, is_ddp, ddps):
if is_ddp:
num_no_syncs = gradient_accumulate_every - 1
head = [combine_contexts(map(lambda ddp: ddp.no_sync, ddps))] * num_no_syncs
tail = [null_context]
contexts = head + tail
else:
contexts = [null_context] * gradient_accumulate_every

for context in contexts:
with context():
yield

def evaluate_in_chunks(max_batch_size, model, *args):
split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args))))
chunked_outputs = [model(*i) for i in split_args]
if len(chunked_outputs) == 1:
return chunked_outputs[0]
return torch.cat(chunked_outputs, dim=0)

def slerp(val, low, high):
low_norm = low / torch.norm(low, dim=1, keepdim=True)
high_norm = high / torch.norm(high, dim=1, keepdim=True)
omega = torch.acos((low_norm * high_norm).sum(1))
so = torch.sin(omega)
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
return res

# helper classes

class NanException(Exception):
Expand All @@ -73,9 +113,19 @@ def update_average(self, old, new):
return new
return old * self.beta + (1 - self.beta) * new

class RandomApply(nn.Module):
def __init__(self, prob, fn, fn_else = lambda x: x):
super().__init__()
self.fn = fn
self.fn_else = fn_else
self.prob = prob
def forward(self, x):
fn = self.fn if random() < self.prob else self.fn_else
return fn(x)

# dataset

def convert_image_to(type_name, image):
def convert_image_to(img_type, image):
if image.mode != img_type:
return image.convert(img_type)
return image
Expand Down Expand Up @@ -118,7 +168,7 @@ def __init__(self, folder, image_size, transparent = False, aug_prob = 0.):
self.paths = [p for ext in EXTS for p in Path(f'{folder}').glob(f'**/*.{ext}')]
assert len(self.paths) > 0, f'No images were found in {folder} for training'

convert_image_fn = partial(convert_image_to, 'RGBA' if transparent else 'RBG')
convert_image_fn = partial(convert_image_to, 'RGBA' if transparent else 'RGB')
num_channels = 3 if not transparent else 4

self.transform = transforms.Compose([
Expand All @@ -138,6 +188,28 @@ def __getitem__(self, index):
img = Image.open(path)
return self.transform(img)

# augmentations

def random_hflip(tensor, prob):
if prob > random():
return tensor
return torch.flip(tensor, dims=(3,))

class AugWrapper(nn.Module):
def __init__(self, D, image_size):
super().__init__()
self.D = D

def forward(self, images, prob = 0., types = [], detach = False):
if random() < prob:
images = random_hflip(images, prob=0.5)
images = DiffAugment(images, types=types)

if detach:
images = images.detach()

return self.D(images)

# classes

class SLE(nn.Module):
Expand Down Expand Up @@ -288,6 +360,11 @@ def __init__(

features = list(map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), range(8, 2, -1)))
features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features))

if num_non_residual_layers == 0:
res, _ = features[0]
features[0] = (res, init_channel)

chan_in_out = zip(features[:-1], features[1:])

self.non_residual_layers = nn.ModuleList([])
Expand Down Expand Up @@ -383,7 +460,8 @@ def __init__(
fmap_inverse_coef = 12,
transparent = False,
ttur_mult = 1.5,
lr = 2e-4
lr = 2e-4,
rank = 0
):
super().__init__()
self.latent_dim = latent_dim
Expand All @@ -397,11 +475,20 @@ def __init__(
set_requires_grad(self.GE, False)

self.D = Discriminator(image_size = image_size, fmap_max = fmap_max, fmap_inverse_coef = fmap_inverse_coef, transparent = transparent)
self.D_aug = AugWrapper(self.D, image_size)

self.G_opt = Adam(self.G.parameters(), lr = lr, betas=(0.5, 0.9))
self.D_opt = Adam(self.D.parameters(), lr = lr * ttur_mult, betas=(0.5, 0.9))

self.cuda()
self._init_weights()
self.reset_parameter_averaging()

self.cuda(rank)

def _init_weights(self):
for m in self.modules():
if type(m) in {nn.Conv2d, nn.Linear}:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')

def EMA(self):
def update_moving_average(ma_model, current_model):
Expand All @@ -426,8 +513,8 @@ def __init__(
results_dir = 'results',
models_dir = 'models',
base_dir = './',
latent_dim = 256,
image_size = 128,
network_capacity = 16,
fmap_max = 512,
transparent = False,
batch_size = 4,
Expand Down Expand Up @@ -462,6 +549,7 @@ def __init__(
self.config_path = self.models_dir / name / '.config.json'

assert log2(image_size).is_integer(), 'image size must be a power of 2 (64, 128, 256, 512, 1024)'
self.latent_dim = latent_dim
self.image_size = image_size
self.fmap_max = fmap_max
self.transparent = transparent
Expand All @@ -483,6 +571,7 @@ def __init__(
self.d_loss = 0
self.g_loss = 0
self.last_gp_loss = None
self.last_recon_loss = None
self.last_fid = None

self.init_folders()
Expand Down Expand Up @@ -510,6 +599,7 @@ def init_GAN(self):

self.GAN = LightweightGAN(
lr = self.lr,
latent_dim = self.latent_dim,
image_size = self.image_size,
ttur_mult = self.ttur_mult,
fmap_max = self.fmap_max,
Expand All @@ -532,20 +622,18 @@ def write_config(self):
def load_config(self):
config = self.config() if not self.config_path.exists() else json.loads(self.config_path.read_text())
self.image_size = config['image_size']
self.network_capacity = config['network_capacity']
self.transparent = config['transparent']
self.fmap_max = config.pop('fmap_max', 512)
del self.GAN
self.init_GAN()

def config(self):
return {'image_size': self.image_size, 'network_capacity': self.network_capacity, 'transparent': self.transparent, 'fq_layers': self.fq_layers, 'fq_dict_size': self.fq_dict_size, 'attn_layers': self.attn_layers, 'no_const': self.no_const}
return {'image_size': self.image_size, 'transparent': self.transparent}

def set_data_src(self, folder):
self.dataset = Dataset(folder, self.image_size, transparent = self.transparent, aug_prob = self.dataset_aug_prob)
num_workers = num_workers = default(self.num_workers, num_cores)
self.dataset = ImageDataset(folder, self.image_size, transparent = self.transparent, aug_prob = self.dataset_aug_prob)
sampler = DistributedSampler(self.dataset, rank=self.rank, num_replicas=self.world_size, shuffle=True) if self.is_ddp else None
dataloader = data.DataLoader(self.dataset, num_workers = num_workers, batch_size = math.ceil(self.batch_size / self.world_size), sampler = sampler, shuffle = not self.is_ddp, drop_last = True, pin_memory = True)
dataloader = DataLoader(self.dataset, num_workers = NUM_CORES, batch_size = math.ceil(self.batch_size / self.world_size), sampler = sampler, shuffle = not self.is_ddp, drop_last = True, pin_memory = True)
self.loader = cycle(dataloader)

def train(self):
Expand Down Expand Up @@ -575,33 +663,36 @@ def train(self):

# train discriminator

avg_pl_length = self.pl_mean
self.GAN.D_opt.zero_grad()

for i in gradient_accumulate_contexts(self.gradient_accumulate_every, self.is_ddp, ddps=[D_aug, G]):
latents = torch.randn(batch_size, latent_dim).cuda(self.rank)

generated_images = G(w_styles)
fake_output, fake_q_loss = D_aug(generated_images.clone().detach(), detach = True, **aug_kwargs)
generated_images = G(latents)
fake_output, fake_aux_loss = D_aug(generated_images.clone().detach(), detach = True, **aug_kwargs)

image_batch = next(self.loader).cuda(self.rank)
image_batch.requires_grad_()
real_output, real_q_loss = D_aug(image_batch, **aug_kwargs)
real_output, real_aux_loss = D_aug(image_batch, **aug_kwargs)

real_output_loss = real_output
fake_output_loss = fake_output

divergence = (F.relu(1 + real_output_loss) + F.relu(1 - fake_output_loss)).mean()
disc_loss = divergence

aux_loss = fake_aux_loss + real_aux_loss
self.last_recon_loss = aux_loss.clone().detach().item()
disc_loss = disc_loss + aux_loss

if apply_gradient_penalty:
gp = gradient_penalty(image_batch, (real_output,))
self.last_gp_loss = gp.clone().detach().item()
disc_loss = disc_loss + gp

disc_loss = disc_loss / self.gradient_accumulate_every
disc_loss.register_hook(raise_if_nan)
disc_loss.backwards()
disc_loss.backward()

total_disc_loss += divergence.detach().item() / self.gradient_accumulate_every

Expand All @@ -624,7 +715,7 @@ def train(self):

gen_loss = gen_loss / self.gradient_accumulate_every
gen_loss.register_hook(raise_if_nan)
gen_loss.backwards()
gen_loss.backward()

total_gen_loss += loss.detach().item() / self.gradient_accumulate_every

Expand Down Expand Up @@ -682,12 +773,12 @@ def evaluate(self, num = 0, num_image_tiles = 8, trunc = 1.0):

# regular

generated_images = self.generate_truncated(self.GAN.G, latents, trunc_psi = self.trunc_psi)
generated_images = self.generate_truncated(self.GAN.G, latents)
torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}.{ext}'), nrow=num_rows)

# moving averages

generated_images = self.generate_truncated(self.GAN.GE, latents, trunc_psi = self.trunc_psi)
generated_images = self.generate_truncated(self.GAN.GE, latents)
torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-ema.{ext}'), nrow=num_rows)

@torch.no_grad()
Expand Down Expand Up @@ -720,16 +811,16 @@ def calculate_fid(self, num_batches):
latents = torch.randn(self.batch_size, latent_dim).cuda(self.rank)

# moving averages
generated_images = self.generate_truncated(self.GAN.GE, latents, trunc_psi = self.trunc_psi)
generated_images = self.generate_truncated(self.GAN.GE, latents)

for j in range(generated_images.size(0)):
torchvision.utils.save_image(generated_images[j, :, :, :], str(Path(fake_path) / f'{str(j + batch_num * self.batch_size)}-ema.{ext}'))

return fid_score.calculate_fid_given_paths([real_path, fake_path], 256, True, 2048)

@torch.no_grad()
def generate_truncated(self, G, style, noi, trunc_psi = 0.75, num_image_tiles = 8):
generated_images = evaluate_in_chunks(self.batch_size, G, style, noi)
def generate_truncated(self, G, style, trunc_psi = 0.75, num_image_tiles = 8):
generated_images = evaluate_in_chunks(self.batch_size, G, style)
return generated_images

@torch.no_grad()
Expand All @@ -751,8 +842,7 @@ def generate_interpolation(self, num = 0, num_image_tiles = 8, trunc = 1.0, num_
frames = []
for ratio in tqdm(ratios):
interp_latents = slerp(ratio, latents_low, latents_high)
latents = [(interp_latents, num_layers)]
generated_images = self.generate_truncated(self.GAN.GE, latents, trunc_psi = self.trunc_psi)
generated_images = self.generate_truncated(self.GAN.GE, interp_latents)
images_grid = torchvision.utils.make_grid(generated_images, nrow = num_rows)
pil_image = transforms.ToPILImage()(images_grid.cpu())

Expand All @@ -775,6 +865,7 @@ def print_log(self):
('G', self.g_loss),
('D', self.d_loss),
('GP', self.last_gp_loss),
('SS', self.last_recon_loss),
('FID', self.last_fid)
]

Expand Down
2 changes: 1 addition & 1 deletion lightweight_gan/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.8'
__version__ = '0.1.0'

0 comments on commit 3cd5d84

Please sign in to comment.