-
Notifications
You must be signed in to change notification settings - Fork 0
/
vgg16_bidirectional_lstm_hi_dim_train.py
35 lines (24 loc) · 1.37 KB
/
vgg16_bidirectional_lstm_hi_dim_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np
from keras import backend as K
import sys
import os
def main():
K.set_image_dim_ordering('tf')
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from keras_video_classifier.library.recurrent_networks import VGG16BidirectionalLSTMVideoClassifier
from keras_video_classifier.library.utility.plot_utils import plot_and_save_history
from keras_video_classifier.library.utility.ucf.UCF101_loader import load_ucf
data_set_name = 'UCF-101'
input_dir_path = os.path.join(os.path.dirname(__file__), 'very_large_data')
output_dir_path = os.path.join(os.path.dirname(__file__), 'models', data_set_name)
report_dir_path = os.path.join(os.path.dirname(__file__), 'reports', data_set_name)
np.random.seed(42)
# this line downloads the video files of UCF-101 dataset if they are not available in the very_large_data folder
load_ucf(input_dir_path)
classifier = VGG16BidirectionalLSTMVideoClassifier()
history = classifier.fit(data_dir_path=input_dir_path, model_dir_path=output_dir_path, vgg16_include_top=False,
data_set_name=data_set_name)
plot_and_save_history(history, VGG16BidirectionalLSTMVideoClassifier.model_name,
report_dir_path + '/' + VGG16BidirectionalLSTMVideoClassifier.model_name + '-hi-dim-history.png')
if __name__ == '__main__':
main()