-
Notifications
You must be signed in to change notification settings - Fork 9
/
main_ft_phase_tcn.py
114 lines (96 loc) · 3.67 KB
/
main_ft_phase_tcn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
'''
Project: SelfSupSurg
-----
Copyright (c) University of Strasbourg, All Rights Reserved.
'''
import os
from shutil import copy2
import glob
import torch
import random
import numpy as np
import h5py
from utils.file_helpers import parse_config, create_argument_parser
from downstream_phase_tcn.trainers import LinearEvalTrainer, TeCNOTrainer
from downstream_phase_tcn.data_loader import create_data_loaders
import logging
logging.basicConfig(level=logging.INFO)
def collect_embeddings(path):
splits = ['train', 'test', 'val']
for split in splits:
feat_path = os.path.join(path, split)
if not os.path.exists(feat_path): continue
if os.path.exists(os.path.join(
path, 'extracted_features_{}.hdf5'.format(split))):
print('embeddings already collected:', feat_path)
continue
print('collecting embeddings:', feat_path)
finds = sorted(glob.glob(os.path.join(
feat_path, '*_inds.npy'
)))
ftargets = sorted(glob.glob(os.path.join(
feat_path, '*_targets.npy'
)))
ffeatures = sorted(glob.glob(os.path.join(
feat_path, '*_features.npy'
)))
inds = np.concatenate([np.load(f) for f in finds])
targets = np.concatenate([np.load(f) for f in ftargets])
features = np.concatenate([np.load(f) for f in ffeatures])
order = np.argsort(inds)
inds = inds[order]
targets = targets[order]
features = features[order]
# split video and frame
frame_func = np.vectorize(lambda x: int(str(x)[-8:]))
vid_func = np.vectorize(lambda x: int(str(x)[:-8]))
frame_id = frame_func(inds)
video_id = vid_func(inds)
total_size = len(inds)
embed_size = features.shape[1]
file_path = os.path.join(path, 'extracted_features_{}.hdf5'.format(split))
f = h5py.File(file_path, 'w')
# write data
f.create_dataset("frame_id", (total_size,), dtype='i', data=frame_id)
f.create_dataset("video_id", (total_size,), dtype='i', data=video_id)
f.create_dataset("embeddings", (total_size, embed_size), dtype='f', data=features)
f.create_dataset("targets", (total_size,), dtype='f', data=targets)
f.close()
return
def Main(args):
test_or_val = args.test_set
cfg_file = os.path.join('configs/config', args.hyper_params)
cfg = parse_config(cfg_file).config
print("cfg_file: ", cfg_file)
ckpt_dir = cfg.CHECKPOINT.DIR
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
#copy2(cfg_file, ckpt_dir)
num_epochs = cfg.OPTIMIZER.num_epochs
seed = cfg.SEED_VALUE
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
shuffle = True
feat_path = os.path.join(
cfg.MODEL.WEIGHTS_INIT.PARAMS_FILE, "extracted_features_Trunk"
)
if not os.path.exists(feat_path): feat_path = feat_path.replace('Trunk', 'Head')
collect_embeddings(feat_path)
trainer = LinearEvalTrainer if cfg.MODEL.name == 'FCN' else TeCNOTrainer
trainer = trainer(cfg)
train_loader, test_loader, val_loader = create_data_loaders(cfg, shuffle=shuffle)
if cfg.MODEL.name == 'FCN':
trainer.add_data_loaders(train_loader, test_loader, val_loader)
trainer.train()
trainer.test(data_splits=[test_or_val])
else:
trainer.add_data_loaders(train_loader, test_loader, val_loader)
trainer.train()
trainer.test(data_splits=['test'])
return
if __name__ == '__main__':
parser = create_argument_parser()
args = parser.parse_args()
Main(args)