forked from eleurent/rl-agents
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdeterministic.py
139 lines (117 loc) · 5.43 KB
/
deterministic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import numpy as np
import logging
from rl_agents.agents.common.factory import safe_deepcopy_env
from rl_agents.agents.tree_search.abstract import Node, AbstractTreeSearchAgent, AbstractPlanner
logger = logging.getLogger(__name__)
class DeterministicNode(Node):
def __init__(self, parent, planner, state=None, depth=0):
super().__init__(parent, planner)
self.state = state
self.observation = None
self.depth = depth
self.reward = 0
self.value_upper = 0
self.value_lower = 0
self.count = 1
self.done = False
def selection_rule(self):
if not self.children:
return None
actions = list(self.children.keys())
index = self.random_argmax([self.children[a].get_value_lower_bound() for a in actions])
return actions[index]
def expand(self):
self.planner.leaves.remove(self)
if self.state is None:
raise Exception("The state should be set before expanding a node")
try:
actions = self.state.get_available_actions()
except AttributeError:
actions = range(self.state.action_space.n)
for action in actions:
self.children[action] = type(self)(self,
self.planner,
state=safe_deepcopy_env(self.state),
depth=self.depth + 1)
observation, reward, done, truncated, info = self.planner.step(self.children[action].state, action)
self.planner.leaves.append(self.children[action])
self.children[action].update(reward, done, observation)
def update(self, reward, done, observation=None):
if not np.all(0 <= reward) or not np.all(reward <= 1):
raise ValueError("This planner assumes that all rewards are normalized in [0, 1]")
gamma = self.planner.config["gamma"]
self.reward = reward
self.observation = observation
self.done = done
self.value_lower = self.parent.value_lower + (gamma ** (self.depth - 1)) * reward
self.value_upper = self.value_lower + (gamma ** self.depth) / (1 - gamma)
if isinstance(done, np.ndarray):
idx = np.where(done)
next_value = self.value_lower[idx] + \
self.planner.config["terminal_reward"] * (gamma ** self.depth) / (1 - gamma)
self.value_lower[idx] = next_value
self.value_upper[idx] = next_value
elif done:
self.value_lower = self.value_upper = self.value_lower + \
self.planner.config["terminal_reward"] * (gamma ** self.depth) / (1 - gamma)
for node in self.sequence():
node.count += 1
def backup_values(self):
if self.children:
backup_children = [child.backup_values() for child in self.children.values()]
self.value_lower = np.amax([b[0] for b in backup_children])
self.value_upper = np.amax([b[1] for b in backup_children])
return self.get_value_lower_bound(), self.get_value_upper_bound()
def backup_to_root(self):
if self.children:
self.value_lower = np.amax([child.get_value_lower_bound() for child in self.children.values()])
self.value_upper = np.amax([child.get_value_upper_bound() for child in self.children.values()])
if self.parent:
self.parent.backup_to_root()
def get_value_lower_bound(self):
return self.value_lower
def get_value_upper_bound(self):
return self.value_upper
def get_value(self) -> float:
return self.value_upper
class OptimisticDeterministicPlanner(AbstractPlanner):
NODE_TYPE = DeterministicNode
"""
An implementation of Optimistic Planning in Deterministic MDPs.
"""
def __init__(self, env, config=None):
super(OptimisticDeterministicPlanner, self).__init__(config)
self.env = env
self.leaves = None
def reset(self):
self.root = self.NODE_TYPE(None, planner=self)
self.leaves = [self.root]
def run(self):
"""
Run an OptimisticDeterministicPlanner episode
"""
leaf_to_expand = max(self.leaves, key=lambda n: n.get_value_upper_bound())
if leaf_to_expand.done:
logger.warning("Expanding a terminal state")
leaf_to_expand.expand()
leaf_to_expand.backup_to_root()
def plan(self, state, observation):
self.root.state = state
for epoch in np.arange(self.config["budget"] // state.action_space.n):
logger.debug("Expansion {}/{}".format(epoch + 1, self.config["budget"] // state.action_space.n))
self.run()
return self.get_plan()
def step_by_subtree(self, action):
super(OptimisticDeterministicPlanner, self).step_by_subtree(action)
if not self.root.children:
self.leaves = [self.root]
# v0 = r0 + g r1 + g^2 r2 +... and v1 = r1 + g r2 + ... = (v0-r0)/g
for leaf in self.leaves:
leaf.value_lower = (leaf.value_lower - self.root.reward) / self.config["gamma"]
leaf.value_upper_bound = (leaf.value_upper_bound - self.root.reward) / self.config["gamma"]
self.root.backup_values()
class DeterministicPlannerAgent(AbstractTreeSearchAgent):
"""
An agent that performs optimistic planning in deterministic MDPs.
"""
PLANNER_TYPE = OptimisticDeterministicPlanner