diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f7eafe5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__ +.vscode \ No newline at end of file diff --git a/README.md b/README.md index bdfe86b..28aa13a 100644 --- a/README.md +++ b/README.md @@ -1 +1,36 @@ -# BigGAN-Generator-Pretrained \ No newline at end of file +# BigGAN Generators with Pretrained Weights in Pytorch +Pytorch implementation of the generator of Large Scale GAN Training for High Fidelity Natural Image Synthesis (BigGAN). + +# Download Pretrained Weights + + +# Demo +```shell +python demo.py -w [-s IMAGE_SIZE] [-c CLASS_LABEL] [-t TRUNCATION] +python demo.py -w ./biggan512-release.pt -s 512 -t 0.3 -c 156 +python demo.py -w ./biggan256-release.pt -s 256 -t 0.02 -c 11 +python demo.py --pretrained_weight ./biggan128-release.pt --size 128 --truncation 0.2 --class_label 821 +``` + +# Results +|![alt text](./assets/p1.png)| +|:--:| +|*class 156 (512 x 512)*| +|![alt text](./assets/p2.png)| +|*class 11 (512 x 512)*| +|![alt text](./assets/p3.png)| +|*class 821 (512 x 512)*| + +# Dependencies +Please refer to the environment.yml file. + +# Pretrained Weights +The pretrained weights are converted from the tensorflow hub modules: +- https://tfhub.dev/deepmind/biggan-128/2 +- https://tfhub.dev/deepmind/biggan-256/2 +- https://tfhub.dev/deepmind/biggan-512/2 + + +# References +paper: https://arxiv.org/abs/1809.11096 + diff --git a/assets/p1.png b/assets/p1.png new file mode 100644 index 0000000..fb74591 Binary files /dev/null and b/assets/p1.png differ diff --git a/assets/p2.png b/assets/p2.png new file mode 100644 index 0000000..be72053 Binary files /dev/null and b/assets/p2.png differ diff --git a/assets/p3.png b/assets/p3.png new file mode 100644 index 0000000..2225f92 Binary files /dev/null and b/assets/p3.png differ diff --git a/demo.py b/demo.py new file mode 100644 index 0000000..939c68d --- /dev/null +++ b/demo.py @@ -0,0 +1,41 @@ +from src.biggan import BigGAN128 +from src.biggan import BigGAN256 +from src.biggan import BigGAN512 + +import torch +import torchvision + +from scipy.stats import truncnorm + +import argparse + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-t', '--truncation', type=float, default=0.4) + parser.add_argument('-s', '--size', type=int, choices=[128, 256, 512], default=512) + parser.add_argument('-c', '--class_label', type=int, choices=range(0, 1000), default=156) + parser.add_argument('-w', '--pretrained_weight', type=str, required=True) + args = parser.parse_args() + + truncation = torch.clamp(torch.tensor(args.truncation), min=0.02+1e-4, max=1.0-1e-4).float() + c = torch.tensor((args.class_label,)).long() + + if args.size == 128: + z = truncation * torch.as_tensor(truncnorm.rvs(-2.0, 2.0, size=(1, 120))).float() + biggan = BigGAN128() + elif args.size == 256: + z = truncation * torch.as_tensor(truncnorm.rvs(-2.0, 2.0, size=(1, 140))).float() + biggan = BigGAN256() + elif args.size == 512: + z = truncation * torch.as_tensor(truncnorm.rvs(-2.0, 2.0, size=(1, 128))).float() + biggan = BigGAN512() + + biggan.load_state_dict(torch.load(args.pretrained_weight)) + biggan.eval() + with torch.no_grad(): + img = biggan(z, c, truncation.item()) + + img = 0.5 * (img.data + 1) + pil = torchvision.transforms.ToPILImage()(img.squeeze()) + pil.show() + diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..0f93dde --- /dev/null +++ b/environment.yml @@ -0,0 +1,10 @@ +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - python=3.6 + - cudatoolkit=10.0 + - pytorch + - torchvision + - scipy diff --git a/src/attention.py b/src/attention.py new file mode 100644 index 0000000..19e9028 --- /dev/null +++ b/src/attention.py @@ -0,0 +1,32 @@ +from .spectral_normalization import spectral_norm + +import torch +from torch import nn +from torch.nn import functional as F + + +class SelfAttention2d(nn.Module): + def __init__(self, in_channels, c_bar, c_hat, eps=1e-4): + super().__init__() + self.theta = spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=c_bar, kernel_size=1, bias=False), eps=eps) + self.phi = spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=c_bar, kernel_size=1, bias=False), eps=eps) + self.g = spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=c_hat, kernel_size=1, bias=False), eps=eps) + self.o_conv = spectral_norm(nn.Conv2d(in_channels=c_hat, out_channels=in_channels, kernel_size=1, bias=False), eps=eps) + self.gamma = nn.Parameter(torch.zeros(1)) + + def forward(self, x): + n, c, h, w = x.size() + g_x = self.g(x) + g_x = F.max_pool2d(g_x, kernel_size=2) + g_x = g_x.view(n, -1, h*w//4) + phi_x = self.phi(x) + phi_x = F.max_pool2d(phi_x, kernel_size=2) + phi_x = phi_x.view(n, -1, h*w//4) + theta_x = self.theta(x) + theta_x = theta_x.view(n, -1, h*w) + attn = F.softmax(torch.bmm(theta_x.permute(0, 2, 1), phi_x), dim=-1) + y = torch.bmm(g_x, attn.permute(0, 2, 1)) + y = y.view(n, -1, h, w) + o = self.o_conv(y) + z = self.gamma * o + x + return z \ No newline at end of file diff --git a/src/batch_normalization.py b/src/batch_normalization.py new file mode 100644 index 0000000..0eaeb68 --- /dev/null +++ b/src/batch_normalization.py @@ -0,0 +1,61 @@ +from .spectral_normalization import spectral_norm + +import torch +from torch import nn +from torch.nn import functional as F + +class CrossReplicaBN2d(nn.BatchNorm2d): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): + super().__init__(num_features, eps, momentum, affine, track_running_stats) + self.register_buffer('standing_means', torch.empty(50, num_features)) + self.register_buffer('standing_vars', torch.empty(50, num_features)) + + @torch._jit_internal.weak_script_method + def forward(self, input, truncation=1.0): + self._check_input_dim(input) + exponential_average_factor = 0.0 + if self.training and self.track_running_stats: + if self.num_batches_tracked is not None: + self.num_batches_tracked += 1 + if self.momentum is None: + exponential_average_factor = 1.0 / float(self.num_batches_tracked) + else: + exponential_average_factor = self.momentum + if not self.training: + standing_mean = self.get_standing_stats(self.standing_means, truncation) + standing_var = self.get_standing_stats(self.standing_vars, truncation) + return F.batch_norm( + input, standing_mean, standing_var, self.weight, self.bias, + self.training or not self.track_running_stats, + exponential_average_factor, self.eps + ) + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training or not self.track_running_stats, + exponential_average_factor, self.eps + ) + + def get_standing_stats(self, stack, truncation): + min = 0.02 - 1e-12 + max = 1.0 + 1e-12 + step = 0.02 + assert(min <= truncation and truncation <= max) + idx = round((truncation - step) / step) + residual = truncation - idx * step + alpha = round(residual / step, 2) + ret = torch.sum(torch.cat((alpha*stack[idx:idx+1], (1.0-alpha)*stack[idx+1:idx+2])), dim=0) + return ret + +class ScaledCrossReplicaBN2d(CrossReplicaBN2d): + pass + +class HyperBN2d(nn.Module): + def __init__(self, num_features, latent_dim, eps=1e-4): + super().__init__() + self.crossreplicabn = CrossReplicaBN2d(num_features=num_features, affine=False, eps=eps) + self.gamma = spectral_norm(nn.Linear(in_features=latent_dim, out_features=num_features, bias=False), eps=eps) + self.beta = spectral_norm(nn.Linear(in_features=latent_dim, out_features=num_features, bias=False), eps=eps) + + def forward(self, x, condition, truncation=1.0): + return (self.gamma(condition).view(condition.size(0), -1, 1, 1) + 1) * self.crossreplicabn(x, truncation) + self.beta(condition).view(condition.size(0), -1, 1, 1) + diff --git a/src/biggan.py b/src/biggan.py new file mode 100644 index 0000000..8c6a941 --- /dev/null +++ b/src/biggan.py @@ -0,0 +1,111 @@ +from .spectral_normalization import spectral_norm +from .batch_normalization import ScaledCrossReplicaBN2d +from .attention import SelfAttention2d +from .module import G_z +from .module import GBlock + +import torch +from torch import nn +from torch.nn import functional as F + +class BigGAN128(nn.Module): + def __init__(self): + super().__init__() + ch = 96 + self.linear = nn.Embedding(num_embeddings=1000, embedding_dim=128) + self.g_z = G_z(in_features=20, out_features=4*4*16*ch, eps=1e-4) + self.gblock = GBlock(in_channels=16*ch, out_channels=16*ch, kernel_size=3, stride=1, padding=1, latent_dim=148, eps=1e-4) + self.gblock_1 = GBlock(in_channels=16*ch, out_channels=8*ch, kernel_size=3, stride=1, padding=1, latent_dim=148, eps=1e-4) + self.gblock_2 = GBlock(in_channels=8*ch, out_channels=4*ch, kernel_size=3, stride=1, padding=1, latent_dim=148, eps=1e-4) + self.gblock_3 = GBlock(in_channels=4*ch, out_channels=2*ch, kernel_size=3, stride=1, padding=1, latent_dim=148, eps=1e-4) + self.attention = SelfAttention2d(in_channels=2*ch, c_bar=ch//4, c_hat=ch) + self.gblock_4 = GBlock(in_channels=2*ch, out_channels=1*ch, kernel_size=3, stride=1, padding=1, latent_dim=148, eps=1e-4) + self.scaledcrossreplicabn = ScaledCrossReplicaBN2d(num_features=ch, eps=1e-4) + self.conv_2d = spectral_norm(nn.Conv2d(in_channels=ch, out_channels=3, kernel_size=3, stride=1, padding=1), eps=1e-4) + + def forward(self, z, c, truncation=1.0): + z_gz, z_0, z_1, z_2, z_3, z_4 = torch.split(z, split_size_or_sections=20, dim=1) + cond = self.linear(c) + x = self.g_z(z_gz) + x = x.view(z.size(0), 4, 4, -1).permute(0, 3, 1, 2) + x = self.gblock(x, torch.cat((z_0, cond), dim=1), truncation) + x = self.gblock_1(x, torch.cat((z_1, cond), dim=1), truncation) + x = self.gblock_2(x, torch.cat((z_2, cond), dim=1), truncation) + x = self.gblock_3(x, torch.cat((z_3, cond), dim=1), truncation) + x = self.attention(x) + x = self.gblock_4(x, torch.cat((z_4, cond), dim=1), truncation) + x = self.scaledcrossreplicabn(x, truncation) + x = torch.relu(x) + x = self.conv_2d(x) + x = torch.tanh(x) + return x + +class BigGAN256(nn.Module): + def __init__(self): + super().__init__() + ch = 96 + self.linear = nn.Embedding(num_embeddings=1000, embedding_dim=128) + self.g_z = G_z(in_features=20, out_features=4*4*16*ch, eps=1e-4) + self.gblock = GBlock(in_channels=16*ch, out_channels=16*ch, kernel_size=3, stride=1, padding=1, latent_dim=148, eps=1e-4) + self.gblock_1 = GBlock(in_channels=16*ch, out_channels=8*ch, kernel_size=3, stride=1, padding=1, latent_dim=148, eps=1e-4) + self.gblock_2 = GBlock(in_channels=8*ch, out_channels=8*ch, kernel_size=3, stride=1, padding=1, latent_dim=148, eps=1e-4) + self.gblock_3 = GBlock(in_channels=8*ch, out_channels=4*ch, kernel_size=3, stride=1, padding=1, latent_dim=148, eps=1e-4) + self.gblock_4 = GBlock(in_channels=4*ch, out_channels=2*ch, kernel_size=3, stride=1, padding=1, latent_dim=148, eps=1e-4) + self.attention = SelfAttention2d(in_channels=2*ch, c_bar=ch//4, c_hat=ch) + self.gblock_5 = GBlock(in_channels=2*ch, out_channels=1*ch, kernel_size=3, stride=1, padding=1, latent_dim=148, eps=1e-4) + self.scaledcrossreplicabn = ScaledCrossReplicaBN2d(num_features=ch, eps=1e-4) + self.conv_2d = spectral_norm(nn.Conv2d(in_channels=ch, out_channels=3, kernel_size=3, stride=1, padding=1), eps=1e-4) + + def forward(self, z, c, truncation=1.0): + z_gz, z_0, z_1, z_2, z_3, z_4, z_5 = torch.split(z, split_size_or_sections=20, dim=1) + cond = self.linear(c) + x = self.g_z(z_gz) + x = x.view(z.size(0), 4, 4, -1).permute(0, 3, 1, 2) + x = self.gblock(x, torch.cat((z_0, cond), dim=1), truncation) + x = self.gblock_1(x, torch.cat((z_1, cond), dim=1), truncation) + x = self.gblock_2(x, torch.cat((z_2, cond), dim=1), truncation) + x = self.gblock_3(x, torch.cat((z_3, cond), dim=1), truncation) + x = self.gblock_4(x, torch.cat((z_4, cond), dim=1), truncation) + x = self.attention(x) + x = self.gblock_5(x, torch.cat((z_5, cond), dim=1), truncation) + x = self.scaledcrossreplicabn(x, truncation) + x = torch.relu(x) + x = self.conv_2d(x) + x = torch.tanh(x) + return x + +class BigGAN512(nn.Module): + def __init__(self): + super().__init__() + ch = 96 + self.linear = nn.Embedding(num_embeddings=1000, embedding_dim=128) + self.g_z = G_z(in_features=16, out_features=4*4*16*ch, eps=1e-4) + self.gblock = GBlock(in_channels=16*ch, out_channels=16*ch, kernel_size=3, stride=1, padding=1, latent_dim=144, eps=1e-4) + self.gblock_1 = GBlock(in_channels=16*ch, out_channels=8*ch, kernel_size=3, stride=1, padding=1, latent_dim=144, eps=1e-4) + self.gblock_2 = GBlock(in_channels=8*ch, out_channels=8*ch, kernel_size=3, stride=1, padding=1, latent_dim=144, eps=1e-4) + self.gblock_3 = GBlock(in_channels=8*ch, out_channels=4*ch, kernel_size=3, stride=1, padding=1, latent_dim=144, eps=1e-4) + self.attention = SelfAttention2d(in_channels=4*ch, c_bar=ch//2, c_hat=2*ch) + self.gblock_4 = GBlock(in_channels=4*ch, out_channels=2*ch, kernel_size=3, stride=1, padding=1, latent_dim=144, eps=1e-4) + self.gblock_5 = GBlock(in_channels=2*ch, out_channels=1*ch, kernel_size=3, stride=1, padding=1, latent_dim=144, eps=1e-4) + self.gblock_6 = GBlock(in_channels=ch, out_channels=ch, kernel_size=3, stride=1, padding=1, latent_dim=144, eps=1e-4) + self.scaledcrossreplicabn = ScaledCrossReplicaBN2d(num_features=ch, eps=1e-4) + self.conv_2d = spectral_norm(nn.Conv2d(in_channels=ch, out_channels=3, kernel_size=3, stride=1, padding=1), eps=1e-4) + + def forward(self, z, c, truncation=1.0): + z_gz, z_0, z_1, z_2, z_3, z_4, z_5, z_6 = torch.split(z, split_size_or_sections=16, dim=1) + cond = self.linear(c) + x = self.g_z(z_gz) + x = x.view(z.size(0), 4, 4, -1).permute(0, 3, 1, 2) + x = self.gblock(x, torch.cat((z_0, cond), dim=1), truncation) + x = self.gblock_1(x, torch.cat((z_1, cond), dim=1), truncation) + x = self.gblock_2(x, torch.cat((z_2, cond), dim=1), truncation) + x = self.gblock_3(x, torch.cat((z_3, cond), dim=1), truncation) + x = self.attention(x) + x = self.gblock_4(x, torch.cat((z_4, cond), dim=1), truncation) + x = self.gblock_5(x, torch.cat((z_5, cond), dim=1), truncation) + x = self.gblock_6(x, torch.cat((z_6, cond), dim=1), truncation) + x = self.scaledcrossreplicabn(x, truncation) + x = torch.relu(x) + x = self.conv_2d(x) + x = torch.tanh(x) + return x diff --git a/src/module.py b/src/module.py new file mode 100644 index 0000000..4ad2130 --- /dev/null +++ b/src/module.py @@ -0,0 +1,57 @@ +from .spectral_normalization import spectral_norm +from .batch_normalization import HyperBN2d +from .util import depth_to_space + +import torch +from torch import nn + +class G_z(nn.Module): + def __init__(self, in_features, out_features, eps=1e-4): + super().__init__() + self.g_linear = spectral_norm(nn.Linear(in_features=in_features, out_features=out_features), eps=eps) + + def forward(self, x): + x = self.g_linear(x) + return x + +class GBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, latent_dim, eps=1e-4): + super().__init__() + self.conv0 = spectral_norm(nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding + ), eps=eps) + self.conv1 = spectral_norm(nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding + ), eps=eps) + self.conv_sc = spectral_norm(nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0 + ), eps=eps) + self.hyperbn = HyperBN2d(num_features=in_channels, latent_dim=latent_dim, eps=eps) + self.hyperbn_1 = HyperBN2d(num_features=out_channels, latent_dim=latent_dim, eps=eps) + + def forward(self, x, condition, truncation=1.0): + sc = torch.cat((x, x, x, x), dim=1) + sc = depth_to_space(sc, r=2) + sc = self.conv_sc(sc) + x = self.hyperbn(x, condition, truncation) + x = torch.relu(x) + x = torch.cat((x, x, x, x), dim=1) + x = depth_to_space(x, r=2) + x = self.conv0(x) + x = self.hyperbn_1(x, condition, truncation) + x = torch.relu(x) + x = self.conv1(x) + x = sc + x + return x diff --git a/src/spectral_normalization.py b/src/spectral_normalization.py new file mode 100644 index 0000000..41a5f46 --- /dev/null +++ b/src/spectral_normalization.py @@ -0,0 +1,261 @@ +""" +Spectral Normalization from https://arxiv.org/abs/1802.05957 +""" +import torch +from torch.nn.parameter import Parameter + +from .util import normalize + +class SpectralNorm(object): + # Invariant before and after each forward call: + # u = normalize(W @ v) + # NB: At initialization, this invariant is not enforced + + _version = 1 + # At version 1: + # made `W` not a buffer, + # added `v` as a buffer, and + # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`. + + def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12): + self.name = name + self.dim = dim + if n_power_iterations <= 0: + raise ValueError('Expected n_power_iterations to be positive, but ' + 'got n_power_iterations={}'.format(n_power_iterations)) + self.n_power_iterations = n_power_iterations + self.eps = eps + + def reshape_weight_to_matrix(self, weight): + if weight.dim() == 4: + return weight.permute(0, 2, 3, 1).reshape(weight.size(0), -1) + if weight.dim() == 2: + return weight + weight_mat = weight + if self.dim != 0: + # permute dim to front + weight_mat = weight_mat.permute(self.dim, *[d for d in range(weight_mat.dim()) if d != self.dim]) + height = weight_mat.size(0) + return weight_mat.reshape(height, -1) + + def compute_weight(self, module, do_power_iteration): + # NB: If `do_power_iteration` is set, the `u` and `v` vectors are + # updated in power iteration **in-place**. This is very important + # because in `DataParallel` forward, the vectors (being buffers) are + # broadcast from the parallelized module to each module replica, + # which is a new module object created on the fly. And each replica + # runs its own spectral norm power iteration. So simply assigning + # the updated vectors to the module this function runs on will cause + # the update to be lost forever. And the next time the parallelized + # module is replicated, the same randomly initialized vectors are + # broadcast and used! + # + # Therefore, to make the change propagate back, we rely on two + # important bahaviors (also enforced via tests): + # 1. `DataParallel` doesn't clone storage if the broadcast tensor + # is alreay on correct device; and it makes sure that the + # parallelized module is already on `device[0]`. + # 2. If the out tensor in `out=` kwarg has correct shape, it will + # just fill in the values. + # Therefore, since the same power iteration is performed on all + # devices, simply updating the tensors in-place will make sure that + # the module replica on `device[0]` will update the _u vector on the + # parallized module (by shared storage). + # + # However, after we update `u` and `v` in-place, we need to **clone** + # them before using them to normalize the weight. This is to support + # backproping through two forward passes, e.g., the common pattern in + # GAN training: loss = D(real) - D(fake). Otherwise, engine will + # complain that variables needed to do backward for the first forward + # (i.e., the `u` and `v` vectors) are changed in the second forward. + weight = getattr(module, self.name + '_orig') + u = getattr(module, self.name + '_u') + v = getattr(module, self.name + '_v') + weight_mat = self.reshape_weight_to_matrix(weight) + if do_power_iteration: + with torch.no_grad(): + for _ in range(self.n_power_iterations): + # Spectral norm of weight equals to `u^T W v`, where `u` and `v` + # are the first left and right singular vectors. + # This power iteration produces approximations of `u` and `v`. + v = normalize(torch.mv(weight_mat.t(), u), dim=0, eps=self.eps, out=v) + u = normalize(torch.mv(weight_mat, v), dim=0, eps=self.eps, out=u) + if self.n_power_iterations > 0: + # See above on why we need to clone + u = u.clone() + v = v.clone() + + sigma = torch.dot(u, torch.mv(weight_mat, v)) + weight = weight / sigma + return weight + + def remove(self, module): + with torch.no_grad(): + weight = self.compute_weight(module, do_power_iteration=False) + delattr(module, self.name) + delattr(module, self.name + '_u') + delattr(module, self.name + '_v') + delattr(module, self.name + '_orig') + module.register_parameter(self.name, torch.nn.Parameter(weight.detach())) + + def __call__(self, module, inputs): + setattr(module, self.name, self.compute_weight(module, do_power_iteration=module.training)) + + def _solve_v_and_rescale(self, weight_mat, u, target_sigma): + # Tries to returns a vector `v` s.t. `u = normalize(W @ v)` + # (the invariant at top of this class) and `u @ W @ v = sigma`. + # This uses pinverse in case W^T W is not invertible. + v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1)).squeeze(1) + return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v))) + + @staticmethod + def apply(module, name, n_power_iterations, dim, eps): + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, SpectralNorm) and hook.name == name: + raise RuntimeError("Cannot register two spectral_norm hooks on " + "the same parameter {}".format(name)) + + fn = SpectralNorm(name, n_power_iterations, dim, eps) + weight = module._parameters[name] + + with torch.no_grad(): + weight_mat = fn.reshape_weight_to_matrix(weight) + + h, w = weight_mat.size() + # randomly initialize `u` and `v` + u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps) + v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps) + + delattr(module, fn.name) + module.register_parameter(fn.name + "_orig", weight) + # We still need to assign weight back as fn.name because all sorts of + # things may assume that it exists, e.g., when initializing weights. + # However, we can't directly assign as it could be an nn.Parameter and + # gets added as a parameter. Instead, we register weight.data as a plain + # attribute. + setattr(module, fn.name, weight.data) + module.register_buffer(fn.name + "_u", u) + module.register_buffer(fn.name + "_v", v) + + module.register_forward_pre_hook(fn) + + module._register_state_dict_hook(SpectralNormStateDictHook(fn)) + module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn)) + return fn + + +# This is a top level class because Py2 pickle doesn't like inner class nor an +# instancemethod. +class SpectralNormLoadStateDictPreHook(object): + # See docstring of SpectralNorm._version on the changes to spectral_norm. + def __init__(self, fn): + self.fn = fn + + # For state_dict with version None, (assuming that it has gone through at + # least one training forward), we have + # + # u = normalize(W_orig @ v) + # W = W_orig / sigma, where sigma = u @ W_orig @ v + # + # To compute `v`, we solve `W_orig @ x = u`, and let + # v = x / (u @ W_orig @ x) * (W / W_orig). + def __call__(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + fn = self.fn + version = local_metadata.get('spectral_norm', {}).get(fn.name + '.version', None) + if version is None or version < 1: + with torch.no_grad(): + weight_orig = state_dict[prefix + fn.name + '_orig'] + weight = state_dict.pop(prefix + fn.name) + sigma = (weight_orig / weight).mean() + weight_mat = fn.reshape_weight_to_matrix(weight_orig) + u = state_dict[prefix + fn.name + '_u'] + v = fn._solve_v_and_rescale(weight_mat, u, sigma) + state_dict[prefix + fn.name + '_v'] = v + + +# This is a top level class because Py2 pickle doesn't like inner class nor an +# instancemethod. +class SpectralNormStateDictHook(object): + # See docstring of SpectralNorm._version on the changes to spectral_norm. + def __init__(self, fn): + self.fn = fn + + def __call__(self, module, state_dict, prefix, local_metadata): + if 'spectral_norm' not in local_metadata: + local_metadata['spectral_norm'] = {} + key = self.fn.name + '.version' + if key in local_metadata['spectral_norm']: + raise RuntimeError("Unexpected key in metadata['spectral_norm']: {}".format(key)) + local_metadata['spectral_norm'][key] = self.fn._version + + +def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None): + r"""Applies spectral normalization to a parameter in the given module. + + .. math:: + \mathbf{W} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\ + \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} + + Spectral normalization stabilizes the training of discriminators (critics) + in Generaive Adversarial Networks (GANs) by rescaling the weight tensor + with spectral norm :math:`\sigma` of the weight matrix calculated using + power iteration method. If the dimension of the weight tensor is greater + than 2, it is reshaped to 2D in power iteration method to get spectral + norm. This is implemented via a hook that calculates spectral norm and + rescales weight before every :meth:`~Module.forward` call. + + See `Spectral Normalization for Generative Adversarial Networks`_ . + + .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 + + Args: + module (nn.Module): containing module + name (str, optional): name of weight parameter + n_power_iterations (int, optional): number of power iterations to + calculate spectal norm + eps (float, optional): epsilon for numerical stability in + calculating norms + dim (int, optional): dimension corresponding to number of outputs, + the default is 0, except for modules that are instances of + ConvTranspose1/2/3d, when it is 1 + + Returns: + The original module with the spectal norm hook + + Example:: + + >>> m = spectral_norm(nn.Linear(20, 40)) + Linear (20 -> 40) + >>> m.weight_u.size() + torch.Size([20]) + + """ + if dim is None: + if isinstance(module, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d)): + dim = 1 + else: + dim = 0 + SpectralNorm.apply(module, name, n_power_iterations, dim, eps) + return module + + + +def remove_spectral_norm(module, name='weight'): + r"""Removes the spectral normalization reparameterization from a module. + + Args: + module (nn.Module): containing module + name (str, optional): name of weight parameter + + Example: + >>> m = spectral_norm(nn.Linear(40, 10)) + >>> remove_spectral_norm(m) + """ + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, SpectralNorm) and hook.name == name: + hook.remove(module) + del module._forward_pre_hooks[k] + return module + + raise ValueError("spectral_norm of '{}' not found in {}".format(name, module)) diff --git a/src/util.py b/src/util.py new file mode 100644 index 0000000..3d4e42b --- /dev/null +++ b/src/util.py @@ -0,0 +1,28 @@ +import torch + +def depth_to_space(x, r): + n, c, h, w = x.size() + x = x.view(n, r, r, c//r**2, h, w).permute(0, 3, 4, 1, 5, 2).reshape(n, c//r**2, r*h, r*w) + return x + +@torch._jit_internal.weak_script +def normalize(input, p=2, dim=1, eps=1e-12, out=None): + if out is None: + denom = input.norm(p, dim, True).expand_as(input) + ret = input / (denom + eps) + else: + denom = input.norm(p, dim, True).expand_as(input) + ret = torch.div(input, denom+eps, out=torch.jit._unwrap_optional(out)) + return ret + + + + + + + + + + + +