-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathwrappers.py
135 lines (115 loc) · 4.49 KB
/
wrappers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from gym.core import Wrapper
from gym.wrappers import Monitor as VideoMonitor
import time
import csv
import os.path as osp
import json
class ResultsWriter(object):
def __init__(self, filename, header='', extra_keys=()):
self.extra_keys = extra_keys
assert filename is not None
if not filename.endswith(Monitor.EXT):
if osp.isdir(filename):
filename = osp.join(filename, Monitor.EXT)
else:
filename = filename + "." + Monitor.EXT
self.f = open(filename, "wt")
if isinstance(header, dict):
header = '# {} \n'.format(json.dumps(header))
self.f.write(header)
self.logger = csv.DictWriter(self.f, fieldnames=('r', 'l', 't')+tuple(extra_keys))
self.logger.writeheader()
self.f.flush()
def write_row(self, epinfo):
if self.logger:
self.logger.writerow(epinfo)
self.f.flush()
class Reset(Wrapper):
"""
gym wrapper to reset environment on done
"""
def __init__(self, env):
Wrapper.__init__(self, env=env)
def step(self, action):
ob, rew, done, info = self.env.step(action)
if done:
ob = self.env.reset()
return (ob, rew, done, info)
class Monitor(Wrapper):
"""
gym wrapper that monitor the agent performance for every episode.
episode_rewards: reward obtained during different episodes
episode_lengths: length of different episodes
episode_time: time of different episodes
"""
EXT = "monitor.csv"
f = None
def __init__(self, env, filename, allow_early_resets=False, reset_keywords=(), info_keywords=()):
Wrapper.__init__(self, env=env)
self.tstart = time.time()
if filename:
self.results_writer = ResultsWriter(filename,
header={"t_start": time.time(), 'env_id' : env.spec and env.spec.id},
extra_keys=reset_keywords + info_keywords
)
else:
self.results_writer = None
self.reset_keywords = reset_keywords
self.info_keywords = info_keywords
self.allow_early_resets = allow_early_resets
self.rewards = None
self.needs_reset = True
self.episode_rewards = []
self.episode_lengths = []
self.episode_times = []
self.total_steps = 0
self.current_reset_info = {} # extra info about the current episode, that was passed in during reset()
def reset(self, **kwargs):
self.reset_state()
for k in self.reset_keywords:
v = kwargs.get(k)
if v is None:
raise ValueError('Expected you to pass kwarg %s into reset'%k)
self.current_reset_info[k] = v
return self.env.reset(**kwargs)
def reset_state(self):
if not self.allow_early_resets and not self.needs_reset:
raise RuntimeError("Tried to reset an environment before done. If you want to allow early resets, wrap your env with Monitor(env, path, allow_early_resets=True)")
self.rewards = []
self.needs_reset = False
def step(self, action):
if self.needs_reset:
raise RuntimeError("Tried to step environment that needs reset")
ob, rew, done, info = self.env.step(action)
self.update(ob, rew, done, info)
return (ob, rew, done, info)
def update(self, ob, rew, done, info):
self.rewards.append(rew)
if done:
self.needs_reset = True
ep_rew = sum(self.rewards)
ep_len = len(self.rewards)
ep_info = {"r": round(ep_rew, 6), "l": ep_len, "t": round(time.time() - self.tstart, 6)}
for k in self.info_keywords:
ep_info[k] = info[k]
self.episode_rewards.append(ep_rew)
self.episode_lengths.append(ep_len)
self.episode_times.append(time.time() - self.tstart)
ep_info.update(self.current_reset_info)
if self.results_writer:
self.results_writer.write_row(ep_info)
assert isinstance(info, dict)
if isinstance(info, dict):
info['episode'] = ep_info
self.total_steps += 1
def close(self):
if self.f is not None:
self.f.close()
def get_total_steps(self):
return self.total_steps
def get_episode_rewards(self):
return self.episode_rewards
def get_episode_lengths(self):
return self.episode_lengths
def get_episode_times(self):
return self.episode_times