forked from johnryh/Face_Embedding_GAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
116 lines (79 loc) · 4.96 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
'''
Copyright 2018 - 2019 Duke University
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License Version 2 as published by
the Free Software Foundation.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License Version 2
along with this program. If not, see <https://www.gnu.org/licenses/old-licenses/gpl-2.0.txt>.
'''
from Input_Pipeline_celeba import *
from network_utility import *
from utilities import *
from tqdm import tqdm
import numpy as np
import os, time
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
prev_phase_iter = get_prev_phase_iter()
print('prev_phase_iter:', prev_phase_iter)
if __name__ == '__main__':
tf.reset_default_graph()
real_img, mask = build_input_pipline(batch_size, train_tfrecord_path)
z = get_z(batch_size, latent_size)
#build model here
real_img = tf.reshape(real_img, [batch_size, output_img_h, output_img_w, 3])
mask = tf.reshape(mask, [batch_size, output_img_h, output_img_w, 1])
print(real_img, mask)
model = prog_w_gan(z, real_img, mask, phase = phase, LAMBDA=10)
merged = tf.summary.merge_all()
g_alpha = 1e-12
d_alpha = 1e-12
run_metadata = tf.RunMetadata()
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
train_writer = tf.summary.FileWriter('runs/{}/logs/phase_{}'.format(exp_name, phase), sess.graph)
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init_op)
print('Session Initiated')
if use_smooth and phase > 2:
model.loader.restore(sess, 'runs/{}/model/phase_{}/iteration_latest/model_latest.ckpt'.format(exp_name, phase-1, prev_phase_iter))
model.d_smooth_loader.restore(sess, 'runs/{}/model/phase_{}/iteration_latest/model_latest.ckpt'.format(exp_name, phase-1, prev_phase_iter))
print('***Phase{}: Phase_{} weights loaded***'.format(phase, phase-1))
elif phase >= 1 and not use_smooth and prev_phase_iter > 0:
model.saver.restore(sess, 'runs/{}/model/phase_{}/iteration_latest/model_latest.ckpt'.format(exp_name, phase))
print('***Phase{}: Phase_{} weights loaded***'.format(phase, phase))
per_iter_time = 0
with tqdm(total=int(epoch_num * total_samples / batch_size-prev_phase_iter), unit='it') as pbar:
train_start_time = time.time()
for iter in range(prev_phase_iter,int(epoch_num * total_samples / batch_size)):
iter_start_time = time.time()
smooth_factors = {model.g_alpha: g_alpha, model.d_alpha: d_alpha}
for critic_itr in range(n_critic-1):
sess.run([model.apply_d_grad], feed_dict=smooth_factors)
_, _, g_loss, d_loss, summary = sess.run([model.apply_d_grad, model.apply_g_grad, model.g_loss, model.d_loss, merged], feed_dict=smooth_factors)
g_alpha = np.clip(g_alpha + 5e-5, 0, 1)
d_alpha = np.clip(d_alpha + 5e-5, 0, 1)
iter_per_sec = 1/(time.time() - iter_start_time)
train_writer.add_summary(summary, iter)
if iter % 10 == 0: pbar.set_postfix({'it_ins/s':'{:4.2f}, d_loss:{}, g_loss:{}'.format(iter_per_sec, d_loss, g_loss)})
pbar.update(1)
if iter == 0:
real_img, fake_masks = sess.run([model.real_images, model.fake_masks], feed_dict=smooth_factors)
save_png(real_img[:16,:,:,:], [4 , 4], 'runs/{}/samples/phase_{}/real_sample.png'.format(exp_name, phase))
save_tiff(fake_masks[:16,:,:,:], [4, 4], 'runs/{}/samples/phase_{}/fake_mask_{}.tif'.format(exp_name, phase, iter))
if iter % int(1000) == 0:
fake_img, fake_masks = sess.run([model.fake_images, model.fake_masks], feed_dict=smooth_factors)
save_png(fake_img[:16,:,:,:], [4,4], 'runs/{}/samples/phase_{}/fake_{}.png'.format(exp_name, phase, iter))
save_tiff(fake_masks[:16,:,:,:], [4,4], 'runs/{}/samples/phase_{}/fake_mask_{}.tif'.format(exp_name, phase, iter))
if iter % 1000 == 0 and iter != 0:
root = 'runs/{}/model/phase_{}/iteration_{}/'.format(exp_name, phase, iter)
if not os.path.exists(root):
os.makedirs(root)
model.saver.save(sess, root + 'model_{}.ckpt'.format(iter))
root = 'runs/{}/model/phase_{}/iteration_latest/'.format(exp_name, phase)
model.saver.save(sess, root + 'model_latest.ckpt')