Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix missing seed / option args in dummy and subproc vec_env resets during stepping #1805

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
4 changes: 3 additions & 1 deletion docs/guide/vec_envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ SB3 VecEnv API is actually close to Gym 0.21 API but differs to Gym 0.26+ API:
Note that if ``render_mode != "rgb_array"``, you can only call ``vec_env.render()`` (without argument or with ``mode=env.render_mode``).

- the ``reset()`` method doesn't take any parameter. If you want to seed the pseudo-random generator or pass options,
you should call ``vec_env.seed(seed=seed)``/``vec_env.set_options(options)`` and ``obs = vec_env.reset()`` afterward (seed and options are discarded after each call to ``reset()``).
you should call ``vec_env.seed(seed=seed)``/``vec_env.set_options(options)``.
Seed and options parameters will be passed to the next call to ``obs = vec_env.reset()`` and any implicit environment reset invoked by episode termination / truncation.
The provided seed and options will be discarded after each call to ``vec_env.reset()``.
qgallouedec marked this conversation as resolved.
Show resolved Hide resolved

- methods and attributes of the underlying Gym envs can be accessed, called and set using ``vec_env.get_attr("attribute_name")``,
``vec_env.env_method("method_name", args1, args2, kwargs1=kwargs1)`` and ``vec_env.set_attr("attribute_name", new_value)``.
Expand Down
33 changes: 33 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,39 @@
Changelog
==========


Release 2.4.0a1 (WIP)
--------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^

New Features:
^^^^^^^^^^^^^

Bug Fixes:
^^^^^^^^^^
- Fixed seed / options argument passing to environment resets in ``vec_env.reset()``

`SB3-Contrib`_
^^^^^^^^^^^^^^

`RL Zoo`_
^^^^^^^^^

`SBX`_ (SB3 + Jax)
^^^^^^^^^^^^^^^^^^

Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^

Documentation:
^^^^^^^^^^^^^^
- Expanded the description for vec_env.reset seed and options passing

Release 2.3.0 (2024-03-31)
--------------------------

Expand Down
5 changes: 4 additions & 1 deletion stable_baselines3/common/vec_env/dummy_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ def step_wait(self) -> VecEnvStepReturn:
if self.buf_dones[env_idx]:
# save final observation where user can get it, then reset
self.buf_infos[env_idx]["terminal_observation"] = obs
obs, self.reset_infos[env_idx] = self.envs[env_idx].reset()
# reset the environment, supplying seed and options
seed = self._seeds[env_idx]
options = self._options[env_idx]
obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(seed=seed, options=options)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as I wrote in my first comment: "PS: options should not be required to work with SB3, so you should in theory make it opt-in (see what I did for the reset())"

self._save_obs(env_idx, obs)
return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos))

Expand Down
9 changes: 5 additions & 4 deletions stable_baselines3/common/vec_env/subproc_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,15 @@ def _worker(
try:
cmd, data = remote.recv()
if cmd == "step":
observation, reward, terminated, truncated, info = env.step(data)
action, seed, options = data
observation, reward, terminated, truncated, info = env.step(action)
# convert to SB3 VecEnv api
done = terminated or truncated
info["TimeLimit.truncated"] = truncated and not terminated
if done:
# save final observation where user can get it, then reset
info["terminal_observation"] = observation
observation, reset_info = env.reset()
observation, reset_info = env.reset(seed=seed, options=options)
remote.send((observation, reward, done, info, reset_info))
elif cmd == "reset":
maybe_options = {"options": data[1]} if data[1] else {}
Expand Down Expand Up @@ -121,8 +122,8 @@ def __init__(self, env_fns: List[Callable[[], gym.Env]], start_method: Optional[
super().__init__(len(env_fns), observation_space, action_space)

def step_async(self, actions: np.ndarray) -> None:
for remote, action in zip(self.remotes, actions):
remote.send(("step", action))
for remote, action, seed, option in zip(self.remotes, actions, self._seeds, self._options):
remote.send(("step", (action, seed, option)))
qgallouedec marked this conversation as resolved.
Show resolved Hide resolved
self.waiting = True

def step_wait(self) -> VecEnvStepReturn:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def __init__(self, delay: float = 0.01):
self.observation_space = spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32)
self.action_space = spaces.Discrete(2)

def reset(self, seed=None):
def reset(self, seed=None, options=None):
return self.observation_space.sample(), {}

def step(self, action):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self):
self.observation_space = SubClassedBox(-1, 1, shape=(2,), dtype=np.float32)
self.action_space = SubClassedBox(-1, 1, shape=(2,), dtype=np.float32)

def reset(self, seed=None):
def reset(self, seed=None, options=None):
return self.observation_space.sample(), {}

def step(self, action):
Expand Down
Loading