-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_tm_model.py
74 lines (67 loc) · 2.75 KB
/
train_tm_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import pandas as pd
import numpy as np
import pickle
import sys
sys.path.append(os.path.expanduser("~/TMPredictor/survival_tm/auton-survival"))
from sklearn.model_selection import ParameterGrid
from auton_survival.estimators import SurvivalModel
'''with open('data/x_train_memory_normed.pkl', 'rb') as f:
x_train_df = pickle.load(f)
with open('data/x_valid_memory_normed.pkl', 'rb') as f:
x_valid_df = pickle.load(f)
with open('data/outcomes_train_memory.pkl', 'rb') as f:
outcomes_train_df = pickle.load(f)
with open('data/outcomes_valid_memory.pkl', 'rb') as f:
outcomes_valid_df = pickle.load(f)'''
'''with open('data/x_train_df_normed.pkl', 'rb') as f:
x_train_df = pickle.load(f)
with open('data/x_valid_df_normed.pkl', 'rb') as f:
x_valid_df = pickle.load(f)
with open('data/outcomes_train_df.pkl', 'rb') as f:
outcomes_train_df = pickle.load(f)
with open('data/outcomes_valid_df.pkl', 'rb') as f:
outcomes_valid_df = pickle.load(f)'''
'''with open('data/x_train_future_normed.pkl', 'rb') as f:
x_train_df = pickle.load(f)
with open('data/x_valid_future_normed.pkl', 'rb') as f:
x_valid_df = pickle.load(f)
with open('data/outcomes_train_future.pkl', 'rb') as f:
outcomes_train_df = pickle.load(f)
with open('data/outcomes_valid_future.pkl', 'rb') as f:
outcomes_valid_df = pickle.load(f)'''
with open('data/x_train_pca.pkl', 'rb') as f:
x_train_df = pickle.load(f)
with open('data/x_valid_pca.pkl', 'rb') as f:
x_valid_df = pickle.load(f)
with open('data/outcomes_train_df.pkl', 'rb') as f:
outcomes_train_df = pickle.load(f)
with open('data/outcomes_valid_df.pkl', 'rb') as f:
outcomes_valid_df = pickle.load(f)
param_grid = {'k' : [3],
'iters': [200],
'distribution' : ['LogNormal'],
'learning_rate' : [1e-5],
'batch_size' : [1000],
'layers' : [
[1000, 1000]
]
}
params = ParameterGrid(param_grid)
models=[]
for i, param in enumerate(params):
print('Hyperparameter ' + str(i) + ' of ' + str(len(params)))
print(param)
model = SurvivalModel(model='dsm',
iters=param['iters'],
k=param['k'],
layers=param['layers'],
distribution=param['distribution'],
learning_rate=param['learning_rate'],
batch_size=param['batch_size']
)
#_, train_loss, val_loss = model.fit(x_train_df, outcomes_train_df, val_data=(x_valid_df, outcomes_valid_df))
_, train_loss, val_loss = model.fit(x_train_df, outcomes_train_df)
models.append([model, train_loss, val_loss, param])
with open('models/pca_models.pkl', 'wb') as f:
pickle.dump(models, f)