-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
330 lines (281 loc) · 13.6 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
import objax
import generator as g
import discriminator as d
import jax.numpy as jnp
from objax.functional import sigmoid
from PIL import Image
import numpy as np
import util as u
import wandb
import data
import argparse
import os
import time
import sys
JIT = True
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--manifest-file', type=str)
parser.add_argument('--group', type=str, default='dft')
parser.add_argument('--batch-size', type=int)
parser.add_argument('--gradient-clip', type=float, default=1.0)
parser.add_argument('--epochs', type=int, default=10000)
parser.add_argument('--max-run-time', type=int, default=None,
help='max run time in secs')
parser.add_argument('--steps-per-epoch', type=int)
parser.add_argument('--patch-size', type=int, default=64)
parser.add_argument('--positive-weight', type=float, default=1.0)
parser.add_argument('--reconstruction-loss-weight', type=float, default=1.0)
parser.add_argument('--change-loss-weight', type=float, default=0.0)
parser.add_argument('--discriminator-loss-weight', type=float, default=1.0)
parser.add_argument('--generator-sigmoid-b', type=float, default=1.0)
parser.add_argument('--discriminator-weight-clip', type=float, default=0.1)
parser.add_argument('--generator-learning-rate', type=float, default=1e-3)
parser.add_argument('--discriminator-learning-rate', type=float, default=1e-4)
opts = parser.parse_args()
print(opts)
finish_time = None
if opts.max_run_time is not None and opts.max_run_time > 0:
finish_time = time.time() + opts.max_run_time
RUN = u.DTS()
print(">RUN", RUN)
sys.stdout.flush()
wandb.init(project='dither_net', group=opts.group, name=RUN)
wandb.config.gradient_clip = opts.gradient_clip
wandb.config.positive_weight = opts.positive_weight
wandb.config.reconstruction_loss_weight = opts.reconstruction_loss_weight
wandb.config.change_loss_weight = opts.change_loss_weight
wandb.config.discriminator_loss_weight = opts.discriminator_loss_weight
wandb.config.generator_sigmoid_b = opts.generator_sigmoid_b
wandb.config.discriminator_weight_clip = opts.discriminator_weight_clip
wandb.config.generator_learning_rate = opts.generator_learning_rate
wandb.config.discriminator_learning_rate = opts.discriminator_learning_rate
generator = g.Generator()
discriminator = d.Discriminator()
print("generator", generator.vars())
print("discriminator", discriminator.vars())
sys.stdout.flush()
def steep_sigmoid(x):
# since D from true_dithers only sees (0, 1) we want to make it's
# job harder for fake_dithers by squashing the sigmoid activation towards
# 0 & 1. simplest way to do this is by varying B in the generalised
# logisitic function. this is handy since that makes it just a function on
# x, so we can continue to use the numerically stable jax.special.expit
# that objax.sigmoid wraps.
# see https://en.wikipedia.org/wiki/Generalised_logistic_function
return sigmoid(opts.generator_sigmoid_b * x)
def generator_loss(rgb_img_t1, true_dither_t0, true_dither_t1):
# generator loss is based on the generated images from the RGB
pred_dither_t1 = steep_sigmoid(generator(rgb_img_t1, true_dither_t0))
# it's based on two components;
# 1) a comparison to the t1 true_dither to see how well it reconstructs it
per_pixel_reconstruction_loss = jnp.abs(pred_dither_t1 - true_dither_t1)
loss_weight = jnp.where(true_dither_t1 == 1, opts.positive_weight, 1.0)
reconstruction_loss = jnp.mean(loss_weight * per_pixel_reconstruction_loss)
# 2) a comparison to the t0 true_dither to see how much has changed
per_pixel_change_loss = jnp.abs(pred_dither_t1 - true_dither_t0)
loss_weight = jnp.where(true_dither_t0 == 1, opts.positive_weight, 1.0)
change_loss = jnp.mean(loss_weight * per_pixel_change_loss)
# 3) how well it fools the discriminator
discriminator_logits = discriminator(pred_dither_t1, training=False)
overall_patch_loss = -jnp.mean(discriminator_logits)
# overall loss is weighted combination of the two
overall_loss = (reconstruction_loss * opts.reconstruction_loss_weight +
change_loss * opts.change_loss_weight +
overall_patch_loss * opts.discriminator_loss_weight)
return (overall_loss,
{'scaled': {'reconstruction_loss':
reconstruction_loss * opts.reconstruction_loss_weight,
'change_loss':
change_loss * opts.change_loss_weight,
'overall_patch_loss':
overall_patch_loss * opts.discriminator_loss_weight},
'unscaled': {'reconstruction_loss': reconstruction_loss,
'change_loss': change_loss,
'overall_patch_loss': overall_patch_loss}})
def discriminator_loss(rgb_img_t1, true_dither_t0, true_dither_t1):
# discriminator loss is based on discriminator's ability to distinguish
# (smoothed) true dithers ...
smoothed_true_dither_t1 = (true_dither_t1 * 0.8) + 0.1
discriminator_logits = discriminator(
smoothed_true_dither_t1, training=True)
true_dither_loss = jnp.mean(discriminator_logits)
# ...from fake dithers
fake_dither_t1 = steep_sigmoid(generator(rgb_img_t1, true_dither_t0))
discriminator_logits = discriminator(fake_dither_t1, training=True)
fake_dither_loss = jnp.mean(discriminator_logits)
# overall loss is the sum
overall_loss = fake_dither_loss - true_dither_loss
return (overall_loss,
{'fake_dither_loss': fake_dither_loss,
'true_dither_loss': true_dither_loss})
def build_train_step_fn(model, loss_fn):
gradient_loss = objax.GradValues(loss_fn, model.vars())
optimiser = objax.optimizer.Adam(model.vars())
def train_step(learning_rate, rgb_img_t1, true_dither_t0, true_dither_t1):
grads, _loss = gradient_loss(
rgb_img_t1, true_dither_t0, true_dither_t1)
grads = u.clip_gradients(grads, theta=opts.gradient_clip)
optimiser(learning_rate, grads)
grad_norms = [jnp.linalg.norm(g) for g in grads]
return grad_norms
if JIT:
train_step = objax.Jit(
train_step, gradient_loss.vars() + optimiser.vars())
return train_step
generator_train_step = build_train_step_fn(
generator, generator_loss)
discriminator_train_step = build_train_step_fn(
discriminator, discriminator_loss)
# load some full res images for checking model performance during training
# first three are sequential frames at a scene change.
# second three are a mix of lighing conditions; a mix, a light & a dark
full_rgbs_t1 = []
full_dithers_t0 = []
full_dithers_t1 = []
for frame in [55290, 55291, 55292, 67000, 77000, 90000]:
_full_rgb, full_true_dither = data.parse_full_size(
"frames/full_res/f_%08d.jpg" % (frame-1))
full_dithers_t0.append(full_true_dither)
full_rgb, full_true_dither = data.parse_full_size(
"frames/full_res/f_%08d.jpg" % frame)
full_rgbs_t1.append(full_rgb)
full_dithers_t1.append(full_true_dither)
full_rgbs_t1 = np.stack(full_rgbs_t1)
full_dithers_t0 = np.stack(full_dithers_t0)
full_dithers_t1 = np.stack(full_dithers_t1)
# jit the generator now (we'll use it for predicting against the full res
# images) and also the two loss fns
if JIT:
generator = objax.Jit(generator)
generator_loss = objax.Jit(generator_loss, generator.vars())
discriminator_loss = objax.Jit(discriminator_loss, discriminator.vars())
# setup output directory for full res samples
u.ensure_dir_exists("full_res_samples/%s" % RUN)
if os.path.exists("full_res_samples/latest"):
os.remove("full_res_samples/latest")
os.symlink(RUN, "full_res_samples/latest")
# init dataset iterator
dataset = data.dataset(manifest_file=opts.manifest_file,
batch_size=opts.batch_size,
patch_size=opts.patch_size)
# set up ckpting for G and D
generator_ckpt = objax.io.Checkpoint(
logdir=f"ckpts/{RUN}/generator/", keep_ckpts=20)
discriminator_ckpt = objax.io.Checkpoint(
logdir=f"ckpts/{RUN}/discriminator/", keep_ckpts=20)
# run training loop!
for epoch in range(opts.epochs):
generator_grads_min_max = None
discriminator_grads_min_max = None
# run some number of steps, alternating between training G and D
train_generator = True
for (rgb_imgs_t1, true_dithers_t0,
true_dithers_t1) in dataset.take(opts.steps_per_epoch):
rgb_imgs_t1 = rgb_imgs_t1.numpy()
true_dithers_t0 = true_dithers_t0.numpy()
true_dithers_t1 = true_dithers_t1.numpy()
if train_generator:
grad_norms = generator_train_step(
opts.generator_learning_rate, rgb_imgs_t1, true_dithers_t0,
true_dithers_t1)
if generator_grads_min_max is None:
generator_grads_min_max = (float(jnp.min(grad_norms)),
float(jnp.max(grad_norms)))
else:
grad_norms = discriminator_train_step(
opts.generator_learning_rate, rgb_imgs_t1, true_dithers_t0,
true_dithers_t1)
if discriminator_grads_min_max is None:
discriminator_grads_min_max = (float(jnp.min(grad_norms)),
float(jnp.max(grad_norms)))
# clip D weights. urgh; this is the hacky way to do the lipschitz
# constraint; much better to get working with the gradient penalty
for v in discriminator.vars().values():
v.assign(jnp.clip(v.value,
-opts.discriminator_weight_clip,
opts.discriminator_weight_clip))
train_generator = not train_generator
# ckpt models
generator_ckpt.save(generator.vars(), idx=epoch)
discriminator_ckpt.save(discriminator.vars(), idx=epoch)
# check loss against last batch
overall_loss, component_losses = generator_loss(
rgb_imgs_t1, true_dithers_t0, true_dithers_t1)
generator_losses = { # clumsy o_O; treemap the float cast?
'overall_loss': float(overall_loss),
'scaled': {
'change_loss':
float(component_losses['scaled']['change_loss']),
'reconstruction_loss':
float(component_losses['scaled']['reconstruction_loss']),
'overall_patch_loss':
float(component_losses['scaled']['overall_patch_loss'])
},
'unscaled': {
'change_loss':
float(component_losses['unscaled']['change_loss']),
'reconstruction_loss':
float(component_losses['unscaled']['reconstruction_loss']),
'overall_patch_loss':
float(component_losses['unscaled']['overall_patch_loss'])
}
}
overall_loss, component_losses = discriminator_loss(
rgb_imgs_t1, true_dithers_t0, true_dithers_t1)
discriminator_losses = {
'overall_loss': float(overall_loss),
'fake_dither_loss': float(component_losses['fake_dither_loss']),
'true_dither_loss': float(component_losses['true_dither_loss'])
}
# save full res pred dithers in a collage.
full_pred_dithers = generator(full_rgbs_t1, full_dithers_t0)
samples = [u.dither_to_pil(p) for p in full_pred_dithers]
collage = u.collage(samples)
collage.save("full_res_samples/%s/%05d.png" % (RUN, epoch))
collage.save("full_res_samples/last/%s.png" % RUN)
# sanity check for collapse of all white or all black
num_sample_white_pixels = int(jnp.sum(full_pred_dithers > 0))
num_sample_black_pixels = int(jnp.sum(full_pred_dithers < 0))
# some wandb logging
wandb.log({
'gen_overall_loss': generator_losses['overall_loss'],
'gen_scaled_reconstruction_loss':
generator_losses['scaled']['reconstruction_loss'],
'gen_scaled_change_loss':
generator_losses['scaled']['change_loss'],
'gen_scaled_overall_patch_loss':
generator_losses['scaled']['overall_patch_loss'],
'gen_unscaled_reconstruction_loss':
generator_losses['unscaled']['reconstruction_loss'],
'gen_unscaled_change_loss':
generator_losses['unscaled']['change_loss'],
'gen_unscaled_overall_patch_loss':
generator_losses['unscaled']['overall_patch_loss'],
'discrim_overall_loss': discriminator_losses['overall_loss'],
'discrim_fake_dither_loss': discriminator_losses['fake_dither_loss'],
'discrim_true_dither_loss': discriminator_losses['true_dither_loss'],
'generator_grad_norm_min': generator_grads_min_max[0],
'generator_grad_norm_max': generator_grads_min_max[1],
'discriminator_grad_norm_min': discriminator_grads_min_max[0],
'discriminator_grad_norm_max': discriminator_grads_min_max[1],
'num_sample_white_pixels': num_sample_white_pixels,
'num_sample_black_pixels': num_sample_black_pixels
}, step=epoch)
# some stdout logging
print("epoch", epoch,
"generator_losses", generator_losses,
"generator_grads_min_max", generator_grads_min_max,
"discriminator_losses", discriminator_losses,
"discriminator_grads_min_max", discriminator_grads_min_max,
# "range_of_dithers", range_of_dithers,
'num_sample_white_pixels', num_sample_white_pixels,
'num_sample_black_pixels', num_sample_black_pixels)
sys.stdout.flush()
if finish_time is not None and time.time() > finish_time:
print("time up", epoch)
break
if epoch >= 1 and (num_sample_white_pixels < 100000 or
num_sample_black_pixels < 100000):
print("model collapse?", epoch)
break