-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathevaluator.py
31 lines (23 loc) · 1.13 KB
/
evaluator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from evaluator.register import metrics_dict
from evaluator.collector import DataStruct
from collections import OrderedDict
class Evaluator(object):
"""Evaluator is used to check parameter correctness, and summarize the results of all metrics."""
def __init__(self, config):
self.config = config
self.metrics = [metric.lower() for metric in self.config["metrics"]]
self.metric_class = {}
for metric in self.metrics:
self.metric_class[metric] = metrics_dict[metric](self.config)
def evaluate(self, dataobject: DataStruct):
"""calculate all the metrics. It is called at the end of each epoch
Args:
dataobject (DataStruct): It contains all the information needed for metrics.
Returns:
collections.OrderedDict: such as ``{'hit@20': 0.3824, 'recall@20': 0.0527, 'hit@10': 0.3153, 'recall@10': 0.0329, 'gauc': 0.9236}``
"""
result_dict = OrderedDict()
for metric in self.metrics:
metric_val = self.metric_class[metric].calculate_metric(dataobject)
result_dict.update(metric_val)
return result_dict