-
Notifications
You must be signed in to change notification settings - Fork 10
/
logger.py
69 lines (56 loc) · 2.37 KB
/
logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import sys
import random
import numpy as np
from collections import OrderedDict
from tabulate import tabulate
from pandas import DataFrame
from time import gmtime, strftime
class Logger:
def __init__(self, name='name', fmt=None):
self.handler = True
self.scalar_metrics = OrderedDict()
self.fmt = fmt if fmt else dict()
base = './logs'
if not os.path.exists(base): os.mkdir(base)
time = gmtime()
hash = ''.join([chr(random.randint(97, 122)) for _ in range(3)])
fname = '-'.join(sys.argv[0].split('/')[-3:])
self.path = '%s/%s-%s-%s-%s' % (base, fname, name, hash, strftime('%m-%d-%H:%M', time))
self.logs = self.path + '.csv'
self.output = self.path + '.out'
self.checkpoint = self.path + '.cpt'
def prin(*args):
str_to_write = ' '.join(map(str, args))
with open(self.output, 'a') as f:
f.write(str_to_write + '\n')
f.flush()
print(str_to_write)
sys.stdout.flush()
self.print = prin
def add_scalar(self, t, key, value):
if key not in self.scalar_metrics:
self.scalar_metrics[key] = []
self.scalar_metrics[key] += [(t, value)]
def iter_info(self, order=None):
names = list(self.scalar_metrics.keys())
if order:
names = order
values = [self.scalar_metrics[name][-1][1] for name in names]
t = int(np.max([self.scalar_metrics[name][-1][0] for name in names]))
fmt = ['%s'] + [self.fmt[name] if name in self.fmt else '.1f' for name in names]
if self.handler:
self.handler = False
self.print(tabulate([[t] + values], ['epoch'] + names, floatfmt=fmt))
else:
self.print(tabulate([[t] + values], ['epoch'] + names, tablefmt='plain', floatfmt=fmt).split('\n')[1])
def save(self):
result = None
for key in self.scalar_metrics.keys():
if result is None:
result = DataFrame(self.scalar_metrics[key], columns=['t', key]).set_index('t')
else:
df = DataFrame(self.scalar_metrics[key], columns=['t', key]).set_index('t')
result = result.join(df, how='outer')
result.to_csv(self.logs)
self.print('The log/output/model have been saved to: ' + self.path + ' + .csv/.out/.cpt')