-
Notifications
You must be signed in to change notification settings - Fork 0
/
robust.py
97 lines (70 loc) · 3.02 KB
/
robust.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import numpy as np
from rl_agents.agents.common.abstract import AbstractAgent
from rl_agents.agents.common.factory import load_agent, preprocess_env
from rl_agents.agents.tree_search.deterministic import DeterministicPlannerAgent, OptimisticDeterministicPlanner, \
DeterministicNode
class JointEnv(object):
def __init__(self, envs):
self.joint_state = envs
def step(self, action):
transitions = [state.step(action) for state in self.joint_state]
observations, rewards, terminals, info = zip(*transitions)
return observations, np.array(rewards), np.array(terminals), info
@property
def action_space(self):
return self.joint_state[0].action_space
def get_available_actions(self):
return list(set().union(*[s.get_available_actions() if hasattr(s, "get_available_actions")
else range(s.action_space.n)
for s in self.joint_state]))
class DiscreteRobustPlanner(OptimisticDeterministicPlanner):
def reset(self):
self.root = RobustNode(parent=None, planner=self)
self.leaves = [self.root]
class RobustNode(DeterministicNode):
def get_value_lower_bound(self):
return np.min(self.value_lower)
def get_value_upper_bound(self):
return np.min(self.value_upper)
class DiscreteRobustPlannerAgent(DeterministicPlannerAgent):
PLANNER_TYPE = DiscreteRobustPlanner
def __init__(self,
env,
config=None):
self.true_env = env
super().__init__(env, config)
@classmethod
def default_config(cls):
config = super().default_config()
config.update(dict(models=[]))
return config
def plan(self, observation):
envs = [preprocess_env(self.true_env, preprocessors) for preprocessors in self.config["models"]]
self.env = JointEnv(envs)
return super().plan(observation)
class IntervalRobustPlannerAgent(AbstractAgent):
def __init__(self, env, config=None):
super(IntervalRobustPlannerAgent, self).__init__(config)
self.env = env
self.sub_agent = load_agent(self.config['sub_agent_path'], env)
@classmethod
def default_config(cls):
return dict(sub_agent_path="",
env_preprocessors=[])
def act(self, observation):
return self.plan(observation)[0]
def plan(self, observation):
self.sub_agent.env = preprocess_env(self.env, self.config["env_preprocessors"])
return self.sub_agent.plan(observation)
def get_plan(self):
return self.sub_agent.planner.get_plan()
def reset(self):
return self.sub_agent.reset()
def seed(self, seed=None):
return self.sub_agent.seed(seed)
def save(self, filename):
return self.sub_agent.save(filename)
def load(self, filename):
return self.sub_agent.load(filename)
def record(self, state, action, reward, next_state, done, info):
return self.sub_agent.record(state, action, reward, next_state, done, info)