-
Notifications
You must be signed in to change notification settings - Fork 12
/
beautify.py
68 lines (53 loc) · 1.77 KB
/
beautify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import argparse
import tensorflow as tf
from tensorflow.python.summary import event_accumulator as ea
from matplotlib import pyplot as plt
from matplotlib import colors as colors
import seaborn as sns
sns.set(style="darkgrid")
sns.set_context("paper")
def plot(params):
''' beautify tf log
Use better library (seaborn) to plot tf event file'''
log_path = params['logdir']
smooth_space = params['smooth']
color_code = params['color']
acc = ea.EventAccumulator(log_path)
acc.Reload()
# only support scalar now
scalar_list = acc.Tags()['scalars']
x_list = []
y_list = []
x_list_raw = []
y_list_raw = []
for tag in scalar_list:
x = [int(s.step) for s in acc.Scalars(tag)]
y = [s.value for s in acc.Scalars(tag)]
# smooth curve
x_ = []
y_ = []
for i in range(0, len(x), smooth_space):
x_.append(x[i])
y_.append(sum(y[i:i+smooth_space]) / float(smooth_space))
x_.append(x[-1])
y_.append(y[-1])
x_list.append(x_)
y_list.append(y_)
# raw curve
x_list_raw.append(x)
y_list_raw.append(y)
for i in range(len(x_list)):
plt.figure(i)
plt.subplot(111)
plt.title(scalar_list[i])
plt.plot(x_list_raw[i], y_list_raw[i], color=colors.to_rgba(color_code, alpha=0.4))
plt.plot(x_list[i], y_list[i], color=color_code, linewidth=1.5)
plt.show()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--logdir', default='./logdir', type=str, help='logdir to event file')
parser.add_argument('--smooth', default=100, type=float, help='window size for average smoothing')
parser.add_argument('--color', default='#4169E1', type=str, help='HTML code for the figure')
args = parser.parse_args()
params = vars(args) # convert to ordinary dict
plot(params)