-
Notifications
You must be signed in to change notification settings - Fork 7
/
main.py
117 lines (103 loc) · 3.71 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import csv
import time
import numpy as np
from datasets.dataset_survival import Generic_MIL_Survival_Dataset
from utils.options import parse_args
from utils.util import get_split_loader, set_seed
from utils.loss import define_loss
from utils.optimizer import define_optimizer
from utils.scheduler import define_scheduler
def main(args):
# set random seed for reproduction
set_seed(args.seed)
# create results directory
results_dir = "./results/{dataset}/[{model}]-[{fusion}]-[{alpha}]-[{time}]".format(
dataset=args.dataset,
model=args.model,
fusion=args.fusion,
alpha=args.alpha,
time=time.strftime("%Y-%m-%d]-[%H-%M-%S"),
)
if not os.path.exists(results_dir):
os.makedirs(results_dir)
# 5-fold cross validation
header = ["folds", "fold 0", "fold 1", "fold 2", "fold 3", "fold 4", "mean", "std"]
best_epoch = ["best epoch"]
best_score = ["best cindex"]
# start 5-fold CV evaluation.
for fold in range(5):
# build dataset
dataset = Generic_MIL_Survival_Dataset(
csv_path="./csv/%s_all_clean.csv" % (args.dataset),
modal=args.modal,
OOM=args.OOM,
apply_sig=True,
data_dir=args.data_root_dir,
shuffle=False,
seed=args.seed,
patient_strat=False,
n_bins=4,
label_col="survival_months",
)
split_dir = os.path.join("./splits", args.which_splits, args.dataset)
train_dataset, val_dataset = dataset.return_splits(
from_id=False, csv_path="{}/splits_{}.csv".format(split_dir, fold)
)
train_loader = get_split_loader(
train_dataset,
training=True,
weighted=args.weighted_sample,
modal=args.modal,
batch_size=args.batch_size,
)
val_loader = get_split_loader(
val_dataset, modal=args.modal, batch_size=args.batch_size
)
print(
"training: {}, validation: {}".format(len(train_dataset), len(val_dataset))
)
# build model, criterion, optimizer, schedular
if args.model == "cmta":
from models.cmta.network import CMTA
from models.cmta.engine import Engine
print(train_dataset.omic_sizes)
model_dict = {
"omic_sizes": train_dataset.omic_sizes,
"n_classes": 4,
"fusion": args.fusion,
"model_size": args.model_size,
}
model = CMTA(**model_dict)
criterion = define_loss(args)
optimizer = define_optimizer(args, model)
scheduler = define_scheduler(args, optimizer)
engine = Engine(args, results_dir, fold)
else:
raise NotImplementedError(
"Model [{}] is not implemented".format(args.model)
)
# start training
score, epoch = engine.learning(
model, train_loader, val_loader, criterion, optimizer, scheduler
)
# save best score and epoch for each fold
best_epoch.append(epoch)
best_score.append(score)
# finish training
# mean and std
best_epoch.append("~")
best_epoch.append("~")
best_score.append(np.mean(best_score[1:6]))
best_score.append(np.std(best_score[1:6]))
csv_path = os.path.join(results_dir, "results.csv")
print("############", csv_path)
with open(csv_path, "w", encoding="utf-8", newline="") as fp:
writer = csv.writer(fp)
writer.writerow(header)
writer.writerow(best_epoch)
writer.writerow(best_score)
if __name__ == "__main__":
args = parse_args()
results = main(args)
print("finished!")