forked from JJN123/Fall-Detection
-
Notifications
You must be signed in to change notification settings - Fork 1
/
dae_main_train.py
47 lines (29 loc) · 1.05 KB
/
dae_main_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from models import *
#from ImageExp import ImgExp
from ae_exp import AEExp
import numpy as np
def init_dae_exp(pre_load = None, regularizer_list = []):
'''
These are the training setting.
'''
batch_size = 16
epochs = 1
img_width, img_height = 64,64
hor_flip = False
initial_epoch = 0
dset = 'UR-Filled' #Choose data set here
autoencooder, model_name, model_type = DAE(img_width = img_width,
img_height = img_height, regularizer_list = regularizer_list)
DAE_exp = AEExp(model = autoencooder, img_width = img_width,\
img_height = img_height, model_name = model_name, model_type = model_type, \
pre_load = pre_load, initial_epoch = initial_epoch,\
epochs = epochs, batch_size = batch_size, dset = dset, hor_flip = hor_flip
)
return DAE_exp
if __name__ == "__main__":
regularizer_list_list = [['Dropout']] # Can use 'L1L2' aswell
for regularizer_list in regularizer_list_list:
DAE_exp = init_dae_exp(regularizer_list = regularizer_list)
DAE_exp.set_train_data(raw = False)
print(DAE_exp.train_data.shape)
DAE_exp.train()