This repository has been archived by the owner on Nov 12, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathtrain.py
executable file
·102 lines (70 loc) · 3.04 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#!/usr/bin/env python3
import sys
import pickle
import matplotlib.pyplot as plt
import numpy as np
from keras.callbacks import TensorBoard
from keras.optimizers import RMSprop
from vaegan.models import create_models, build_graph
from vaegan.training import fit_models
from vaegan.data import celeba_loader, encoder_loader, decoder_loader, discriminator_loader, NUM_SAMPLES, mnist_loader
from vaegan.callbacks import DecoderSnapshot, ModelsCheckpoint
def set_trainable(model, trainable):
model.trainable = trainable
for layer in model.layers:
layer.trainable = trainable
def main():
encoder, decoder, discriminator = create_models()
encoder_train, decoder_train, discriminator_train, vae, vaegan = build_graph(encoder, decoder, discriminator)
try:
initial_epoch = int(sys.argv[1])
except (IndexError, ValueError):
initial_epoch = 0
epoch_format = '.{epoch:03d}.h5'
if initial_epoch != 0:
suffix = epoch_format.format(epoch=initial_epoch)
encoder.load_weights('encoder' + suffix)
decoder.load_weights('decoder' + suffix)
discriminator.load_weights('discriminator' + suffix)
batch_size = 64
rmsprop = RMSprop(lr=0.0003)
set_trainable(encoder, False)
set_trainable(decoder, False)
discriminator_train.compile(rmsprop, ['binary_crossentropy'] * 3, ['acc'] * 3)
discriminator_train.summary()
set_trainable(discriminator, False)
set_trainable(decoder, True)
decoder_train.compile(rmsprop, ['binary_crossentropy'] * 2, ['acc'] * 2)
decoder_train.summary()
set_trainable(decoder, False)
set_trainable(encoder, True)
encoder_train.compile(rmsprop)
encoder_train.summary()
set_trainable(vaegan, True)
checkpoint = ModelsCheckpoint(epoch_format, encoder, decoder, discriminator)
decoder_sampler = DecoderSnapshot()
callbacks = [checkpoint, decoder_sampler, TensorBoard()]
epochs = 250
steps_per_epoch = NUM_SAMPLES // batch_size
seed = np.random.randint(2**32 - 1)
img_loader = celeba_loader(batch_size, num_child=3, seed=seed)
dis_loader = discriminator_loader(img_loader, seed=seed)
dec_loader = decoder_loader(img_loader, seed=seed)
enc_loader = encoder_loader(img_loader)
models = [discriminator_train, decoder_train, encoder_train]
generators = [dis_loader, dec_loader, enc_loader]
metrics = [{'di_l': 1, 'di_l_t': 2, 'di_l_p': 3, 'di_a': 4, 'di_a_t': 7, 'di_a_p': 10}, {'de_l_t': 1, 'de_l_p': 2, 'de_a_t': 3, 'de_a_p': 5}, {'en_l': 0}]
histories = fit_models(vaegan, models, generators, metrics, batch_size,
steps_per_epoch=steps_per_epoch, callbacks=callbacks,
epochs=epochs, initial_epoch=initial_epoch)
with open('histories.pickle', 'wb') as f:
pickle.dump(histories, f)
x = next(celeba_loader(1))
x_tilde = vae.predict(x)
plt.subplot(211)
plt.imshow((x[0].squeeze() + 1.) / 2.)
plt.subplot(212)
plt.imshow((x_tilde[0].squeeze() + 1.) / 2.)
plt.show()
if __name__ == '__main__':
main()