From c289d484f10bbaf76c0c8485220bd58d5cfceded Mon Sep 17 00:00:00 2001 From: Hosein Hashemi Date: Sun, 16 Aug 2020 19:14:03 +0200 Subject: [PATCH 1/2] Create LO-WGAN-gp.py --- implementations/LO-WGAN-gp/LO-WGAN-gp.py | 258 +++++++++++++++++++++++ 1 file changed, 258 insertions(+) create mode 100644 implementations/LO-WGAN-gp/LO-WGAN-gp.py diff --git a/implementations/LO-WGAN-gp/LO-WGAN-gp.py b/implementations/LO-WGAN-gp/LO-WGAN-gp.py new file mode 100644 index 00000000..a9ba956e --- /dev/null +++ b/implementations/LO-WGAN-gp/LO-WGAN-gp.py @@ -0,0 +1,258 @@ +import argparse +import os +import numpy as np +import math +import sys + +import torchvision.transforms as transforms +from torchvision.utils import save_image + +from torch.utils.data import DataLoader +from torchvision import datasets +from torch.autograd import Variable + +import torch.nn as nn +import torch.nn.functional as F +import torch.autograd as autograd +import torch + +os.makedirs("images", exist_ok=True) + + +parser = argparse.ArgumentParser() +parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training") +parser.add_argument("--batch_size", type=int, default=64, help="size of the batches") +parser.add_argument("--lr", type=float, default=0.00005, help="learning rate") +parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") +parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") +parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation") +parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space") +parser.add_argument("--latent_method", type=str, default="ngd", help="The latent optimization method") +parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension") +parser.add_argument("--channels", type=int, default=1, help="number of image channels") +parser.add_argument("--n_critic", type=int, default=5, help="number of training steps for discriminator per iter") +parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples") +opt = parser.parse_args() +print(opt) + +img_shape = (opt.channels, opt.img_size, opt.img_size) + +cuda = True if torch.cuda.is_available() else False + + +class Generator(nn.Module): + def __init__(self): + super(Generator, self).__init__() + + def block(in_feat, out_feat, normalize=True): + layers = [nn.Linear(in_feat, out_feat)] + if normalize: + layers.append(nn.BatchNorm1d(out_feat, 0.8)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + return layers + + self.model = nn.Sequential( + *block(opt.latent_dim, 128, normalize=False), + *block(128, 256), + *block(256, 512), + *block(512, 1024), + nn.Linear(1024, int(np.prod(img_shape))), + nn.Tanh() + ) + + def forward(self, z): + img = self.model(z) + img = img.view(img.shape[0], *img_shape) + return img + + +class Discriminator(nn.Module): + def __init__(self): + super(Discriminator, self).__init__() + + self.model = nn.Sequential( + nn.Linear(int(np.prod(img_shape)), 512), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(512, 256), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(256, 1), + ) + + def forward(self, img): + img_flat = img.view(img.shape[0], -1) + validity = self.model(img_flat) + return validity + + +# Loss weight for gradient penalty +lambda_gp = 10 + +# Initialize generator and discriminator +generator = Generator() +discriminator = Discriminator() + +if cuda: + generator.cuda() + discriminator.cuda() + +# Configure data loader +os.makedirs("../../data/mnist", exist_ok=True) +dataloader = torch.utils.data.DataLoader( + datasets.MNIST( + "../../data/mnist", + train=True, + download=True, + transform=transforms.Compose( + [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] + ), + ), + batch_size=opt.batch_size, + shuffle=True, +) + +# Optimizers +optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) +optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) + +Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor + +#Latent optimization + +def latent_opt(Gen, Dis, z, method , batch_size, alpha= 0.9, beta= 0.1): + method = method.lower() + + #Using gradient descent + if method == "gd": + + fake = Gen(z) + f_z = Dis(dake.view(batch_size, opt.channels, opt.img_size, opt.img_size)) + + d_fz = torch.autograd.grad(outputs=f_z, + inputs= z, + grad_outputs=torch.ones_like(f_z), + retain_graph=True, + create_graph= True + )[0] + + delta_z = torch.ones_like(d_fz) + delta_z = alpha * d_fz + + with torch.no_grad(): + z_prime = torch.clamp(z + delta_z, min=-1, max=1) + + return z_prime + #Using natural gradient descent + elif method == "ngd": + fake = Gen(z) + f_z = Dis(fake.view(batch_size, opt.channels, opt.img_size, opt.img_size)) + + d_fz = torch.autograd.grad(outputs=f_z, + inputs= z, + grad_outputs=torch.ones_like(f_z), + retain_graph=True, + create_graph= True + )[0] + + delta_z = torch.ones_like(d_fz) + delta_z = (alpha * d_fz) / (beta + torch.norm(delta_z, p=2, dim=0)) + with torch.no_grad(): + z_prime = torch.clamp(z + delta_z, min=-1, max=1) + + return z_prime + +def compute_gradient_penalty(D, real_samples, fake_samples): + """Calculates the gradient penalty loss for WGAN GP""" + # Random weight term for interpolation between real and fake samples + alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1))) + # Get random interpolation between real and fake samples + interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) + d_interpolates = D(interpolates) + fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False) + # Get gradient w.r.t. interpolates + gradients = autograd.grad( + outputs=d_interpolates, + inputs=interpolates, + grad_outputs=fake, + create_graph=True, + retain_graph=True, + only_inputs=True, + )[0] + gradients = gradients.view(gradients.size(0), -1) + gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() + return gradient_penalty + + + +# ---------- +# Training +# ---------- + +batches_done = 0 +for epoch in range(opt.n_epochs): + + for i, (imgs, _) in enumerate(dataloader): + + # Configure input + real_imgs = Variable(imgs.type(Tensor)) + + # --------------------- + # latent optimization step + # --------------------- + + optimizer_G.zero_grad(), optimizer_D.zero_grad() + + # Sample optimized noise as an input for the generator + z = Variable(Tensor(np.random.uniform(-1, 1, (imgs.shape[0], opt.latent_dim))), requires_grad=True) + z_prime = latent_opt(generator, discriminator, opt.latent_method, z, imgs.shape[0]) + + # --------------------- + # Train Discriminator + # --------------------- + + optimizer_D.zero_grad() + + + # Generate a batch of images + fake_imgs = generator(z_prime) + + # Real images + real_validity = discriminator(real_imgs) + # Fake images + fake_validity = discriminator(fake_imgs) + # Gradient penalty + gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data) + # Adversarial loss + d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty + + d_loss.backward() + optimizer_D.step() + + optimizer_G.zero_grad() + + # Train the generator every n_critic steps + if i % opt.n_critic == 0: + + + # ----------------- + # Train Generator + # ----------------- + + # Generate a batch of images + fake_imgs = generator(z_prime) + # Loss measures generator's ability to fool the discriminator + # Train on fake images + fake_validity = discriminator(fake_imgs) + g_loss = -torch.mean(fake_validity) + + g_loss.backward() + optimizer_G.step() + + print( + "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" + % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()) + ) + + if batches_done % opt.sample_interval == 0: + save_image(fake_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True) + + batches_done += opt.n_critic From 8e3910382a369e894dbae4a84e11fb29a7b81a74 Mon Sep 17 00:00:00 2001 From: Hosein Hashemi Date: Sun, 16 Aug 2020 19:27:02 +0200 Subject: [PATCH 2/2] Rename LO-WGAN-gp.py to lo-wgan-gp.py --- implementations/LO-WGAN-gp/{LO-WGAN-gp.py => lo-wgan-gp.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename implementations/LO-WGAN-gp/{LO-WGAN-gp.py => lo-wgan-gp.py} (100%) diff --git a/implementations/LO-WGAN-gp/LO-WGAN-gp.py b/implementations/LO-WGAN-gp/lo-wgan-gp.py similarity index 100% rename from implementations/LO-WGAN-gp/LO-WGAN-gp.py rename to implementations/LO-WGAN-gp/lo-wgan-gp.py