-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvalue_plot.py
40 lines (40 loc) · 1.49 KB
/
value_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import matplotlib.pyplot as plt
import numpy as np
"""
Input:
Q_tab : Tabulr Q (numpy matrix |S| by |A|)
env : an environment object (e.g. env = Maze())
isMaze : fixed to True
arrow : True if you want to plot arrows.s
"""
def value_plot(Q_tab, env, isMaze = True, arrow = True):
direction={0:(0,-0.4),1:(0,0.4),2:(-0.4,0),3:(0.4,0)} #(x,y) cooridnate
V = np.max(Q_tab,axis=1)
best_action = np.argmax(Q_tab,axis=1)
if isMaze:
idx2cell = env.idx2cell
for i in xrange(8):
f,ax = plt.subplots()
y_mat = np.zeros(env.dim)
for j in xrange(len(idx2cell)):
pos = idx2cell[j]
y_mat[pos[0], pos[1]] = V[8*j+i]
if arrow:
a = best_action[8*j+i]
ax.arrow(pos[1], pos[0], direction[a][0], direction[a][1],
head_width=0.05, head_length=0.1, fc='r', ec='r')
y_mat[env.goal_pos] = max(V)+0.1
ax.imshow(y_mat,cmap='gray')
else:
n = int(np.sqrt(len(V)))
tab = np.zeros((n,n))
for r in xrange(n):
for c in xrange(n):
if not(r==(n-1)and c==(n-1)):
tab[r,c] = V[n*c+r]
if arrow:
d = direction[best_action[n*c+r]]
plt.arrow(c,r,d[0],d[1], head_width=0.05, head_length=0.1, fc='r', ec='r')
tab[env.goal_pos] = max(V[:-1])+0.1
plt.imshow(tab,cmap='gray')
plt.show()