-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
152 lines (128 loc) · 5.28 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""Train the model."""
import argparse
import glob
import os
import math
from time import time
import keras.backend as K
from keras.callbacks import ModelCheckpoint, TensorBoard
from keras.layers import Input
from keras.models import load_model
from keras.optimizers import Adam
from model.flowchroma_network import FlowChroma
from model.fusion_layer import FusionLayer
from dataset.utils.shared import frames_per_video, default_nn_input_width, default_nn_input_height, dir_lab_records, \
dir_resnet_csv
from dataset.data_generator import DataGenerator
parser = argparse.ArgumentParser(description='Train flowchroma')
parser.add_argument('-r', '--resnet-records',
type=str,
metavar='FILE',
dest='resnet_path',
default=dir_resnet_csv,
help='directory to resnet records')
parser.add_argument('-l', '--lab-records',
type=str,
metavar='FILE',
dest='lab_path',
default=dir_lab_records,
help='directory to lab records')
parser.add_argument('-s', '--split-ratio',
type=float,
default=0.1,
dest='val_split_ratio',
help='validation split ratio')
parser.add_argument('-lr', '--learning-rate',
type=float,
default=0.001,
dest='lr',
help='learning rate for the optimizer')
parser.add_argument('-t', '--train-batch-size',
type=int,
dest='train_batch_size',
default=4,
help='batch size of training set')
parser.add_argument('-v', '--val-batch-size',
type=int,
dest='val_batch_size',
default=4,
help='batch size of validation set')
parser.add_argument('-e', '--epochs',
type=int,
dest='n_epochs_to_train',
default=10,
help='number of epochs to train')
parser.add_argument('-c', '--ckpt-period',
type=int,
dest='ckpt_period',
default=2,
help='checkpoint period')
args = parser.parse_args()
resnet_path = args.resnet_path
lab_path = args.lab_path
val_split_ratio = args.val_split_ratio
lr = args.lr
train_batch_size = args.train_batch_size
val_batch_size = args.val_batch_size
n_epochs_to_train = args.n_epochs_to_train
ckpt_period = args.ckpt_period
time_steps, h, w = frames_per_video, default_nn_input_height, default_nn_input_width
initial_epoch = 0
ckpts = glob.glob("checkpoints/*.hdf5")
if len(ckpts) != 0:
# there are ckpts
latest_ckpt = max(ckpts, key=os.path.getctime)
print("loading from checkpoint:", latest_ckpt)
initial_epoch = int(latest_ckpt[latest_ckpt.find("-epoch-") + len("-epoch-"):latest_ckpt.rfind("-lr-")])
model = load_model(latest_ckpt, custom_objects={'FusionLayer': FusionLayer})
else:
# no ckpts
enc_input = Input(shape=(time_steps, h, w, 1), name='encoder_input')
incep_out = Input(shape=(time_steps, 1000), name='inception_input')
model = FlowChroma([enc_input, incep_out]).build()
opt = Adam(lr=lr)
model.compile(optimizer=opt, loss='mse', metrics=['accuracy'])
# generate_model_summaries(model)
n_lab_records = len(glob.glob('{0}/*.npy'.format(dir_lab_records)))
n_resnet_records = len(glob.glob('{0}/*.npy'.format(dir_resnet_csv)))
assert n_lab_records == n_resnet_records
val_split = int(math.floor(n_lab_records * val_split_ratio))
dataset = {
"validation": ['{0:05}'.format(i) for i in range(val_split)],
"train": ['{0:05}'.format(i) for i in range(val_split, n_lab_records)]
}
basic_generator_params = {
"resnet_path": resnet_path,
"lab_path": lab_path,
"time_steps": time_steps,
"h": h,
"w": w
}
# generators
training_generator = DataGenerator(**basic_generator_params,
file_ids=dataset['train'],
batch_size=train_batch_size)
validation_generator = DataGenerator(**basic_generator_params,
file_ids=dataset['validation'],
batch_size=val_batch_size)
os.makedirs("checkpoints", exist_ok=True)
file_path = "checkpoints/flowchroma-epoch-{epoch:05d}-lr-" + str(
lr) + "-train_loss-{loss:.4f}-val_loss-{val_loss:.4f}.hdf5"
checkpoint = ModelCheckpoint(file_path,
monitor=['loss', 'val_loss'],
verbose=1,
save_best_only=False,
save_weights_only=False,
mode='min',
period=ckpt_period)
tensorboard = TensorBoard(log_dir="logs/{}".format(time()), histogram_freq=0)
if n_epochs_to_train <= initial_epoch:
n_epochs_to_train += initial_epoch
model.fit_generator(generator=training_generator,
validation_data=validation_generator,
use_multiprocessing=True,
epochs=n_epochs_to_train,
initial_epoch=initial_epoch,
callbacks=[checkpoint, tensorboard],
workers=6)
K.clear_session()