-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
639 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
__pycache__ | ||
.vscode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,36 @@ | ||
# BigGAN-Generator-Pretrained | ||
# BigGAN Generators with Pretrained Weights in Pytorch | ||
Pytorch implementation of the generator of Large Scale GAN Training for High Fidelity Natural Image Synthesis (BigGAN). | ||
|
||
# Download Pretrained Weights | ||
|
||
|
||
# Demo | ||
```shell | ||
python demo.py -w <PRETRAINED_WEIGHT_PATH> [-s IMAGE_SIZE] [-c CLASS_LABEL] [-t TRUNCATION] | ||
python demo.py -w ./biggan512-release.pt -s 512 -t 0.3 -c 156 | ||
python demo.py -w ./biggan256-release.pt -s 256 -t 0.02 -c 11 | ||
python demo.py --pretrained_weight ./biggan128-release.pt --size 128 --truncation 0.2 --class_label 821 | ||
``` | ||
|
||
# Results | ||
|![alt text](./assets/p1.png)| | ||
|:--:| | ||
|*class 156 (512 x 512)*| | ||
|![alt text](./assets/p2.png)| | ||
|*class 11 (512 x 512)*| | ||
|![alt text](./assets/p3.png)| | ||
|*class 821 (512 x 512)*| | ||
|
||
# Dependencies | ||
Please refer to the environment.yml file. | ||
|
||
# Pretrained Weights | ||
The pretrained weights are converted from the tensorflow hub modules: | ||
- https://tfhub.dev/deepmind/biggan-128/2 | ||
- https://tfhub.dev/deepmind/biggan-256/2 | ||
- https://tfhub.dev/deepmind/biggan-512/2 | ||
|
||
|
||
# References | ||
paper: https://arxiv.org/abs/1809.11096 | ||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from src.biggan import BigGAN128 | ||
from src.biggan import BigGAN256 | ||
from src.biggan import BigGAN512 | ||
|
||
import torch | ||
import torchvision | ||
|
||
from scipy.stats import truncnorm | ||
|
||
import argparse | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('-t', '--truncation', type=float, default=0.4) | ||
parser.add_argument('-s', '--size', type=int, choices=[128, 256, 512], default=512) | ||
parser.add_argument('-c', '--class_label', type=int, choices=range(0, 1000), default=156) | ||
parser.add_argument('-w', '--pretrained_weight', type=str, required=True) | ||
args = parser.parse_args() | ||
|
||
truncation = torch.clamp(torch.tensor(args.truncation), min=0.02+1e-4, max=1.0-1e-4).float() | ||
c = torch.tensor((args.class_label,)).long() | ||
|
||
if args.size == 128: | ||
z = truncation * torch.as_tensor(truncnorm.rvs(-2.0, 2.0, size=(1, 120))).float() | ||
biggan = BigGAN128() | ||
elif args.size == 256: | ||
z = truncation * torch.as_tensor(truncnorm.rvs(-2.0, 2.0, size=(1, 140))).float() | ||
biggan = BigGAN256() | ||
elif args.size == 512: | ||
z = truncation * torch.as_tensor(truncnorm.rvs(-2.0, 2.0, size=(1, 128))).float() | ||
biggan = BigGAN512() | ||
|
||
biggan.load_state_dict(torch.load(args.pretrained_weight)) | ||
biggan.eval() | ||
with torch.no_grad(): | ||
img = biggan(z, c, truncation.item()) | ||
|
||
img = 0.5 * (img.data + 1) | ||
pil = torchvision.transforms.ToPILImage()(img.squeeze()) | ||
pil.show() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
channels: | ||
- pytorch | ||
- conda-forge | ||
- defaults | ||
dependencies: | ||
- python=3.6 | ||
- cudatoolkit=10.0 | ||
- pytorch | ||
- torchvision | ||
- scipy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from .spectral_normalization import spectral_norm | ||
|
||
import torch | ||
from torch import nn | ||
from torch.nn import functional as F | ||
|
||
|
||
class SelfAttention2d(nn.Module): | ||
def __init__(self, in_channels, c_bar, c_hat, eps=1e-4): | ||
super().__init__() | ||
self.theta = spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=c_bar, kernel_size=1, bias=False), eps=eps) | ||
self.phi = spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=c_bar, kernel_size=1, bias=False), eps=eps) | ||
self.g = spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=c_hat, kernel_size=1, bias=False), eps=eps) | ||
self.o_conv = spectral_norm(nn.Conv2d(in_channels=c_hat, out_channels=in_channels, kernel_size=1, bias=False), eps=eps) | ||
self.gamma = nn.Parameter(torch.zeros(1)) | ||
|
||
def forward(self, x): | ||
n, c, h, w = x.size() | ||
g_x = self.g(x) | ||
g_x = F.max_pool2d(g_x, kernel_size=2) | ||
g_x = g_x.view(n, -1, h*w//4) | ||
phi_x = self.phi(x) | ||
phi_x = F.max_pool2d(phi_x, kernel_size=2) | ||
phi_x = phi_x.view(n, -1, h*w//4) | ||
theta_x = self.theta(x) | ||
theta_x = theta_x.view(n, -1, h*w) | ||
attn = F.softmax(torch.bmm(theta_x.permute(0, 2, 1), phi_x), dim=-1) | ||
y = torch.bmm(g_x, attn.permute(0, 2, 1)) | ||
y = y.view(n, -1, h, w) | ||
o = self.o_conv(y) | ||
z = self.gamma * o + x | ||
return z |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from .spectral_normalization import spectral_norm | ||
|
||
import torch | ||
from torch import nn | ||
from torch.nn import functional as F | ||
|
||
class CrossReplicaBN2d(nn.BatchNorm2d): | ||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): | ||
super().__init__(num_features, eps, momentum, affine, track_running_stats) | ||
self.register_buffer('standing_means', torch.empty(50, num_features)) | ||
self.register_buffer('standing_vars', torch.empty(50, num_features)) | ||
|
||
@torch._jit_internal.weak_script_method | ||
def forward(self, input, truncation=1.0): | ||
self._check_input_dim(input) | ||
exponential_average_factor = 0.0 | ||
if self.training and self.track_running_stats: | ||
if self.num_batches_tracked is not None: | ||
self.num_batches_tracked += 1 | ||
if self.momentum is None: | ||
exponential_average_factor = 1.0 / float(self.num_batches_tracked) | ||
else: | ||
exponential_average_factor = self.momentum | ||
if not self.training: | ||
standing_mean = self.get_standing_stats(self.standing_means, truncation) | ||
standing_var = self.get_standing_stats(self.standing_vars, truncation) | ||
return F.batch_norm( | ||
input, standing_mean, standing_var, self.weight, self.bias, | ||
self.training or not self.track_running_stats, | ||
exponential_average_factor, self.eps | ||
) | ||
return F.batch_norm( | ||
input, self.running_mean, self.running_var, self.weight, self.bias, | ||
self.training or not self.track_running_stats, | ||
exponential_average_factor, self.eps | ||
) | ||
|
||
def get_standing_stats(self, stack, truncation): | ||
min = 0.02 - 1e-12 | ||
max = 1.0 + 1e-12 | ||
step = 0.02 | ||
assert(min <= truncation and truncation <= max) | ||
idx = round((truncation - step) / step) | ||
residual = truncation - idx * step | ||
alpha = round(residual / step, 2) | ||
ret = torch.sum(torch.cat((alpha*stack[idx:idx+1], (1.0-alpha)*stack[idx+1:idx+2])), dim=0) | ||
return ret | ||
|
||
class ScaledCrossReplicaBN2d(CrossReplicaBN2d): | ||
pass | ||
|
||
class HyperBN2d(nn.Module): | ||
def __init__(self, num_features, latent_dim, eps=1e-4): | ||
super().__init__() | ||
self.crossreplicabn = CrossReplicaBN2d(num_features=num_features, affine=False, eps=eps) | ||
self.gamma = spectral_norm(nn.Linear(in_features=latent_dim, out_features=num_features, bias=False), eps=eps) | ||
self.beta = spectral_norm(nn.Linear(in_features=latent_dim, out_features=num_features, bias=False), eps=eps) | ||
|
||
def forward(self, x, condition, truncation=1.0): | ||
return (self.gamma(condition).view(condition.size(0), -1, 1, 1) + 1) * self.crossreplicabn(x, truncation) + self.beta(condition).view(condition.size(0), -1, 1, 1) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from .spectral_normalization import spectral_norm | ||
from .batch_normalization import ScaledCrossReplicaBN2d | ||
from .attention import SelfAttention2d | ||
from .module import G_z | ||
from .module import GBlock | ||
|
||
import torch | ||
from torch import nn | ||
from torch.nn import functional as F | ||
|
||
class BigGAN128(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
ch = 96 | ||
self.linear = nn.Embedding(num_embeddings=1000, embedding_dim=128) | ||
self.g_z = G_z(in_features=20, out_features=4*4*16*ch, eps=1e-4) | ||
self.gblock = GBlock(in_channels=16*ch, out_channels=16*ch, kernel_size=3, stride=1, padding=1, latent_dim=148, eps=1e-4) | ||
self.gblock_1 = GBlock(in_channels=16*ch, out_channels=8*ch, kernel_size=3, stride=1, padding=1, latent_dim=148, eps=1e-4) | ||
self.gblock_2 = GBlock(in_channels=8*ch, out_channels=4*ch, kernel_size=3, stride=1, padding=1, latent_dim=148, eps=1e-4) | ||
self.gblock_3 = GBlock(in_channels=4*ch, out_channels=2*ch, kernel_size=3, stride=1, padding=1, latent_dim=148, eps=1e-4) | ||
self.attention = SelfAttention2d(in_channels=2*ch, c_bar=ch//4, c_hat=ch) | ||
self.gblock_4 = GBlock(in_channels=2*ch, out_channels=1*ch, kernel_size=3, stride=1, padding=1, latent_dim=148, eps=1e-4) | ||
self.scaledcrossreplicabn = ScaledCrossReplicaBN2d(num_features=ch, eps=1e-4) | ||
self.conv_2d = spectral_norm(nn.Conv2d(in_channels=ch, out_channels=3, kernel_size=3, stride=1, padding=1), eps=1e-4) | ||
|
||
def forward(self, z, c, truncation=1.0): | ||
z_gz, z_0, z_1, z_2, z_3, z_4 = torch.split(z, split_size_or_sections=20, dim=1) | ||
cond = self.linear(c) | ||
x = self.g_z(z_gz) | ||
x = x.view(z.size(0), 4, 4, -1).permute(0, 3, 1, 2) | ||
x = self.gblock(x, torch.cat((z_0, cond), dim=1), truncation) | ||
x = self.gblock_1(x, torch.cat((z_1, cond), dim=1), truncation) | ||
x = self.gblock_2(x, torch.cat((z_2, cond), dim=1), truncation) | ||
x = self.gblock_3(x, torch.cat((z_3, cond), dim=1), truncation) | ||
x = self.attention(x) | ||
x = self.gblock_4(x, torch.cat((z_4, cond), dim=1), truncation) | ||
x = self.scaledcrossreplicabn(x, truncation) | ||
x = torch.relu(x) | ||
x = self.conv_2d(x) | ||
x = torch.tanh(x) | ||
return x | ||
|
||
class BigGAN256(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
ch = 96 | ||
self.linear = nn.Embedding(num_embeddings=1000, embedding_dim=128) | ||
self.g_z = G_z(in_features=20, out_features=4*4*16*ch, eps=1e-4) | ||
self.gblock = GBlock(in_channels=16*ch, out_channels=16*ch, kernel_size=3, stride=1, padding=1, latent_dim=148, eps=1e-4) | ||
self.gblock_1 = GBlock(in_channels=16*ch, out_channels=8*ch, kernel_size=3, stride=1, padding=1, latent_dim=148, eps=1e-4) | ||
self.gblock_2 = GBlock(in_channels=8*ch, out_channels=8*ch, kernel_size=3, stride=1, padding=1, latent_dim=148, eps=1e-4) | ||
self.gblock_3 = GBlock(in_channels=8*ch, out_channels=4*ch, kernel_size=3, stride=1, padding=1, latent_dim=148, eps=1e-4) | ||
self.gblock_4 = GBlock(in_channels=4*ch, out_channels=2*ch, kernel_size=3, stride=1, padding=1, latent_dim=148, eps=1e-4) | ||
self.attention = SelfAttention2d(in_channels=2*ch, c_bar=ch//4, c_hat=ch) | ||
self.gblock_5 = GBlock(in_channels=2*ch, out_channels=1*ch, kernel_size=3, stride=1, padding=1, latent_dim=148, eps=1e-4) | ||
self.scaledcrossreplicabn = ScaledCrossReplicaBN2d(num_features=ch, eps=1e-4) | ||
self.conv_2d = spectral_norm(nn.Conv2d(in_channels=ch, out_channels=3, kernel_size=3, stride=1, padding=1), eps=1e-4) | ||
|
||
def forward(self, z, c, truncation=1.0): | ||
z_gz, z_0, z_1, z_2, z_3, z_4, z_5 = torch.split(z, split_size_or_sections=20, dim=1) | ||
cond = self.linear(c) | ||
x = self.g_z(z_gz) | ||
x = x.view(z.size(0), 4, 4, -1).permute(0, 3, 1, 2) | ||
x = self.gblock(x, torch.cat((z_0, cond), dim=1), truncation) | ||
x = self.gblock_1(x, torch.cat((z_1, cond), dim=1), truncation) | ||
x = self.gblock_2(x, torch.cat((z_2, cond), dim=1), truncation) | ||
x = self.gblock_3(x, torch.cat((z_3, cond), dim=1), truncation) | ||
x = self.gblock_4(x, torch.cat((z_4, cond), dim=1), truncation) | ||
x = self.attention(x) | ||
x = self.gblock_5(x, torch.cat((z_5, cond), dim=1), truncation) | ||
x = self.scaledcrossreplicabn(x, truncation) | ||
x = torch.relu(x) | ||
x = self.conv_2d(x) | ||
x = torch.tanh(x) | ||
return x | ||
|
||
class BigGAN512(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
ch = 96 | ||
self.linear = nn.Embedding(num_embeddings=1000, embedding_dim=128) | ||
self.g_z = G_z(in_features=16, out_features=4*4*16*ch, eps=1e-4) | ||
self.gblock = GBlock(in_channels=16*ch, out_channels=16*ch, kernel_size=3, stride=1, padding=1, latent_dim=144, eps=1e-4) | ||
self.gblock_1 = GBlock(in_channels=16*ch, out_channels=8*ch, kernel_size=3, stride=1, padding=1, latent_dim=144, eps=1e-4) | ||
self.gblock_2 = GBlock(in_channels=8*ch, out_channels=8*ch, kernel_size=3, stride=1, padding=1, latent_dim=144, eps=1e-4) | ||
self.gblock_3 = GBlock(in_channels=8*ch, out_channels=4*ch, kernel_size=3, stride=1, padding=1, latent_dim=144, eps=1e-4) | ||
self.attention = SelfAttention2d(in_channels=4*ch, c_bar=ch//2, c_hat=2*ch) | ||
self.gblock_4 = GBlock(in_channels=4*ch, out_channels=2*ch, kernel_size=3, stride=1, padding=1, latent_dim=144, eps=1e-4) | ||
self.gblock_5 = GBlock(in_channels=2*ch, out_channels=1*ch, kernel_size=3, stride=1, padding=1, latent_dim=144, eps=1e-4) | ||
self.gblock_6 = GBlock(in_channels=ch, out_channels=ch, kernel_size=3, stride=1, padding=1, latent_dim=144, eps=1e-4) | ||
self.scaledcrossreplicabn = ScaledCrossReplicaBN2d(num_features=ch, eps=1e-4) | ||
self.conv_2d = spectral_norm(nn.Conv2d(in_channels=ch, out_channels=3, kernel_size=3, stride=1, padding=1), eps=1e-4) | ||
|
||
def forward(self, z, c, truncation=1.0): | ||
z_gz, z_0, z_1, z_2, z_3, z_4, z_5, z_6 = torch.split(z, split_size_or_sections=16, dim=1) | ||
cond = self.linear(c) | ||
x = self.g_z(z_gz) | ||
x = x.view(z.size(0), 4, 4, -1).permute(0, 3, 1, 2) | ||
x = self.gblock(x, torch.cat((z_0, cond), dim=1), truncation) | ||
x = self.gblock_1(x, torch.cat((z_1, cond), dim=1), truncation) | ||
x = self.gblock_2(x, torch.cat((z_2, cond), dim=1), truncation) | ||
x = self.gblock_3(x, torch.cat((z_3, cond), dim=1), truncation) | ||
x = self.attention(x) | ||
x = self.gblock_4(x, torch.cat((z_4, cond), dim=1), truncation) | ||
x = self.gblock_5(x, torch.cat((z_5, cond), dim=1), truncation) | ||
x = self.gblock_6(x, torch.cat((z_6, cond), dim=1), truncation) | ||
x = self.scaledcrossreplicabn(x, truncation) | ||
x = torch.relu(x) | ||
x = self.conv_2d(x) | ||
x = torch.tanh(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from .spectral_normalization import spectral_norm | ||
from .batch_normalization import HyperBN2d | ||
from .util import depth_to_space | ||
|
||
import torch | ||
from torch import nn | ||
|
||
class G_z(nn.Module): | ||
def __init__(self, in_features, out_features, eps=1e-4): | ||
super().__init__() | ||
self.g_linear = spectral_norm(nn.Linear(in_features=in_features, out_features=out_features), eps=eps) | ||
|
||
def forward(self, x): | ||
x = self.g_linear(x) | ||
return x | ||
|
||
class GBlock(nn.Module): | ||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, latent_dim, eps=1e-4): | ||
super().__init__() | ||
self.conv0 = spectral_norm(nn.Conv2d( | ||
in_channels=in_channels, | ||
out_channels=out_channels, | ||
kernel_size=kernel_size, | ||
stride=stride, | ||
padding=padding | ||
), eps=eps) | ||
self.conv1 = spectral_norm(nn.Conv2d( | ||
in_channels=out_channels, | ||
out_channels=out_channels, | ||
kernel_size=kernel_size, | ||
stride=stride, | ||
padding=padding | ||
), eps=eps) | ||
self.conv_sc = spectral_norm(nn.Conv2d( | ||
in_channels=in_channels, | ||
out_channels=out_channels, | ||
kernel_size=1, | ||
stride=1, | ||
padding=0 | ||
), eps=eps) | ||
self.hyperbn = HyperBN2d(num_features=in_channels, latent_dim=latent_dim, eps=eps) | ||
self.hyperbn_1 = HyperBN2d(num_features=out_channels, latent_dim=latent_dim, eps=eps) | ||
|
||
def forward(self, x, condition, truncation=1.0): | ||
sc = torch.cat((x, x, x, x), dim=1) | ||
sc = depth_to_space(sc, r=2) | ||
sc = self.conv_sc(sc) | ||
x = self.hyperbn(x, condition, truncation) | ||
x = torch.relu(x) | ||
x = torch.cat((x, x, x, x), dim=1) | ||
x = depth_to_space(x, r=2) | ||
x = self.conv0(x) | ||
x = self.hyperbn_1(x, condition, truncation) | ||
x = torch.relu(x) | ||
x = self.conv1(x) | ||
x = sc + x | ||
return x |
Oops, something went wrong.