-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
126 lines (114 loc) · 4.31 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import sys
sys.path.append('./model/model/')
import AV_model as AV
from option import ModelMGPU, latest_file
from keras.callbacks import ModelCheckpoint, LearningRateScheduler, Callback
from keras.models import Model, load_model
from data_load import AVGenerator
from keras.callbacks import TensorBoard
from keras import optimizers
import os
from loss import audio_discriminate_loss2 as audio_loss
import tensorflow as tf
import matplotlib.pyplot as plt
# Resume Model
resume_state = False
# Parameters
people_num = 2
epochs = 10
initial_epoch = 0
batch_size = 1
gamma_loss = 0.1
beta_loss = gamma_loss * 2
# Accelerate Training Process
workers = 8
MultiProcess = True
NUM_GPU = 0
# PATH
model_path = './saved_AV_models' # model path
database_path = 'data/'
# create folder to save models
folder = os.path.exists(model_path)
if not folder:
os.makedirs(model_path)
print('create folder to save models')
filepath = model_path + "/AVmodel-" + str(people_num) + "p-{epoch:03d}-{val_loss:.5f}.h5"
checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min')
# automatically change lr
def scheduler(epoch):
ini_lr = 0.00001
lr = ini_lr
if epoch >= 5:
lr = ini_lr / 5
if epoch >= 10:
lr = ini_lr / 10
return lr
rlr = LearningRateScheduler(scheduler, verbose=1)
# format: mix.npy single.npy single.npy
trainfile = []
valfile = []
with open((database_path + 'AVdataset_train.txt'), 'r') as t:
trainfile = t.readlines()
with open((database_path + 'AVdataset_val.txt'), 'r') as v:
valfile = v.readlines()
# the training steps
if resume_state:
latest_file = latest_file(model_path + '/')
AV_model = load_model(latest_file, custom_objects={"tf": tf})
info = latest_file.strip().split('-')
initial_epoch = int(info[-2])
else:
AV_model = AV.AV_model(people_num)
train_generator = AVGenerator(trainfile, database_path=database_path, batch_size=batch_size, shuffle=True)
val_generator = AVGenerator(valfile, database_path=database_path, batch_size=batch_size, shuffle=True)
if NUM_GPU > 1:
parallel_model = ModelMGPU(AV_model, NUM_GPU)
adam = optimizers.Adam()
loss = audio_loss(gamma=gamma_loss, beta=beta_loss, people_num=people_num)
parallel_model.compile(loss=loss, optimizer=adam)
print(AV_model.summary())
history=parallel_model.fit_generator(generator=train_generator,
validation_data=val_generator,
epochs=epochs,
workers=workers,
use_multiprocessing=MultiProcess,
callbacks=[TensorBoard(log_dir='./log_AV'), checkpoint, rlr],
initial_epoch=initial_epoch
)
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
# summarize history for loss plt.plot(history.history['loss']) plt.plot(history.history['val_loss']) plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
if NUM_GPU <= 1:
adam = optimizers.Adam()
loss = audio_loss(gamma=gamma_loss, beta=beta_loss, people_num=people_num)
AV_model.compile(optimizer=adam, loss=loss)
print(AV_model.summary())
history=AV_model.fit_generator(generator=train_generator,
validation_data=val_generator,
epochs=epochs,
workers=workers,
use_multiprocessing=MultiProcess,
callbacks=[TensorBoard(log_dir='./log_AV'), checkpoint, rlr],
initial_epoch=initial_epoch
)
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
# summarize history for loss plt.plot(history.history['loss']) plt.plot(history.history['val_loss']) plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()