-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
203 lines (172 loc) · 6.87 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import os
import pickle
from collections import OrderedDict
from gym.wrappers.time_limit import TimeLimit
import numpy as np
import gym
import dm2gym
def create_log_dir(experiment_name):
path = os.path.join(os.getcwd(), experiment_name)
try:
os.makedirs(path, exist_ok=True)
except OSError:
print("Creation of the directory %s failed" % path)
else:
print("Successfully created the directory %s " % path)
return path
def remove_prefix(text, prefix):
if text.startswith(prefix):
return text[len(prefix):]
return text
def boolify(s):
if s == 'True':
return True
if s == 'False':
return False
raise ValueError("String '{}' is not a known bool value.".format(s))
def autoconvert(s):
for fn in (boolify, int, float):
try:
return fn(s)
except ValueError:
pass
return s
def update_param(params, arg_name, arg_value):
if arg_name not in params:
raise KeyError(
"Parameter '{}' specified, but not found in hyperparams file.".
format(arg_name))
else:
print("Updating parameter '{}' to {}".format(arg_name, arg_value))
converted_arg_value = autoconvert(arg_value)
if type(params[arg_name]) != type(converted_arg_value):
error_str = f"Old and new type must match! Got {type(converted_arg_value)}, expected {type(params[arg_name])}, for {arg_name}"
raise ValueError(error_str)
params[arg_name] = converted_arg_value
class DMSettlingWrapper(gym.Wrapper):
"""
Allows for many "no-op" actions before the actual episode begins. Complicated by the fact that
dm_control does internal "resets", and that we want to preserve the wrapping time-limit.
"""
def __init__(self, env, random_steps=5000, reset_on_reward=False):
assert isinstance(env, gym.wrappers.TimeLimit)
assert isinstance(env.env, dm2gym.envs.dm_suite_env.DMSuiteEnv)
super(DMSettlingWrapper, self).__init__(env)
print(f"old step limit: {env.env.env._step_limit}")
env.env.env._step_limit = float("inf")
print(f"new step limit: {env.env.env._step_limit}")
print(f"what is self env? {self.env}")
self._random_steps = random_steps
self.reset_on_reward = reset_on_reward
self.MAX_RESETS = 10
@staticmethod
def sanitize_kwargs(kwargs):
kwargs_copy = dict(kwargs)
try:
kwargs_copy.pop("num_resets")
except KeyError:
pass
return kwargs_copy
def unwrap_env(self):
env = self.env
assert isinstance(env, gym.wrappers.TimeLimit)
env = env.env
assert isinstance(env, dm2gym.envs.dm_suite_env.DMSuiteEnv)
return env
def reset(self, *args, **kwargs):
# num_resets breaks the super's reset
kwargs_copy = DMSettlingWrapper.sanitize_kwargs(kwargs)
state = super(DMSettlingWrapper, self).reset(*args, **kwargs_copy)
env = self.unwrap_env()
noop_action = np.zeros_like(np.array(env.action_space.sample()))
reward = 0
for _ in range(self._random_steps):
# print('doing noop thing')
state, reward, done, info = env.step(noop_action)
if done:
print("bad!")
raise Exception("you did something wrong!")
if reward != 0:
print("REWARD IS NOT ZERO ")
num_resets = kwargs.get("num_resets", 0)
if num_resets < self.MAX_RESETS:
kwargs["num_resets"] = num_resets + 1
print(f"Resetting # {kwargs['num_resets']}")
return self.reset(*args, **kwargs)
else:
print(
f"But continuing because {num_resets}/{self.MAX_RESETS} done already"
)
assert self.env._elapsed_steps == 0
return state
class DMSuiteUnwrapper(gym.Wrapper):
"""
Makes observation space correct as well, so the whole interface is correct
"""
def __init__(self, env):
super(DMSuiteUnwrapper, self).__init__(env)
self.observation_space = self.observation_space['observations']
def reset(self, *args, **kwargs):
state = super(DMSuiteUnwrapper, self).reset(*args, **kwargs)
assert isinstance(state, OrderedDict)
return state['observations']
def step(self, *args, **kwargs):
state, reward, done, info = super(DMSuiteUnwrapper,
self).step(*args, **kwargs)
assert isinstance(state, OrderedDict)
return state['observations'], reward, done, info
def render(self, *args, **kwargs):
kwargs['use_opencv_renderer'] = True
return super(DMSuiteUnwrapper, self).render(*args, **kwargs)
class FrameSkipWrapper(gym.Wrapper):
def __init__(self, env, skip:int=1):
print('frame skipping!')
assert isinstance(skip, int) and skip >= 1, f"Skip was {skip}"
super().__init__(env)
self._skip = skip
def step(self, action):
total_reward = 0.
done = False
for _ in range(self._skip):
# print('doing skip thing')
obs, r, done, info = self.env.step(action)
total_reward += r
if done:
break
return obs, total_reward, done, info
def make_env(env_name, step_limit=None, delay_time=-1, action_skip=1, seed=0):
"""
env_name: gets passed to gym.make
step_limit: goes into the TimeLimit wrapper
"""
import rbfdqn.tasks
if env_name.startswith("dm2gym:"):
if step_limit is not None:
raise ValueError("For now, no new step limits with dm_control stuff")
delay_time_dict = {
'dm2gym:CartpoleSwingup_sparse-v0': 0,
'dm2gym:PendulumSwingup-v0': 1000,
'dm2gym:Ball_in_cupCatch-v0': 1000,
'dm2gym:AcrobotSwingup_sparse-v0': 1000,
'dm2gym:HopperStand-v0': 1000,
}
if env_name not in delay_time_dict:
raise Exception(
f"Currently only can handle these four domains, got {env_name}"
)
delay_time = delay_time if delay_time >= 0 else delay_time_dict[env_name]
env = gym.make(env_name, environment_kwargs={'flat_observation': True}, task_kwargs={"random": seed})
env = DMSettlingWrapper(env,
random_steps=delay_time,
reset_on_reward=True)
env = DMSuiteUnwrapper(env)
else:
env = gym.make(env_name)
if step_limit is not None:
print(f"Setting step-limit for {env_name} to {step_limit}")
assert isinstance(env, gym.wrappers.TimeLimit), "for now assume it's time-limit wrapped"
env = env.unwrapped
env = TimeLimit(env, max_episode_steps=step_limit)
if action_skip != 1:
env = FrameSkipWrapper(env, skip=action_skip)
return env