forked from TrackingLaboratory/tracklab
-
Notifications
You must be signed in to change notification settings - Fork 0
/
trackeval_evaluator.py
146 lines (125 loc) · 5.77 KB
/
trackeval_evaluator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import numpy as np
import logging
import trackeval
from pathlib import Path
from tabulate import tabulate
from tracklab.core import Evaluator as EvaluatorBase
from tracklab.utils import wandb
log = logging.getLogger(__name__)
class TrackEvalEvaluator(EvaluatorBase):
"""
Evaluator using the TrackEval library (https://github.com/JonathonLuiten/TrackEval).
Save on disk the tracking predictions and ground truth in MOT Challenge format and run the evaluation by calling TrackEval.
"""
def __init__(self, cfg, eval_set, show_progressbar, dataset_path, tracking_dataset, *args, **kwargs):
self.cfg = cfg
self.tracking_dataset = tracking_dataset
self.eval_set = eval_set
self.trackeval_dataset_name = cfg.dataset.dataset_class
self.trackeval_dataset_class = getattr(trackeval.datasets, self.trackeval_dataset_name)
self.show_progressbar = show_progressbar
self.dataset_path = dataset_path
def run(self, tracker_state):
log.info("Starting evaluation using TrackEval library (https://github.com/JonathonLuiten/TrackEval)")
tracker_name = 'tracklab'
save_classes = self.trackeval_dataset_class.__name__ != 'MotChallenge2DBox'
# Save predictions
pred_save_path = Path(self.cfg.dataset.TRACKERS_FOLDER) / f"{self.trackeval_dataset_class.__name__}-{self.eval_set}" / tracker_name
self.tracking_dataset.save_for_eval(
tracker_state.detections_pred,
tracker_state.image_pred,
tracker_state.video_metadatas,
pred_save_path,
self.cfg.bbox_column_for_eval,
save_classes, # do not use classes for MOTChallenge2DBox
is_ground_truth=False,
)
log.info(
f"Tracking predictions saved in {self.trackeval_dataset_name} format in {pred_save_path}")
if tracker_state.detections_gt is None or len(tracker_state.detections_gt) == 0:
log.warning(
f"Stopping evaluation because the current split ({self.eval_set}) has no ground truth detections.")
return
# Save ground truth
if self.cfg.save_gt:
self.tracking_dataset.save_for_eval(
tracker_state.detections_gt,
tracker_state.image_gt,
tracker_state.video_metadatas,
Path(self.cfg.dataset.GT_FOLDER) / f"{self.trackeval_dataset_name}-{self.eval_set}",
self.cfg.bbox_column_for_eval,
save_classes,
is_ground_truth=True
)
log.info(
f"Tracking ground truth saved in {self.trackeval_dataset_name} format in {pred_save_path}")
# Build TrackEval dataset
dataset_config = self.trackeval_dataset_class.get_default_dataset_config()
dataset_config['SEQ_INFO'] = tracker_state.video_metadatas.set_index('name')['nframes'].to_dict()
dataset_config['BENCHMARK'] = self.trackeval_dataset_class.__name__ # required for trackeval.datasets.MotChallenge2DBox
for key, value in self.cfg.dataset.items():
dataset_config[key] = value
if not self.cfg.save_gt:
dataset_config['GT_FOLDER'] = self.dataset_path # Location of GT data
dataset_config['GT_LOC_FORMAT'] = '{gt_folder}/{seq}/Labels-GameState.json' # '{gt_folder}/{seq}/gt/gt.txt'
dataset = self.trackeval_dataset_class(dataset_config)
# Build metrics
metrics_config = {'METRICS': set(self.cfg.metrics), 'PRINT_CONFIG': False, 'THRESHOLD': 0.5}
metrics_list = []
for metric_name in self.cfg.metrics:
try:
metric = getattr(trackeval.metrics, metric_name)
metrics_list.append(metric(metrics_config))
except AttributeError:
log.warning(f'Skipping evaluation for unknown metric: {metric_name}')
# Build evaluator
eval_config = trackeval.Evaluator.get_default_eval_config()
for key, value in self.cfg.eval.items():
eval_config[key] = value
evaluator = trackeval.Evaluator(eval_config)
# Run evaluation
output_res, output_msg = evaluator.evaluate([dataset], metrics_list, show_progressbar=self.show_progressbar)
# Log results
results = output_res[dataset.get_name()][tracker_name]
# if the dataset has the process_trackeval_results method, use it to process the results
if hasattr(self.tracking_dataset, 'process_trackeval_results'):
results = self.tracking_dataset.process_trackeval_results(results, dataset_config, eval_config)
wandb.log(results)
def _print_results(
res_combined,
res_by_video=None,
scale_factor=1.0,
title="",
print_by_video=False,
):
headers = res_combined.keys()
data = [
format_metric(name, res_combined[name], scale_factor)
for name in headers
]
log.info(f"{title}\n" + tabulate([data], headers=headers, tablefmt="plain"))
if print_by_video and res_by_video:
data = []
for video_name, res in res_by_video.items():
video_data = [video_name] + [
format_metric(name, res[name], scale_factor)
for name in headers
]
data.append(video_data)
headers = ["video"] + list(headers)
log.info(
f"{title} by videos\n"
+ tabulate(data, headers=headers, tablefmt="plain")
)
def format_metric(metric_name, metric_value, scale_factor):
if (
"TP" in metric_name
or "FN" in metric_name
or "FP" in metric_name
or "TN" in metric_name
):
if metric_name == "MOTP":
return np.around(metric_value * scale_factor, 3)
return int(metric_value)
else:
return np.around(metric_value * scale_factor, 3)