forked from minghanz/EquivReg
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetricrecord.py
69 lines (60 loc) · 2.34 KB
/
metricrecord.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import numpy as np
from collections import defaultdict
class Record:
def __init__(self, name, num=10, highest=True) -> None:
self.name = name
self.num = num
self.highest = highest
self.vals = []
self.items = []
def update(self, val, item):
if len(self.vals) < self.num:
# Add the item and value, and then sort the list
self.vals.append(val)
self.items.append(item)
self.sort()
else:
if self.highest:
if val > self.vals[-1]:
# Replace the smallest value and item, and then sort the list
self.vals[-1] = val
self.items[-1] = item
self.sort()
else:
if val < self.vals[-1]:
# Replace the largest value and item, and then sort the list
self.vals[-1] = val
self.items[-1] = item
self.sort()
def sort(self):
sorted_indices = sorted(range(len(self.vals)), key=lambda i: self.vals[i], reverse=self.highest)
self.vals = [self.vals[i] for i in sorted_indices]
self.items = [self.items[i] for i in sorted_indices]
# combined = list(zip(self.vals, self.items)) # zip returns an iterable object, each element is a tuple
# combined.sort(reverse=self.highest)
# self.vals, self.items = map(list, zip(*combined)) # map apply list() to every tuple returned by zip
def __str__(self) -> str:
return f"{self.name}: {list(zip(self.items, self.vals))}"
class Metric:
def __init__(self, name) -> None:
self.name = name
self.val = 0
self.count = 0
def update(self, val):
if isinstance(val, torch.Tensor):
val = val.item()
self.val += val
self.count += 1
def avg(self):
return self.val / max(1, self.count)
def __str__(self) -> str:
return f"{self.name}: {self.avg()} (avg over {self.count})"
if __name__ == '__main__':
# Example usage:
record = Record("Top Scores", num=5, highest=False)
scores = [100, 50, 75, 120, 80, 60, 110]
names = ["Alice", "Bob", "Carol", "David", "Eve", "Frank", "Grace"]
for name, score in zip(names, scores):
record.update(score, name)
print(record)