-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
192 lines (155 loc) · 5.83 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import os
import sys
import yaml
import json
import torch
import pickle
import shutil
import logging
import warnings
import argparse
from os import path
from datetime import datetime
from torchmetrics.classification import AUROC, Accuracy
from src.utility.builtin import ODTrainer, ODLightningCLI
from src.utility.notify import send_to_telegram
def parse_args(args=None):
parser = argparse.ArgumentParser()
parser.add_argument("model_cfg_path", type=str)
parser.add_argument("data_cfg_path", type=str)
parser.add_argument("model_ckpt_path", type=str)
parser.add_argument("--precision", type=str, default="16")
parser.add_argument("--devices", type=int, default=-1)
parser.add_argument("--notes", type=str, default='')
return parser.parse_args(args=args)
class StatsRecorder:
def __init__(self, label):
self.label = label
self.prob = 0
self.count = 0
def update(self, prob, label):
assert label == self.label
self.prob += prob
self.count += 1
def compute(self):
return {
"label": self.label,
"prob": self.prob / self.count
}
def configure_logging():
logging_fmt = "[%(levelname)s][%(filename)s:%(lineno)d]: %(message)s"
logging.basicConfig(level="INFO", format=logging_fmt)
warnings.filterwarnings(action="ignore")
@torch.inference_mode()
def inference_driver(cli, cfg_dir, ckpt_path, notes=None):
timestamp = datetime.now().strftime("%m%dT%H%M%S")
trainer = cli.trainer
# setup model
model = cli.model
try:
model = model.__class__.load_from_checkpoint(ckpt_path)
except Exception as e:
print(f"Unable to load model from checkpoint in strict mode: {e}")
print(f"Loading model from checkpoint in non-strict mode.")
model = model.__class__.load_from_checkpoint(ckpt_path, strict=False)
model.eval()
# setup dataset
datamodule = cli.datamodule
datamodule.prepare_data()
datamodule.affine_model(cli.model)
datamodule.setup('test')
stats = {}
report = {}
test_dataloaders = datamodule.test_dataloader()
for dts_name, dataloader in test_dataloaders.items():
# iterate all videos
auc_calc = AUROC(task="BINARY", num_classes=2)
acc_calc = Accuracy(task="BINARY", num_classes=2)
dataset = dataloader.dataset
dts_stats = {}
# perform ddp prediction
batch_results = trainer.predict(
model=model,
dataloaders=[dataloader]
)
gathered_results = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(gathered_results, batch_results)
torch.distributed.barrier()
if (trainer.is_global_zero):
# fetch predict results and aggregate.
for batch_results in gathered_results:
for batch_result in batch_results:
probs = batch_result["probs"]
names = batch_result["names"]
y = batch_result["y"]
for prob, label, name in zip(probs, y, names):
if (not name in dts_stats):
dts_stats[name] = StatsRecorder(label)
dts_stats[name].update(prob, label)
# compute the average probability.
for k in dts_stats:
dts_stats[k] = dts_stats[k].compute()
# add straying videos into metric calculation
for k, v in dataset.stray_videos.items():
dts_stats[k] = dict(
label=v,
prob=0.5,
stray=1
)
# compute the metric scores
dataset_labels = []
dataset_probs = []
for v in dts_stats.values():
dataset_labels.append(v["label"])
dataset_probs.append(v["prob"])
dataset_labels = torch.tensor(dataset_labels)
dataset_probs = torch.tensor(dataset_probs)
accuracy = acc_calc(dataset_probs, dataset_labels).item()
roc_auc = auc_calc(dataset_probs, dataset_labels).item()
accuracy = round(accuracy, 3)
roc_auc = round(roc_auc, 3)
logging.info(f'[{dts_name}] accuracy: {accuracy}, roc_auc: {roc_auc}')
stats[dts_name] = dts_stats
report[dts_name] = {
"accuracy": accuracy,
"roc_auc": roc_auc
}
if (trainer.is_global_zero):
# save report and stats.
with open(path.join(cfg_dir, f'report_{timestamp}.json'), "w") as f:
json.dump(report, f, sort_keys=True, indent=4, separators=(',', ': '))
with open(path.join(cfg_dir, f'stats_{timestamp}.pickle'), "wb") as f:
pickle.dump(stats, f)
if (not notes is None):
send_to_telegram(f"Inference for '{cfg_dir.split('/')[-2]}' Complete!(notes:{notes})")
send_to_telegram(json.dumps(report, sort_keys=True, indent=4, separators=(',', ': ')))
return report
if __name__ == "__main__":
configure_logging()
params = parse_args()
cli = ODLightningCLI(
run=False,
trainer_class=ODTrainer,
save_config_callback=None,
parser_kwargs={
"parser_mode": "omegaconf"
},
auto_configure_optimizers=False,
seed_everything_default=1019,
args=[
'-c', params.model_cfg_path,
'-c', params.data_cfg_path,
'--trainer.logger=null',
f'--trainer.devices={params.devices}',
f'--trainer.precision={params.precision}',
],
)
cfg_dir = os.path.split(params.model_cfg_path)[0]
ckpt_path = params.model_ckpt_path
notes = params.notes
inference_driver(
cli=cli,
cfg_dir=cfg_dir,
ckpt_path=ckpt_path,
notes=notes
)