-
Notifications
You must be signed in to change notification settings - Fork 3
/
plot_all.py
102 lines (88 loc) · 3.03 KB
/
plot_all.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
'''
Script to plot all runs in a folder.
Say you have
data-trial/ppo/Pendulum-v0/190312_140449
data-trial/ppo/Pendulum-v0/190412_141241
...
This script will plot one line (for the desired data) for each run.
Each line will have a legend entry with the filename.
You can then press R to refresh the plot (e.g., if some trials are still running)
or P to save the plot as pdf,
or ESC to close plot and end the program.
'''
import os
import numpy as np
import argparse
import matplotlib.pyplot as plt
import matplotlib
plt.ion()
#matplotlib.use('TKagg')
import seaborn as sns
import itertools
import sys
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--col', type=int, help='index of the column to plot', required=True)
parser.add_argument('--src', type=str, help='folder where .dat files are stored', required=True)
parser.add_argument('--pdf', type=str, help='(optional) filename to save as pdf', required=False)
args = parser.parse_args()
folder = args.src
col = args.col
def savepdf():
if args.pdf is not None:
plt.savefig(args.pdf+".pdf", bbox_inches='tight', pad_inches=0)
def update():
plt.cla()
palette = itertools.cycle(sns.color_palette())
lines = itertools.cycle(["-","--","-.",":"])
l = []
for f in sorted(os.listdir(folder)):
if f.endswith(".dat"):
data_mat = np.loadtxt(os.path.join(folder, f))
if data_mat.shape[0] > 0:
try:
data = data_mat[:,col]
l.append(f)
except:
print('Cannot read', f)
continue
data = data[np.logical_and(~np.isnan(data), ~np.isinf(data))]
plt.plot(data, color=next(palette), alpha=0.7, linestyle=next(lines))
leg = plt.legend(handles=plt.gca().lines, labels=l, loc='best')
frame = leg.get_frame()
frame.set_facecolor('white')
plt.draw()
if len(l) == 0:
print('nothing to plot, quit')
sys.exit(0)
def handle(event):
if event.key == 'r' or event.key == 'R':
update()
print('refreshed')
if event.key == 'p' or event.key == 'P':
savepdf()
print('saved as pdf')
if event.key == 'escape':
print('quit')
sys.exit(0)
def handle_close(event):
sys.exit(0)
sns.set_context("paper")
sns.set_style('darkgrid', {'legend.frameon':True})
fig = plt.figure()
plt.axes()
picsize = fig.get_size_inches() / 1.3
fig.set_size_inches(picsize)
fig.canvas.mpl_connect('key_press_event', handle)
fig.canvas.mpl_connect('close_event', handle_close)
update()
savepdf()
while True:
command = input('')
if command == 'r' or command == 'R':
update()
print('refreshed')
elif command == 'p' or command == 'P':
savepdf()
else:
break