-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
106 lines (86 loc) · 3.74 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Packages to import
import os
import pickle
# Check if file exists
def check_file(path, msg):
if os.path.exists(path):
file = open(path, 'rb')
results = pickle.load(file)
file.close()
return results
else:
raise RuntimeError(msg)
# Save dictionary to pickle file
def save(res, path):
with open(path, 'wb') as handle:
pickle.dump(res, handle, protocol=pickle.HIGHEST_PROTOCOL)
# Hyperparameter tuning. Gets all the possible combinations of parameters to train the base_model. Once found the
# best combination, stop using it
def parameter_combination():
hidden_size = [50, 100, 250, 500, 1000]
latent_dim = [5, 10, 25, 50, 100]
param_comb = []
for hidden in hidden_size:
for latent in latent_dim:
new_params = {'hidden_size': hidden, 'latent_dim': latent}
param_comb.append(new_params)
return param_comb
# Function that creates output directories for each task
def create_output_dir(task, args):
for dataset_name in args['datasets']:
if task == 'data_preprocessing':
os.makedirs(args['output_dir'] + dataset_name + '/', exist_ok=True)
elif task == 'sota_sa':
for model in args['sota_models']:
os.makedirs(
args['sota_output_dir'] + dataset_name + '/' + model + '/' + str(args['n_folds']) + '_folds/',
exist_ok=True)
elif task == 'savae_sa':
for params in args['param_comb']:
for seed in range(args['n_seeds']):
model_path = str(params['latent_dim']) + '_' + str(params['hidden_size']) + '/seed_' + str(seed)
os.makedirs(args['output_dir'] + dataset_name + '/' + str(args['n_folds']) + '_folds/' +
model_path + '/', exist_ok=True)
# Function that sets environment configuration
def run_args(task):
args = {}
# Data
datasets = []
dataset_name = 'all'
if dataset_name == 'all':
datasets = ['whas', 'support', 'gbsg', 'flchain', 'nwtco', 'metabric', 'pbc', 'std', 'pneumon','crlm']
else:
datasets.append(dataset_name)
args['datasets'] = datasets
print('[INFO] Datasets: ', datasets)
# Absolute path
abs_path = os.path.dirname(os.path.abspath(__file__)) + os.sep
# Depending on the task, set the arguments
if task == 'data_preprocessing':
args['output_dir'] = abs_path + '/data_preprocessing/data/'
args['input_dir'] = abs_path + '/data_preprocessing/raw_data/'
else:
args['input_dir'] = abs_path + '/data_preprocessing/data/'
# Training and testing configurations for savae and sota models
args['train'] = True
args['eval'] = True
args['early_stop'] = True
args['n_folds'] = 5
args['batch_size'] = 64
args['n_epochs'] = 3000
args['lr'] = 1e-3
args['time_distribution'] = ('weibull',2)
# SOTA models
args['sota_output_dir'] = abs_path + '/survival_analysis/output_sota_' + str(args['n_folds']) + '_folds_' + str(
args['batch_size']) + '_batch_size_' + args['time_distribution'][0] + '/'
model_name = 'all'
args['sota_models'] = ['coxph', 'deepsurv', 'deephit'] if model_name == 'all' else [model_name]
# SAVAE hyperparameters
args['n_threads'] = 24
args['n_seeds'] = 10
default_params = True
args['param_comb'] = [{'hidden_size': 50, 'latent_dim': 5}] if default_params else parameter_combination()
# SAVAE output folders
args['output_dir'] = abs_path + '/survival_analysis/output_savae_' + str(args['n_folds']) + '_folds_' + str(
args['batch_size']) + '_batch_size_' + args['time_distribution'][0] + '/'
return args