diff --git a/README.md b/README.md index ed5641f..0599298 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,19 @@ -## 'Lightweight' GAN (wip) +## 'Lightweight' GAN Implementation of an extremely 'lightweight' GAN proposed in ICLR 2021, in Pytorch. The main contributions of the paper is a skip-layer excitation in the generator, paired with autoencoding self-supervised learning in the discriminator. Quoting the one-line summary "converge on single gpu with few hours' training, on 1024 resolution sub-hundred images". +## Install + +```bash +$ pip install lightweight-gan +``` + +## Usage + +```bash +$ lightweight_gan --data ./path/to/images --image-size 512 +``` + ## Citations ```bibtex diff --git a/lightweight_gan/lightweight_gan.py b/lightweight_gan/lightweight_gan.py index 0448c0d..8991612 100644 --- a/lightweight_gan/lightweight_gan.py +++ b/lightweight_gan/lightweight_gan.py @@ -1,20 +1,28 @@ +import json import multiprocessing from random import random +import math from math import log2, floor from functools import partial +from contextlib import contextmanager +from pathlib import Path +from shutil import rmtree import torch from torch.optim import Adam from torch import nn, einsum import torch.nn.functional as F -from torch.utils.data import Dataset +from torch.utils.data import Dataset, DataLoader +from torch.autograd import grad as torch_grad +from PIL import Image import torchvision from torchvision import transforms from lightweight_gan.diff_augment import DiffAugment from lightweight_gan.version import __version__ +from tqdm import tqdm from einops import rearrange from pytorch_fid import fid_score @@ -34,6 +42,10 @@ def exists(val): return val is not None +@contextmanager +def null_context(): + yield + def default(val, d): return val if exists(val) else d @@ -59,6 +71,34 @@ def gradient_penalty(images, outputs, weight = 10): gradients = gradients.reshape(batch_size, -1) return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean() +def gradient_accumulate_contexts(gradient_accumulate_every, is_ddp, ddps): + if is_ddp: + num_no_syncs = gradient_accumulate_every - 1 + head = [combine_contexts(map(lambda ddp: ddp.no_sync, ddps))] * num_no_syncs + tail = [null_context] + contexts = head + tail + else: + contexts = [null_context] * gradient_accumulate_every + + for context in contexts: + with context(): + yield + +def evaluate_in_chunks(max_batch_size, model, *args): + split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args)))) + chunked_outputs = [model(*i) for i in split_args] + if len(chunked_outputs) == 1: + return chunked_outputs[0] + return torch.cat(chunked_outputs, dim=0) + +def slerp(val, low, high): + low_norm = low / torch.norm(low, dim=1, keepdim=True) + high_norm = high / torch.norm(high, dim=1, keepdim=True) + omega = torch.acos((low_norm * high_norm).sum(1)) + so = torch.sin(omega) + res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high + return res + # helper classes class NanException(Exception): @@ -73,9 +113,19 @@ def update_average(self, old, new): return new return old * self.beta + (1 - self.beta) * new +class RandomApply(nn.Module): + def __init__(self, prob, fn, fn_else = lambda x: x): + super().__init__() + self.fn = fn + self.fn_else = fn_else + self.prob = prob + def forward(self, x): + fn = self.fn if random() < self.prob else self.fn_else + return fn(x) + # dataset -def convert_image_to(type_name, image): +def convert_image_to(img_type, image): if image.mode != img_type: return image.convert(img_type) return image @@ -118,7 +168,7 @@ def __init__(self, folder, image_size, transparent = False, aug_prob = 0.): self.paths = [p for ext in EXTS for p in Path(f'{folder}').glob(f'**/*.{ext}')] assert len(self.paths) > 0, f'No images were found in {folder} for training' - convert_image_fn = partial(convert_image_to, 'RGBA' if transparent else 'RBG') + convert_image_fn = partial(convert_image_to, 'RGBA' if transparent else 'RGB') num_channels = 3 if not transparent else 4 self.transform = transforms.Compose([ @@ -138,6 +188,28 @@ def __getitem__(self, index): img = Image.open(path) return self.transform(img) +# augmentations + +def random_hflip(tensor, prob): + if prob > random(): + return tensor + return torch.flip(tensor, dims=(3,)) + +class AugWrapper(nn.Module): + def __init__(self, D, image_size): + super().__init__() + self.D = D + + def forward(self, images, prob = 0., types = [], detach = False): + if random() < prob: + images = random_hflip(images, prob=0.5) + images = DiffAugment(images, types=types) + + if detach: + images = images.detach() + + return self.D(images) + # classes class SLE(nn.Module): @@ -288,6 +360,11 @@ def __init__( features = list(map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), range(8, 2, -1))) features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features)) + + if num_non_residual_layers == 0: + res, _ = features[0] + features[0] = (res, init_channel) + chan_in_out = zip(features[:-1], features[1:]) self.non_residual_layers = nn.ModuleList([]) @@ -383,7 +460,8 @@ def __init__( fmap_inverse_coef = 12, transparent = False, ttur_mult = 1.5, - lr = 2e-4 + lr = 2e-4, + rank = 0 ): super().__init__() self.latent_dim = latent_dim @@ -397,11 +475,20 @@ def __init__( set_requires_grad(self.GE, False) self.D = Discriminator(image_size = image_size, fmap_max = fmap_max, fmap_inverse_coef = fmap_inverse_coef, transparent = transparent) + self.D_aug = AugWrapper(self.D, image_size) self.G_opt = Adam(self.G.parameters(), lr = lr, betas=(0.5, 0.9)) self.D_opt = Adam(self.D.parameters(), lr = lr * ttur_mult, betas=(0.5, 0.9)) - self.cuda() + self._init_weights() + self.reset_parameter_averaging() + + self.cuda(rank) + + def _init_weights(self): + for m in self.modules(): + if type(m) in {nn.Conv2d, nn.Linear}: + nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') def EMA(self): def update_moving_average(ma_model, current_model): @@ -426,8 +513,8 @@ def __init__( results_dir = 'results', models_dir = 'models', base_dir = './', + latent_dim = 256, image_size = 128, - network_capacity = 16, fmap_max = 512, transparent = False, batch_size = 4, @@ -462,6 +549,7 @@ def __init__( self.config_path = self.models_dir / name / '.config.json' assert log2(image_size).is_integer(), 'image size must be a power of 2 (64, 128, 256, 512, 1024)' + self.latent_dim = latent_dim self.image_size = image_size self.fmap_max = fmap_max self.transparent = transparent @@ -483,6 +571,7 @@ def __init__( self.d_loss = 0 self.g_loss = 0 self.last_gp_loss = None + self.last_recon_loss = None self.last_fid = None self.init_folders() @@ -510,6 +599,7 @@ def init_GAN(self): self.GAN = LightweightGAN( lr = self.lr, + latent_dim = self.latent_dim, image_size = self.image_size, ttur_mult = self.ttur_mult, fmap_max = self.fmap_max, @@ -532,20 +622,18 @@ def write_config(self): def load_config(self): config = self.config() if not self.config_path.exists() else json.loads(self.config_path.read_text()) self.image_size = config['image_size'] - self.network_capacity = config['network_capacity'] self.transparent = config['transparent'] self.fmap_max = config.pop('fmap_max', 512) del self.GAN self.init_GAN() def config(self): - return {'image_size': self.image_size, 'network_capacity': self.network_capacity, 'transparent': self.transparent, 'fq_layers': self.fq_layers, 'fq_dict_size': self.fq_dict_size, 'attn_layers': self.attn_layers, 'no_const': self.no_const} + return {'image_size': self.image_size, 'transparent': self.transparent} def set_data_src(self, folder): - self.dataset = Dataset(folder, self.image_size, transparent = self.transparent, aug_prob = self.dataset_aug_prob) - num_workers = num_workers = default(self.num_workers, num_cores) + self.dataset = ImageDataset(folder, self.image_size, transparent = self.transparent, aug_prob = self.dataset_aug_prob) sampler = DistributedSampler(self.dataset, rank=self.rank, num_replicas=self.world_size, shuffle=True) if self.is_ddp else None - dataloader = data.DataLoader(self.dataset, num_workers = num_workers, batch_size = math.ceil(self.batch_size / self.world_size), sampler = sampler, shuffle = not self.is_ddp, drop_last = True, pin_memory = True) + dataloader = DataLoader(self.dataset, num_workers = NUM_CORES, batch_size = math.ceil(self.batch_size / self.world_size), sampler = sampler, shuffle = not self.is_ddp, drop_last = True, pin_memory = True) self.loader = cycle(dataloader) def train(self): @@ -575,18 +663,17 @@ def train(self): # train discriminator - avg_pl_length = self.pl_mean self.GAN.D_opt.zero_grad() for i in gradient_accumulate_contexts(self.gradient_accumulate_every, self.is_ddp, ddps=[D_aug, G]): latents = torch.randn(batch_size, latent_dim).cuda(self.rank) - generated_images = G(w_styles) - fake_output, fake_q_loss = D_aug(generated_images.clone().detach(), detach = True, **aug_kwargs) + generated_images = G(latents) + fake_output, fake_aux_loss = D_aug(generated_images.clone().detach(), detach = True, **aug_kwargs) image_batch = next(self.loader).cuda(self.rank) image_batch.requires_grad_() - real_output, real_q_loss = D_aug(image_batch, **aug_kwargs) + real_output, real_aux_loss = D_aug(image_batch, **aug_kwargs) real_output_loss = real_output fake_output_loss = fake_output @@ -594,6 +681,10 @@ def train(self): divergence = (F.relu(1 + real_output_loss) + F.relu(1 - fake_output_loss)).mean() disc_loss = divergence + aux_loss = fake_aux_loss + real_aux_loss + self.last_recon_loss = aux_loss.clone().detach().item() + disc_loss = disc_loss + aux_loss + if apply_gradient_penalty: gp = gradient_penalty(image_batch, (real_output,)) self.last_gp_loss = gp.clone().detach().item() @@ -601,7 +692,7 @@ def train(self): disc_loss = disc_loss / self.gradient_accumulate_every disc_loss.register_hook(raise_if_nan) - disc_loss.backwards() + disc_loss.backward() total_disc_loss += divergence.detach().item() / self.gradient_accumulate_every @@ -624,7 +715,7 @@ def train(self): gen_loss = gen_loss / self.gradient_accumulate_every gen_loss.register_hook(raise_if_nan) - gen_loss.backwards() + gen_loss.backward() total_gen_loss += loss.detach().item() / self.gradient_accumulate_every @@ -682,12 +773,12 @@ def evaluate(self, num = 0, num_image_tiles = 8, trunc = 1.0): # regular - generated_images = self.generate_truncated(self.GAN.G, latents, trunc_psi = self.trunc_psi) + generated_images = self.generate_truncated(self.GAN.G, latents) torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}.{ext}'), nrow=num_rows) # moving averages - generated_images = self.generate_truncated(self.GAN.GE, latents, trunc_psi = self.trunc_psi) + generated_images = self.generate_truncated(self.GAN.GE, latents) torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-ema.{ext}'), nrow=num_rows) @torch.no_grad() @@ -720,7 +811,7 @@ def calculate_fid(self, num_batches): latents = torch.randn(self.batch_size, latent_dim).cuda(self.rank) # moving averages - generated_images = self.generate_truncated(self.GAN.GE, latents, trunc_psi = self.trunc_psi) + generated_images = self.generate_truncated(self.GAN.GE, latents) for j in range(generated_images.size(0)): torchvision.utils.save_image(generated_images[j, :, :, :], str(Path(fake_path) / f'{str(j + batch_num * self.batch_size)}-ema.{ext}')) @@ -728,8 +819,8 @@ def calculate_fid(self, num_batches): return fid_score.calculate_fid_given_paths([real_path, fake_path], 256, True, 2048) @torch.no_grad() - def generate_truncated(self, G, style, noi, trunc_psi = 0.75, num_image_tiles = 8): - generated_images = evaluate_in_chunks(self.batch_size, G, style, noi) + def generate_truncated(self, G, style, trunc_psi = 0.75, num_image_tiles = 8): + generated_images = evaluate_in_chunks(self.batch_size, G, style) return generated_images @torch.no_grad() @@ -751,8 +842,7 @@ def generate_interpolation(self, num = 0, num_image_tiles = 8, trunc = 1.0, num_ frames = [] for ratio in tqdm(ratios): interp_latents = slerp(ratio, latents_low, latents_high) - latents = [(interp_latents, num_layers)] - generated_images = self.generate_truncated(self.GAN.GE, latents, trunc_psi = self.trunc_psi) + generated_images = self.generate_truncated(self.GAN.GE, interp_latents) images_grid = torchvision.utils.make_grid(generated_images, nrow = num_rows) pil_image = transforms.ToPILImage()(images_grid.cpu()) @@ -775,6 +865,7 @@ def print_log(self): ('G', self.g_loss), ('D', self.d_loss), ('GP', self.last_gp_loss), + ('SS', self.last_recon_loss), ('FID', self.last_fid) ] diff --git a/lightweight_gan/version.py b/lightweight_gan/version.py index 9123cf0..b794fd4 100644 --- a/lightweight_gan/version.py +++ b/lightweight_gan/version.py @@ -1 +1 @@ -__version__ = '0.0.8' +__version__ = '0.1.0'