-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathTestMdpQVAnalyzer.lua
58 lines (52 loc) · 1.69 KB
/
TestMdpQVAnalyzer.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
local TestMdpQVAnalyzer, parent =
torch.class('rl.TestMdpQVAnalyzer', 'rl.QVAnalyzer')
function TestMdpQVAnalyzer:__init()
parent.__init(self, rl.TestMdp())
self.n_states = #self.mdp.get_all_states()
self.n_actions = #self.mdp.get_all_actions()
end
function TestMdpQVAnalyzer:get_v_tensor(v)
local tensor = torch.zeros(self.n_states)
for s = 1, self.n_states do
tensor[s] = v:get_value(s)
end
return tensor
end
function TestMdpQVAnalyzer:plot_v(v)
local tensor = self:get_v_tensor(v)
local x = torch.Tensor(self.n_states)
x = rl.util.apply_to_slices(x, 1, rl.util.fill_range, 0)
gnuplot.plot(x, tensor)
gnuplot.xlabel('Dealer Showing')
gnuplot.ylabel('State Value')
gnuplot.title('Monte-Carlo State Value Function')
end
function TestMdpQVAnalyzer:get_q_tensor(q)
local tensor = torch.zeros(self.n_states, self.n_actions)
for s = 1, self.n_states do
for a = 1, self.n_actions do
tensor[s][a] = q:get_value(s, a)
end
end
return tensor
end
function TestMdpQVAnalyzer:plot_best_action(q)
local best_action_at_state = torch.Tensor(self.n_states)
for s = 1, self.n_states do
best_action_at_state[s] = q:get_best_action(s)
end
local x = torch.Tensor(self.n_actions)
x = rl.util.apply_to_slices(x, 1, rl.util.fill_range, 0)
gnuplot.plot(x, best_action_at_state)
gnuplot.xlabel('State')
gnuplot.zlabel('Best Action')
gnuplot.title('Learned Best Action Based on q')
end
function TestMdpQVAnalyzer:v_from_q(q)
local v = rl.VHash(self.mdp)
for s = 1, self.n_states do
local a = q:get_best_action(s)
v:add(s, q:get_value(s, a))
end
return v
end