forked from Neo-X/RL-Tutorials
-
Notifications
You must be signed in to change notification settings - Fork 0
/
NNVisualize.py
114 lines (93 loc) · 4.34 KB
/
NNVisualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import matplotlib.pyplot as plt
# from matplotlib import mpl
import numpy as np
# import matplotlib.animation as animation
import random
import sys
import json
class NNVisualize(object):
def __init__(self, title):
"""
Three plots
bellman error
average reward
discounted reward error
"""
self._title=title
"""
self._fig, (self._bellman_error_ax, self._reward_ax, self._discount_error_ax) = plt.subplots(3, 1, sharey=False, sharex=True)
self._bellman_error, = self._bellman_error_ax.plot([], [], linewidth=2.0)
self._bellman_error_std = self._bellman_error_ax.fill_between([0], [0], [1], facecolor='blue', alpha=0.5)
self._bellman_error_ax.set_title('Bellman Error')
self._bellman_error_ax.set_ylabel("Absolute Error")
self._reward, = self._reward_ax.plot([], [], linewidth=2.0)
self._reward_std = self._reward_ax.fill_between([0], [0], [1], facecolor='blue', alpha=0.5)
self._reward_ax.set_title('Mean Reward')
self._reward_ax.set_ylabel("Reward")
self._discount_error, = self._discount_error_ax.plot([], [], linewidth=2.0)
self._discount_error_std = self._discount_error_ax.fill_between([0], [0], [1], facecolor='blue', alpha=0.5)
self._discount_error_ax.set_title('Discount Error')
self._discount_error_ax.set_ylabel("Absolute Error")
plt.xlabel("Iteration")
self._fig.set_size_inches(8.0, 12.5, forward=True)
"""
def init(self):
"""
Three plots
bellman error
average reward
discounted reward error
"""
# self._fig, (self._bellman_error_ax, self._reward_ax, self._discount_error_ax) = plt.subplots(1, 1, sharey=False, sharex=True)
self._fig, (self._bellman_error_ax) = plt.subplots(1, 1, sharey=False, sharex=True)
self._bellman_error, = self._bellman_error_ax.plot([], [], linewidth=2.0)
self._bellman_error_std = self._bellman_error_ax.fill_between([0], [0], [1], facecolor='blue', alpha=0.5)
self._bellman_error_ax.set_title('Error')
self._bellman_error_ax.set_ylabel("Absolute Error")
"""
self._reward, = self._reward_ax.plot([], [], linewidth=2.0)
self._reward_std = self._reward_ax.fill_between([0], [0], [1], facecolor='blue', alpha=0.5)
self._reward_ax.set_title('Mean Reward')
self._reward_ax.set_ylabel("Reward")
self._discount_error, = self._discount_error_ax.plot([], [], linewidth=2.0)
self._discount_error_std = self._discount_error_ax.fill_between([0], [0], [1], facecolor='blue', alpha=0.5)
self._discount_error_ax.set_title('Discount Error')
self._discount_error_ax.set_ylabel("Absolute Error")
plt.xlabel("Iteration")
"""
self._fig.suptitle(self._title, fontsize=18)
self._fig.set_size_inches(8.0, 4.5, forward=True)
def updateLoss(self, error, std):
self._bellman_error.set_xdata(np.arange(len(error)))
self._bellman_error.set_ydata(error)
self._bellman_error_ax.collections.remove(self._bellman_error_std)
self._bellman_error_std = self._bellman_error_ax.fill_between(np.arange(len(error)), error - std, error + std, facecolor='blue', alpha=0.5)
self._bellman_error_ax.relim() # make sure all the data fits
self._bellman_error_ax.autoscale()
def show(self):
plt.show()
def redraw(self):
self._fig.canvas.draw()
def setInteractive(self):
plt.ion()
def setInteractiveOff(self):
plt.ioff()
def saveVisual(self, fileName):
self._fig.savefig(fileName+".svg")
if __name__ == "__main__":
datafile = sys.argv[1]
file = open(datafile)
trainData = json.load(file)
# print "Training data: " + str(trainingData)
file.close()
"""
trainData["mean_reward"]=[]
trainData["std_reward"]=[]
trainData["mean_bellman_error"]=[]
trainData["std_bellman_error"]=[]
trainData["mean_discount_error"]=[]
trainData["std_discount_error"]=[]
"""
rlv = NNVisualize()
rlv.updateLoss(np.array(trainData["mean_bellman_error"]), np.array(trainData["std_bellman_error"]))
rlv.show()