-
Notifications
You must be signed in to change notification settings - Fork 5
/
chainer_models.py
122 lines (108 loc) · 5.22 KB
/
chainer_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
Source: https://github.com/AkinoriTanaka-phys/DOT/blob/master/model.py
Copyright (c) 2019 AkinoriTanaka-phys
Released under the MIT license
https://github.com/AkinoriTanaka-phys/DOT/blob/master/LICENSE
Copyright (c) 2017 pfnet-research
Released under the MIT license
https://github.com/pfnet-research/chainer-gan-lib/blob/master/LICENSE
"""
import chainer
import chainer.functions as F
import chainer.links as L
import cupy as xp
from snd.sn_linear import SNLinear
from snd.sn_convolution_2d import SNConvolution2D
class DCGANGenerator(chainer.Chain):
def __init__(self, n_hidden=128, bottom_width=4,
ch=512, wscale=0.02,
hidden_activation=F.relu,
output_activation=F.tanh,
use_bn=True):
super(DCGANGenerator, self).__init__()
self.n_hidden = n_hidden
self.ch = ch
self.bottom_width = bottom_width
self.hidden_activation = hidden_activation
self.output_activation = output_activation
self.use_bn = use_bn
with self.init_scope():
w = chainer.initializers.Normal(wscale)
self.l0 = L.Linear(self.n_hidden, bottom_width * bottom_width * ch,
initialW=w)
self.dc1 = L.Deconvolution2D(ch, ch // 2, 4, 2, 1, initialW=w)
self.dc2 = L.Deconvolution2D(ch // 2, ch // 4, 4, 2, 1, initialW=w)
self.dc3 = L.Deconvolution2D(ch // 4, ch // 8, 4, 2, 1, initialW=w)
self.dc4 = L.Deconvolution2D(ch // 8, 3, 3, 1, 1, initialW=w)
if self.use_bn:
self.bn0 = L.BatchNormalization(
bottom_width * bottom_width * ch)
self.bn1 = L.BatchNormalization(ch // 2)
self.bn2 = L.BatchNormalization(ch // 4)
self.bn3 = L.BatchNormalization(ch // 8)
def __call__(self, z):
if not self.use_bn:
h = F.reshape(self.hidden_activation(self.l0(z)),
(len(z), self.ch,
self.bottom_width, self.bottom_width))
h = self.hidden_activation(self.dc1(h))
h = self.hidden_activation(self.dc2(h))
h = self.hidden_activation(self.dc3(h))
x = self.output_activation(self.dc4(h))
else:
h = F.reshape(self.hidden_activation(self.bn0(self.l0(z))),
(len(z), self.ch,
self.bottom_width, self.bottom_width))
h = self.hidden_activation(self.bn1(self.dc1(h)))
h = self.hidden_activation(self.bn2(self.dc2(h)))
h = self.hidden_activation(self.bn3(self.dc3(h)))
x = self.output_activation(self.dc4(h))
return x
class WGANDiscriminator(chainer.Chain):
def __init__(self, bottom_width=4, ch=512, wscale=0.02, output_dim=1):
w = chainer.initializers.Normal(wscale)
super(WGANDiscriminator, self).__init__()
with self.init_scope():
self.c0 = L.Convolution2D(3, ch // 8, 3, 1, 1, initialW=w)
self.c1 = L.Convolution2D(ch // 8, ch // 4, 4, 2, 1, initialW=w)
self.c1_0 = L.Convolution2D(ch // 4, ch // 4, 3, 1, 1, initialW=w)
self.c2 = L.Convolution2D(ch // 4, ch // 2, 4, 2, 1, initialW=w)
self.c2_0 = L.Convolution2D(ch // 2, ch // 2, 3, 1, 1, initialW=w)
self.c3 = L.Convolution2D(ch // 2, ch // 1, 4, 2, 1, initialW=w)
self.c3_0 = L.Convolution2D(ch // 1, ch // 1, 3, 1, 1, initialW=w)
self.l4 = L.Linear(
bottom_width * bottom_width * ch, output_dim, initialW=w)
def __call__(self, x):
self.x = x
self.h0 = F.leaky_relu(self.c0(self.x))
self.h1 = F.leaky_relu(self.c1(self.h0))
self.h2 = F.leaky_relu(self.c1_0(self.h1))
self.h3 = F.leaky_relu(self.c2(self.h2))
self.h4 = F.leaky_relu(self.c2_0(self.h3))
self.h5 = F.leaky_relu(self.c3(self.h4))
self.h6 = F.leaky_relu(self.c3_0(self.h5))
return self.l4(self.h6)
class SNDCGANDiscriminator(chainer.Chain):
def __init__(self, bottom_width=4, ch=512, wscale=0.02, output_dim=1):
w = chainer.initializers.Normal(wscale)
self.bottom_width = bottom_width
super(SNDCGANDiscriminator, self).__init__()
with self.init_scope():
self.c0_0 = SNConvolution2D(3, ch // 8, 3, 1, 1, initialW=w)
self.c0_1 = SNConvolution2D(ch // 8, ch // 4, 4, 2, 1, initialW=w)
self.c1_0 = SNConvolution2D(ch // 4, ch // 4, 3, 1, 1, initialW=w)
self.c1_1 = SNConvolution2D(ch // 4, ch // 2, 4, 2, 1, initialW=w)
self.c2_0 = SNConvolution2D(ch // 2, ch // 2, 3, 1, 1, initialW=w)
self.c2_1 = SNConvolution2D(ch // 2, ch // 1, 4, 2, 1, initialW=w)
self.c3_0 = SNConvolution2D(ch // 1, ch // 1, 3, 1, 1, initialW=w)
self.l4 = SNLinear(
bottom_width * bottom_width * ch, output_dim, initialW=w)
def __call__(self, x):
h = F.leaky_relu(self.c0_0(x))
h = F.leaky_relu(self.c0_1(h))
h = F.leaky_relu(self.c1_0(h))
h = F.leaky_relu(self.c1_1(h))
h = F.leaky_relu(self.c2_0(h))
h = F.leaky_relu(self.c2_1(h))
h = F.leaky_relu(self.c3_0(h))
return self.l4(h)