-
Notifications
You must be signed in to change notification settings - Fork 0
/
gan.py
112 lines (102 loc) · 3.69 KB
/
gan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from layers import SVDConv2d
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
# Generator Params
nz = 100 # Size of the latent z vector
ngf = 64 # ??
nc = 3 # Number of channels
# Discriminator Params
ndf = 64 # ??
class Generator(nn.Module):
def __init__(self, ngpu, nz, ngf, nc):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, nc, 4, 2, 1, bias=False),
#nn.BatchNorm2d(ngf),
#nn.ReLU(True),
# state size. (ngf) x 32 x 32
#nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, input):
if input.is_cuda and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
return output
class Discriminator(nn.Module):
def __init__(self, ngpu, nz, ndf, nc, scale):
super(Discriminator, self).__init__()
self.ngpu = ngpu
self.scale = scale
self.main = nn.Sequential(
# input is (nc) x 64 x 64
#SVDConv2d(nc, ndf, 4, self.scale, 2, 1, bias=False),
#nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
SVDConv2d(nc, ndf * 2, 4,self.scale, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
SVDConv2d(ndf * 2, ndf * 4, 4,self.scale, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
SVDConv2d(ndf * 4, ndf * 8, 4,self.scale, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
SVDConv2d(ndf * 8, 1, 4,self.scale, 1, 0, bias=False),
nn.Sigmoid()
)
def orth_reg(self):
reg = 0
for m in self.modules():
if isinstance(m, SVDConv2d):
reg += m.orth_reg()
return reg
def D_optimal_reg(self):
reg = 0
for m in self.modules():
if isinstance(m, SVDConv2d):
reg += m.spectral_reg()
return reg
def forward(self, input):
if input.is_cuda and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
return output.view(-1, 1).squeeze(1)