-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathTableSarsa.lua
72 lines (65 loc) · 2.02 KB
/
TableSarsa.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
-- Implement SARSA algorithm using a linear function approximator for on-line
-- policy control
local TableSarsa, parent = torch.class('rl.TableSarsa', 'rl.Sarsa')
function TableSarsa:__init(mdp_config, lambda)
parent.__init(self, mdp_config, lambda)
self.Ns = rl.VHash(self.mdp)
self.Nsa = rl.QHash(self.mdp)
self.q = rl.QHash(self.mdp)
self.eligibility = rl.QHash(self.mdp)
end
function TableSarsa:get_new_q()
return rl.QHash(self.mdp)
end
function TableSarsa:reset_eligibility()
self.eligibility = rl.QHash(self.mdp)
end
function TableSarsa:update_eligibility(s, a)
for _, state in pairs(self.mdp:get_all_states()) do
for _, action in pairs(self.mdp:get_all_actions()) do
self.eligibility:mult(
state,
action,
self.discount_factor*self.lambda)
end
end
self.eligibility:add(s, a, 1)
self.Ns:add(s, 1)
self.Nsa:add(s, a, 1)
end
local function get_step_size(self, state, action)
local value = self.Nsa:get_value(state, action)
if value == 0 then
return value
end
return 1. / value
end
function TableSarsa:td_update(td_error)
for _, state in pairs(self.mdp:get_all_states()) do
for _, action in pairs(self.mdp:get_all_actions()) do
local step_size = get_step_size(self, state, action)
local eligibility = self.eligibility:get_value(state, action)
self.q:add(
state,
action,
step_size * td_error * eligibility)
end
end
end
function TableSarsa:update_policy()
self.explorer = rl.DecayTableExplorer(
rl.MONTECARLOCONTROL_DEFAULT_N0,
self.Ns)
self.policy = rl.GreedyPolicy(
self.q,
self.explorer,
self.actions
)
end
function TableSarsa:__eq(other)
return torch.typename(self) == torch.typename(other)
and self.Ns == other.Ns
and self.Nsa == other.Nsa
and self.q == other.q
and self.eligibility == other.eligibility
end