forked from Scalsol/mega.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
160 lines (142 loc) · 5.28 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import logging
import time
import os
import torch
from tqdm import tqdm
from mega_core.data.datasets.evaluation import evaluate
from ..utils.comm import is_main_process, get_world_size
from ..utils.comm import all_gather
from ..utils.comm import synchronize
from ..utils.timer import Timer, get_time_str
from .bbox_aug import im_detect_bbox_aug
def compute_on_dataset(model, data_loader, device, bbox_aug, method, timer=None):
model.eval()
results_dict = {}
cpu_device = torch.device("cpu")
for i, batch in enumerate(tqdm(data_loader)):
images, targets, image_ids = batch
with torch.no_grad():
if timer:
timer.tic()
if bbox_aug:
output = im_detect_bbox_aug(model, images, device)
else:
if method in ("base", ):
images = images.to(device)
elif method in ("rdn", "mega", "fgfa", "dff"):
images["cur"] = images["cur"].to(device)
for key in ("ref", "ref_l", "ref_m", "ref_g"):
if key in images.keys():
images[key] = [img.to(device) for img in images[key]]
else:
raise ValueError("method {} not supported yet.".format(method))
output = model(images)
if timer:
if not device.type == 'cpu':
torch.cuda.synchronize()
timer.toc()
output = [o.to(cpu_device) for o in output]
results_dict.update(
{img_id: result for img_id, result in zip(image_ids, output)}
)
return results_dict
def _accumulate_predictions_from_multiple_gpus(predictions_per_gpu):
all_predictions = all_gather(predictions_per_gpu)
if not is_main_process():
return
# merge the list of dicts
predictions = {}
for p in all_predictions:
predictions.update(p)
# convert a dict where the key is the index in a list
image_ids = list(sorted(predictions.keys()))
if len(image_ids) != image_ids[-1] + 1:
logger = logging.getLogger("mega_core.inference")
logger.warning(
"Number of images that were gathered from multiple processes is not "
"a contiguous set. Some images might be missing from the evaluation"
)
# convert to a list
predictions = [predictions[i] for i in image_ids]
return predictions
def inference(
cfg,
model,
data_loader,
dataset_name,
iou_types=("bbox",),
motion_specific=False,
box_only=False,
bbox_aug=False,
device="cuda",
expected_results=(),
expected_results_sigma_tol=4,
output_folder=None,
):
# convert to a torch.device for efficiency
device = torch.device(device)
num_devices = get_world_size()
logger = logging.getLogger("mega_core.inference")
dataset = data_loader.dataset
logger.info("Start evaluation on {} dataset({} images).".format(dataset_name, len(dataset)))
total_timer = Timer()
inference_timer = Timer()
total_timer.tic()
predictions = compute_on_dataset(model, data_loader, device, bbox_aug, cfg.MODEL.VID.METHOD, inference_timer)
# wait for all processes to complete before measuring the time
synchronize()
total_time = total_timer.toc()
total_time_str = get_time_str(total_time)
logger.info(
"Total run time: {} ({} s / img per device, on {} devices)".format(
total_time_str, total_time * num_devices / len(dataset), num_devices
)
)
total_infer_time = get_time_str(inference_timer.total_time)
logger.info(
"Model inference time: {} ({} s / img per device, on {} devices)".format(
total_infer_time,
inference_timer.total_time * num_devices / len(dataset),
num_devices,
)
)
predictions = _accumulate_predictions_from_multiple_gpus(predictions)
if not is_main_process():
return
if output_folder:
torch.save(predictions, os.path.join(output_folder, "predictions.pth"))
extra_args = dict(
box_only=box_only,
iou_types=iou_types,
motion_specific=motion_specific,
expected_results=expected_results,
expected_results_sigma_tol=expected_results_sigma_tol,
)
return evaluate(dataset=dataset,
predictions=predictions,
output_folder=output_folder,
**extra_args)
def inference_no_model(
data_loader,
iou_types=("bbox",),
motion_specific=False,
box_only=False,
expected_results=(),
expected_results_sigma_tol=4,
output_folder=None,
):
dataset = data_loader.dataset
predictions = torch.load(os.path.join(output_folder, "predictions.pth"))
print("prediction loaded.")
extra_args = dict(
box_only=box_only,
iou_types=iou_types,
motion_specific=motion_specific,
expected_results=expected_results,
expected_results_sigma_tol=expected_results_sigma_tol,
)
return evaluate(dataset=dataset,
predictions=predictions,
output_folder=output_folder,
**extra_args)