-
Notifications
You must be signed in to change notification settings - Fork 9
/
utils.py
107 lines (84 loc) · 2.69 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import logging
import random
import sys
from collections import OrderedDict
import numpy as np
import torch
def reset_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def to_cuda(obj):
if torch.is_tensor(obj):
return obj.cuda()
if isinstance(obj, tuple):
return tuple(to_cuda(t) for t in obj)
if isinstance(obj, list):
return [to_cuda(t) for t in obj]
if isinstance(obj, dict):
return {k: to_cuda(v) for k, v in obj.items()}
if isinstance(obj, (int, float, str)):
return obj
raise ValueError("'%s' has unsupported type '%s'" % (obj, type(obj)))
def get_logger():
time_format = "%m/%d %H:%M:%S"
fmt = "[%(asctime)s] %(levelname)s (%(name)s) %(message)s"
formatter = logging.Formatter(fmt, time_format)
logger = logging.getLogger()
if logger.hasHandlers():
logger.handlers.clear()
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
return logger
class AverageMeterGroup:
"""Average meter group for multiple average meters"""
def __init__(self):
self.meters = OrderedDict()
def update(self, data, n=1):
for k, v in data.items():
if k not in self.meters:
self.meters[k] = AverageMeter(k, ":4f")
self.meters[k].update(v, n=n)
def __getattr__(self, item):
return self.meters[item]
def __getitem__(self, item):
return self.meters[item]
def __str__(self):
return " ".join(str(v) for v in self.meters.values())
def summary(self):
return " ".join(v.summary() for v in self.meters.values())
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
"""
Initialization of AverageMeter
Parameters
----------
name : str
Name to display.
fmt : str
Format string to print the values.
"""
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def summary(self):
fmtstr = '{name}: {avg' + self.fmt + '}'
return fmtstr.format(**self.__dict__)