-
Notifications
You must be signed in to change notification settings - Fork 5
/
updater.py
108 lines (79 loc) · 2.96 KB
/
updater.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import numpy as np
from chainer import training, reporter, cuda
from chainer import Variable
class WassersteinGANUpdater(training.StandardUpdater):
def __init__(self, *, iterator, noise_iterator, optimizer_generator,
optimizer_critic, device=-1):
if optimizer_generator.target.name is None:
optimizer_generator.target.name = 'generator'
if optimizer_critic.target.name is None:
optimizer_critic.target.name = 'critic'
iterators = {'main': iterator, 'z': noise_iterator}
optimizers = {'generator': optimizer_generator,
'critic': optimizer_critic}
super().__init__(iterators, optimizers, device=device)
if device >= 0:
cuda.get_device(device).use()
[optimizer.target.to_gpu() for optimizer in optimizers.values()]
self.xp = cuda.cupy if device >= 0 else np
@property
def optimizer_generator(self):
return self._optimizers['generator']
@property
def optimizer_critic(self):
return self._optimizers['critic']
@property
def generator(self):
return self._optimizers['generator'].target
@property
def critic(self):
return self._optimizers['critic'].target
@property
def x(self):
return self._iterators['main']
@property
def z(self):
return self._iterators['z']
def next_batch(self, iterator):
batch = self.converter(iterator.next(), self.device)
return Variable(batch)
def sample(self):
"""Return a sample batch of images."""
z = self.next_batch(self.z)
x = self.generator(z, test=True)
# [-1, 1] -> [0, 1]
x += 1.0
x /= 2
return x
def update_core(self):
def _update(optimizer, loss):
optimizer.target.cleargrads()
loss.backward()
optimizer.update()
# Update critic 5 times
for _ in range(5):
# Clamp critic parameters
self.critic.clamp()
# Real images
x_real = self.next_batch(self.x)
y_real = self.critic(x_real)
y_real.grad = self.xp.ones_like(y_real.data)
_update(self.optimizer_critic, y_real)
# Fake images
z = self.next_batch(self.z)
x_fake = self.generator(z)
y_fake = self.critic(x_fake)
y_fake.grad = -1 * self.xp.ones_like(y_fake.data)
_update(self.optimizer_critic, y_fake)
reporter.report({
'critic/loss/real': y_real,
'critic/loss/fake': y_fake,
'critic/loss': y_real - y_fake
})
# Update generator 1 time
z = self.next_batch(self.z)
x_fake = self.generator(z)
y_fake = self.critic(x_fake)
y_fake.grad = self.xp.ones_like(y_fake.data)
_update(self.optimizer_generator, y_fake)
reporter.report({'generator/loss': y_fake})