-
Notifications
You must be signed in to change notification settings - Fork 40
/
make_final_split.py
77 lines (55 loc) · 2.11 KB
/
make_final_split.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import numpy as np
import hashlib
import utils
import utils_lung
import pathfinder
rng = np.random.RandomState(42)
tvt_ids = utils.load_pkl(pathfinder.VALIDATION_SPLIT_PATH)
train_pids, valid_pids, test_pids = tvt_ids['training'], tvt_ids['validation'], tvt_ids['test']
all_pids = train_pids + valid_pids + test_pids
print 'total number of pids', len(all_pids)
id2label = utils_lung.read_labels(pathfinder.LABELS_PATH)
id2label_test = utils_lung.read_test_labels(pathfinder.TEST_LABELS_PATH)
id2label.update(id2label_test)
n_patients = len(id2label)
pos_ids = []
neg_ids = []
for pid, label in id2label.iteritems():
if label ==1 :
pos_ids.append(pid)
elif label == 0 :
neg_ids.append(pid)
else:
raise ValueError("weird shit is going down")
pos_ratio = 1. * len(pos_ids) / n_patients
print 'pos id ratio', pos_ratio
split_ratio = 0.15
n_target_split = int(np.round(split_ratio*n_patients))
print 'given split ratio', split_ratio
print 'target split ratio', 1. * n_target_split / n_patients
n_pos_ftest = int(np.round(split_ratio*len(pos_ids)))
n_neg_ftest = int(np.round(split_ratio*len(neg_ids)))
final_pos_test = rng.choice(pos_ids,n_pos_ftest, replace=False)
final_neg_test = rng.choice(neg_ids,n_neg_ftest, replace=False)
final_test = np.append(final_pos_test,final_neg_test)
print 'pos id ratio final test set', 1.*len(final_pos_test) / (len(final_test))
final_train = []
final_pos_train = []
final_neg_train = []
for pid in all_pids:
if pid not in final_test:
final_train.append(pid)
if id2label[pid] == 1:
final_pos_train.append(pid)
elif id2label[pid] == 0:
final_neg_train.append(pid)
else:
raise ValueError("weird shit is going down")
print 'pos id ratio final train set', 1.*len(final_pos_train) / (len(final_train))
print 'final test/(train+test):', 1.*len(final_test) / (len(final_train) + len(final_test))
concat_str = ''.join(final_test)
print 'md5 of concatenated pids:', hashlib.md5(concat_str).hexdigest()
output = {'train':final_train, 'test':final_test}
output_name = pathfinder.METADATA_PATH+'final_split.pkl'
utils.save_pkl(output, output_name)
print 'final split saved at ', output_name