-
Notifications
You must be signed in to change notification settings - Fork 30
/
train.py
executable file
·116 lines (95 loc) · 3.66 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import cPickle as pickle
import tensorflow as tf
from scipy import misc
import numpy as np
import argparse
import ntpath
import sys
import os
import time
import data_ops
import net
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--EPOCHS', required=False,default=10,type=int,help='Number of epochs to train for')
parser.add_argument('--DATA_DIR', required=True,help='Directory where data is')
parser.add_argument('--BATCH_SIZE', required=False,type=int,default=32,help='Batch size to use')
a = parser.parse_args()
EPOCHS = a.EPOCHS
DATA_DIR = a.DATA_DIR
BATCH_SIZE = a.BATCH_SIZE
CHECKPOINT_DIR = 'checkpoints/'
IMAGES_DIR = CHECKPOINT_DIR+'images/'
try: os.mkdir(CHECKPOINT_DIR)
except: pass
try: os.mkdir(IMAGES_DIR)
except: pass
# write all this info to a pickle file in the experiments directory
exp_info = dict()
exp_info['EPOCHS'] = EPOCHS
exp_info['DATA_DIR'] = DATA_DIR
exp_info['BATCH_SIZE'] = BATCH_SIZE
exp_pkl = open(CHECKPOINT_DIR+'info.pkl', 'wb')
data = pickle.dumps(exp_info)
exp_pkl.write(data)
exp_pkl.close()
print
print 'EPOCHS: ',EPOCHS
print 'DATA_DIR: ',DATA_DIR
print 'BATCH_SIZE: ',BATCH_SIZE
print
# global step that is saved with a model to keep track of how many steps/epochs
global_step = tf.Variable(0, name='global_step', trainable=False)
# load data
Data = data_ops.loadData(DATA_DIR, BATCH_SIZE)
num_train = Data.count
gray_image = Data.inputs
color_image = Data.targets
# architecture from
col_img = net.architecture(gray_image)
#loss = tf.reduce_mean((ab_image-col_img)**2)
loss = tf.reduce_mean(tf.nn.l2_loss(color_image-col_img))
train_op = tf.train.AdamOptimizer(learning_rate=1e-6).minimize(loss, global_step=global_step)
saver = tf.train.Saver(max_to_keep=1)
# tensorboard summaries
tf.summary.scalar('loss', loss)
init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess = tf.Session()
sess.run(init)
# write out logs for tensorboard to the checkpointSdir
summary_writer = tf.summary.FileWriter(CHECKPOINT_DIR+'/logs/', graph=tf.get_default_graph())
ckpt = tf.train.get_checkpoint_state(CHECKPOINT_DIR)
# restore previous model if there is one
if ckpt and ckpt.model_checkpoint_path:
print "Restoring previous model..."
try:
saver.restore(sess, ckpt.model_checkpoint_path)
print "Model restored"
except:
print "Could not restore model"
pass
########################################### training portion
step = sess.run(global_step)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord=coord)
merged_summary_op = tf.summary.merge_all()
start = time.time()
epoch_num = step/(num_train/BATCH_SIZE)
while epoch_num < EPOCHS:
epoch_num = step/(num_train/BATCH_SIZE)
s = time.time()
sess.run(train_op)
loss_, summary = sess.run([loss, merged_summary_op])
summary_writer.add_summary(summary, step)
summary_writer.add_summary(summary, step)
print 'epoch:',epoch_num,'step:',step,'loss:',loss_,'time:',time.time()-s
step += 1
if step%500 == 0:
print 'Saving model...'
saver.save(sess, CHECKPOINT_DIR+'checkpoint-'+str(step))
saver.export_meta_graph(CHECKPOINT_DIR+'checkpoint-'+str(step)+'.meta')
print 'Model saved\n'
print 'Finished training', time.time()-start
saver.save(sess, CHECKPOINT_DIR+'checkpoint-'+str(step))
saver.export_meta_graph(CHECKPOINT_DIR+'checkpoint-'+str(step)+'.meta')
exit()