From 8a46c3ab39f424519031a3103b7c2b4f97155f33 Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Thu, 28 Nov 2024 12:18:55 +0000 Subject: [PATCH] Add support for different vector autoreset modes (#1227) --- gymnasium/envs/classic_control/cartpole.py | 3 +- gymnasium/envs/functional_jax_env.py | 3 +- gymnasium/envs/phys2d/cartpole.py | 8 +- gymnasium/envs/phys2d/pendulum.py | 8 +- gymnasium/envs/registration.py | 11 + gymnasium/envs/tabular/blackjack.py | 2 + gymnasium/envs/tabular/cliffwalking.py | 2 + gymnasium/vector/__init__.py | 2 + gymnasium/vector/async_vector_env.py | 87 ++++- gymnasium/vector/sync_vector_env.py | 120 ++++-- gymnasium/vector/vector_env.py | 38 +- gymnasium/wrappers/vector/common.py | 60 ++- .../wrappers/vector/stateful_observation.py | 30 +- .../wrappers/vector/vectorize_observation.py | 31 +- pyproject.toml | 2 +- tests/envs/registration/test_make_vec.py | 44 ++- tests/testing_env.py | 1 - tests/vector/test_autoreset_mode.py | 343 ++++++++++++++++++ tests/vector/test_vector_env.py | 111 +++++- tests/vector/test_vector_env_info.py | 2 +- tests/wrappers/vector/test_vector_wrappers.py | 9 +- 21 files changed, 854 insertions(+), 63 deletions(-) create mode 100644 tests/vector/test_autoreset_mode.py diff --git a/gymnasium/envs/classic_control/cartpole.py b/gymnasium/envs/classic_control/cartpole.py index c3e3e7781..9bd08a015 100644 --- a/gymnasium/envs/classic_control/cartpole.py +++ b/gymnasium/envs/classic_control/cartpole.py @@ -13,7 +13,7 @@ from gymnasium import logger, spaces from gymnasium.envs.classic_control import utils from gymnasium.error import DependencyNotInstalled -from gymnasium.vector import VectorEnv +from gymnasium.vector import AutoresetMode, VectorEnv from gymnasium.vector.utils import batch_space @@ -355,6 +355,7 @@ class CartPoleVectorEnv(VectorEnv): metadata = { "render_modes": ["rgb_array"], "render_fps": 50, + "autoreset_mode": AutoresetMode.NEXT_STEP, } def __init__( diff --git a/gymnasium/envs/functional_jax_env.py b/gymnasium/envs/functional_jax_env.py index 1db7a0281..81b1549a5 100644 --- a/gymnasium/envs/functional_jax_env.py +++ b/gymnasium/envs/functional_jax_env.py @@ -12,6 +12,7 @@ from gymnasium.envs.registration import EnvSpec from gymnasium.experimental.functional import ActType, FuncEnv, StateType from gymnasium.utils import seeding +from gymnasium.vector import AutoresetMode from gymnasium.vector.utils import batch_space @@ -115,7 +116,7 @@ def __init__( """Initialize the environment from a FuncEnv.""" super().__init__() if metadata is None: - metadata = {} + metadata = {"autoreset_mode": AutoresetMode.NEXT_STEP} self.func_env = func_env self.num_envs = num_envs diff --git a/gymnasium/envs/phys2d/cartpole.py b/gymnasium/envs/phys2d/cartpole.py index 99c3b5f44..fa4cd1dbc 100644 --- a/gymnasium/envs/phys2d/cartpole.py +++ b/gymnasium/envs/phys2d/cartpole.py @@ -15,6 +15,7 @@ from gymnasium.error import DependencyNotInstalled from gymnasium.experimental.functional import ActType, FuncEnv, StateType from gymnasium.utils import EzPickle +from gymnasium.vector import AutoresetMode RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock"] # type: ignore # noqa: F821 @@ -272,7 +273,12 @@ def __init__(self, render_mode: str | None = None, **kwargs: Any): class CartPoleJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle): """Jax-based implementation of the vectorized CartPole environment.""" - metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True} + metadata = { + "render_modes": ["rgb_array"], + "render_fps": 50, + "jax": True, + "autoreset_mode": AutoresetMode.NEXT_STEP, + } def __init__( self, diff --git a/gymnasium/envs/phys2d/pendulum.py b/gymnasium/envs/phys2d/pendulum.py index 2e2538263..160049478 100644 --- a/gymnasium/envs/phys2d/pendulum.py +++ b/gymnasium/envs/phys2d/pendulum.py @@ -16,6 +16,7 @@ from gymnasium.error import DependencyNotInstalled from gymnasium.experimental.functional import ActType, FuncEnv, StateType from gymnasium.utils import EzPickle +from gymnasium.vector import AutoresetMode RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]] # type: ignore # noqa: F821 @@ -225,7 +226,12 @@ def get_default_params(self, **kwargs) -> PendulumParams: class PendulumJaxEnv(FunctionalJaxEnv, EzPickle): """Jax-based pendulum environment using the functional version as base.""" - metadata = {"render_modes": ["rgb_array"], "render_fps": 30, "jax": True} + metadata = { + "render_modes": ["rgb_array"], + "render_fps": 30, + "jax": True, + "autoreset_mode": AutoresetMode.NEXT_STEP, + } def __init__(self, render_mode: str | None = None, **kwargs: Any): """Constructor where the kwargs are passed to the base environment to modify the parameters.""" diff --git a/gymnasium/envs/registration.py b/gymnasium/envs/registration.py index de00ac1c4..1602d2fdd 100644 --- a/gymnasium/envs/registration.py +++ b/gymnasium/envs/registration.py @@ -19,6 +19,8 @@ import gymnasium as gym from gymnasium import Env, Wrapper, error, logger +from gymnasium.logger import warn +from gymnasium.vector import AutoresetMode if sys.version_info < (3, 10): @@ -976,6 +978,15 @@ def create_single_env() -> Env: copied_id_spec.kwargs["wrappers"] = wrappers env.unwrapped.spec = copied_id_spec + if "autoreset_mode" not in env.metadata: + warn( + f"The VectorEnv ({env}) is missing AutoresetMode metadata, metadata={env.metadata}" + ) + elif not isinstance(env.metadata["autoreset_mode"], AutoresetMode): + warn( + f"The VectorEnv ({env}) metadata['autoreset_mode'] is not an instance of AutoresetMode, {type(env.metadata['autoreset_mode'])}." + ) + return env diff --git a/gymnasium/envs/tabular/blackjack.py b/gymnasium/envs/tabular/blackjack.py index 714bfea07..8765bbd39 100644 --- a/gymnasium/envs/tabular/blackjack.py +++ b/gymnasium/envs/tabular/blackjack.py @@ -16,6 +16,7 @@ from gymnasium.error import DependencyNotInstalled from gymnasium.experimental.functional import ActType, FuncEnv, StateType from gymnasium.utils import EzPickle, seeding +from gymnasium.vector import AutoresetMode from gymnasium.wrappers import HumanRendering @@ -239,6 +240,7 @@ class BlackjackFunctional( metadata = { "render_modes": ["rgb_array"], "render_fps": 4, + "autoreseet-mode": AutoresetMode.NEXT_STEP, } def transition( diff --git a/gymnasium/envs/tabular/cliffwalking.py b/gymnasium/envs/tabular/cliffwalking.py index 511dabe21..f8ec04c1b 100644 --- a/gymnasium/envs/tabular/cliffwalking.py +++ b/gymnasium/envs/tabular/cliffwalking.py @@ -15,6 +15,7 @@ from gymnasium.error import DependencyNotInstalled from gymnasium.experimental.functional import ActType, FuncEnv, StateType from gymnasium.utils import EzPickle +from gymnasium.vector import AutoresetMode from gymnasium.wrappers import HumanRendering @@ -136,6 +137,7 @@ class CliffWalkingFunctional( metadata = { "render_modes": ["rgb_array"], "render_fps": 4, + "autoreset_mode": AutoresetMode.NEXT_STEP, } def transition( diff --git a/gymnasium/vector/__init__.py b/gymnasium/vector/__init__.py index 5e380b093..e55a098ad 100644 --- a/gymnasium/vector/__init__.py +++ b/gymnasium/vector/__init__.py @@ -4,6 +4,7 @@ from gymnasium.vector.async_vector_env import AsyncVectorEnv from gymnasium.vector.sync_vector_env import SyncVectorEnv from gymnasium.vector.vector_env import ( + AutoresetMode, VectorActionWrapper, VectorEnv, VectorObservationWrapper, @@ -21,4 +22,5 @@ "SyncVectorEnv", "AsyncVectorEnv", "utils", + "AutoresetMode", ] diff --git a/gymnasium/vector/async_vector_env.py b/gymnasium/vector/async_vector_env.py index 3d7e24b78..c681467ce 100644 --- a/gymnasium/vector/async_vector_env.py +++ b/gymnasium/vector/async_vector_env.py @@ -35,7 +35,7 @@ read_from_shared_memory, write_to_shared_memory, ) -from gymnasium.vector.vector_env import ArrayType, VectorEnv +from gymnasium.vector.vector_env import ArrayType, AutoresetMode, VectorEnv __all__ = ["AsyncVectorEnv", "AsyncState"] @@ -101,6 +101,7 @@ def __init__( | None ) = None, observation_mode: str | Space = "same", + autoreset_mode: str | AutoresetMode = AutoresetMode.NEXT_STEP, ): """Vectorized environment that runs multiple environments in parallel. @@ -120,6 +121,7 @@ def __init__( 'different' defines that there can be multiple observation spaces with different parameters though requires the same shape and dtype, warning, may raise unexpected errors. Passing a ``Tuple[Space, Space]`` object allows defining a custom ``single_observation_space`` and ``observation_space``, warning, may raise unexpected errors. + autoreset_mode: The Autoreset Mode used, see todo for more details. Warnings: worker is an advanced mode option. It provides a high degree of flexibility and a high chance @@ -135,7 +137,15 @@ def __init__( self.env_fns = env_fns self.shared_memory = shared_memory self.copy = copy + self.context = context + self.daemon = daemon + self.worker = worker self.observation_mode = observation_mode + self.autoreset_mode = ( + autoreset_mode + if isinstance(autoreset_mode, AutoresetMode) + else AutoresetMode(autoreset_mode) + ) self.num_envs = len(env_fns) @@ -145,6 +155,7 @@ def __init__( # As we support `make_vec(spec)` then we can't include a `spec = dummy_env.spec` as this doesn't guarantee we can actual recreate the vector env. self.metadata = dummy_env.metadata + self.metadata["autoreset_mode"] = self.autoreset_mode self.render_mode = dummy_env.render_mode self.single_action_space = dummy_env.action_space @@ -211,6 +222,7 @@ def __init__( parent_pipe, _obs_buffer, self.error_queue, + self.autoreset_mode, ), ) @@ -287,9 +299,32 @@ def reset_async( str(self._state.value), ) - for pipe, env_seed in zip(self.parent_pipes, seed): - env_kwargs = {"seed": env_seed, "options": options} - pipe.send(("reset", env_kwargs)) + if options is not None and "reset_mask" in options: + reset_mask = options.pop("reset_mask") + assert isinstance( + reset_mask, np.ndarray + ), f"`options['reset_mask': mask]` must be a numpy array, got {type(reset_mask)}" + assert reset_mask.shape == ( + self.num_envs, + ), f"`options['reset_mask': mask]` must have shape `({self.num_envs},)`, got {reset_mask.shape}" + assert ( + reset_mask.dtype == np.bool_ + ), f"`options['reset_mask': mask]` must have `dtype=np.bool_`, got {reset_mask.dtype}" + assert np.any( + reset_mask + ), f"`options['reset_mask': mask]` must contain a boolean array, got reset_mask={reset_mask}" + + for pipe, env_seed, env_reset in zip(self.parent_pipes, seed, reset_mask): + if env_reset: + env_kwargs = {"seed": env_seed, "options": options} + pipe.send(("reset", env_kwargs)) + else: + pipe.send(("reset-noop", None)) + else: + for pipe, env_seed in zip(self.parent_pipes, seed): + env_kwargs = {"seed": env_seed, "options": options} + pipe.send(("reset", env_kwargs)) + self._state = AsyncState.WAITING_RESET def reset_wait( @@ -688,11 +723,13 @@ def _async_worker( parent_pipe: Connection, shared_memory: multiprocessing.Array | dict[str, Any] | tuple[Any, ...], error_queue: Queue, + autoreset_mode: AutoresetMode, ): env = env_fn() observation_space = env.observation_space action_space = env.action_space autoreset = False + observation = None parent_pipe.close() @@ -709,11 +746,23 @@ def _async_worker( observation = None autoreset = False pipe.send(((observation, info), True)) + elif command == "reset-noop": + pipe.send(((observation, {}), True)) elif command == "step": - if autoreset: - observation, info = env.reset() - reward, terminated, truncated = 0, False, False - else: + if autoreset_mode == AutoresetMode.NEXT_STEP: + if autoreset: + observation, info = env.reset() + reward, terminated, truncated = 0, False, False + else: + ( + observation, + reward, + terminated, + truncated, + info, + ) = env.step(data) + autoreset = terminated or truncated + elif autoreset_mode == AutoresetMode.SAME_STEP: ( observation, reward, @@ -721,7 +770,27 @@ def _async_worker( truncated, info, ) = env.step(data) - autoreset = terminated or truncated + + if terminated or truncated: + reset_observation, reset_info = env.reset() + + info = { + "final_info": info, + "final_obs": observation, + **reset_info, + } + observation = reset_observation + elif autoreset_mode == AutoresetMode.DISABLED: + assert autoreset is False + ( + observation, + reward, + terminated, + truncated, + info, + ) = env.step(data) + else: + raise ValueError(f"Unexpected autoreset_mode: {autoreset_mode}") if shared_memory: write_to_shared_memory( diff --git a/gymnasium/vector/sync_vector_env.py b/gymnasium/vector/sync_vector_env.py index b92a26889..7b98b2ce5 100644 --- a/gymnasium/vector/sync_vector_env.py +++ b/gymnasium/vector/sync_vector_env.py @@ -17,7 +17,7 @@ create_empty_array, iterate, ) -from gymnasium.vector.vector_env import ArrayType, VectorEnv +from gymnasium.vector.vector_env import ArrayType, AutoresetMode, VectorEnv __all__ = ["SyncVectorEnv"] @@ -65,6 +65,7 @@ def __init__( env_fns: Iterator[Callable[[], Env]] | Sequence[Callable[[], Env]], copy: bool = True, observation_mode: str | Space = "same", + autoreset_mode: str | AutoresetMode = AutoresetMode.NEXT_STEP, ): """Vectorized environment that serially runs multiple environments. @@ -74,13 +75,22 @@ def __init__( observation_mode: Defines how environment observation spaces should be batched. 'same' defines that there should be ``n`` copies of identical spaces. 'different' defines that there can be multiple observation spaces with the same length but different high/low values batched together. Passing a ``Space`` object allows the user to set some custom observation space mode not covered by 'same' or 'different.' + autoreset_mode: The Autoreset Mode used, see todo for more details. + Raises: RuntimeError: If the observation space of some sub-environment does not match observation_space (or, by default, the observation space of the first sub-environment). """ - self.copy = copy + super().__init__() + self.env_fns = env_fns + self.copy = copy self.observation_mode = observation_mode + self.autoreset_mode = ( + autoreset_mode + if isinstance(autoreset_mode, AutoresetMode) + else AutoresetMode(autoreset_mode) + ) # Initialise all sub-environments self.envs = [env_fn() for env_fn in env_fns] @@ -89,6 +99,7 @@ def __init__( # As we support `make_vec(spec)` then we can't include a `spec = self.envs[0].spec` as this doesn't guarantee we can actual recreate the vector env. self.num_envs = len(self.envs) self.metadata = self.envs[0].metadata + self.metadata["autoreset_mode"] = self.autoreset_mode self.render_mode = self.envs[0].render_mode self.single_action_space = self.envs[0].action_space @@ -130,6 +141,7 @@ def __init__( ), f"Sub-environment action space doesn't make the `single_action_space`, action_space={env.action_space}, single_action_space={self.single_action_space}" # Initialise attributes used in `step` and `reset` + self._env_obs = [None for _ in range(self.num_envs)] self._observations = create_empty_array( self.single_observation_space, n=self.num_envs, fn=np.zeros ) @@ -175,23 +187,52 @@ def reset( len(seed) == self.num_envs ), f"If seeds are passed as a list the length must match num_envs={self.num_envs} but got length={len(seed)}." - self._terminations = np.zeros((self.num_envs,), dtype=np.bool_) - self._truncations = np.zeros((self.num_envs,), dtype=np.bool_) - - observations, infos = [], {} - for i, (env, single_seed) in enumerate(zip(self.envs, seed)): - env_obs, env_info = env.reset(seed=single_seed, options=options) + if options is not None and "reset_mask" in options: + reset_mask = options.pop("reset_mask") + assert isinstance( + reset_mask, np.ndarray + ), f"`options['reset_mask': mask]` must be a numpy array, got {type(reset_mask)}" + assert reset_mask.shape == ( + self.num_envs, + ), f"`options['reset_mask': mask]` must have shape `({self.num_envs},)`, got {reset_mask.shape}" + assert ( + reset_mask.dtype == np.bool_ + ), f"`options['reset_mask': mask]` must have `dtype=np.bool_`, got {reset_mask.dtype}" + assert np.any( + reset_mask + ), f"`options['reset_mask': mask]` must contain a boolean array, got reset_mask={reset_mask}" + + self._terminations[reset_mask] = False + self._truncations[reset_mask] = False + self._autoreset_envs[reset_mask] = False + + infos = {} + for i, (env, single_seed, env_mask) in enumerate( + zip(self.envs, seed, reset_mask) + ): + if env_mask: + self._env_obs[i], env_info = env.reset( + seed=single_seed, options=options + ) + + infos = self._add_info(infos, env_info, i) + else: + self._terminations = np.zeros((self.num_envs,), dtype=np.bool_) + self._truncations = np.zeros((self.num_envs,), dtype=np.bool_) + self._autoreset_envs = np.zeros((self.num_envs,), dtype=np.bool_) + + infos = {} + for i, (env, single_seed) in enumerate(zip(self.envs, seed)): + self._env_obs[i], env_info = env.reset( + seed=single_seed, options=options + ) - observations.append(env_obs) - infos = self._add_info(infos, env_info, i) + infos = self._add_info(infos, env_info, i) # Concatenate the observations self._observations = concatenate( - self.single_observation_space, observations, self._observations + self.single_observation_space, self._env_obs, self._observations ) - - self._autoreset_envs = np.zeros((self.num_envs,), dtype=np.bool_) - return deepcopy(self._observations) if self.copy else self._observations, infos def step( @@ -204,29 +245,58 @@ def step( """ actions = iterate(self.action_space, actions) - observations, infos = [], {} + infos = {} for i, action in enumerate(actions): - if self._autoreset_envs[i]: - env_obs, env_info = self.envs[i].reset() - - self._rewards[i] = 0.0 - self._terminations[i] = False - self._truncations[i] = False - else: + if self.autoreset_mode == AutoresetMode.NEXT_STEP: + if self._autoreset_envs[i]: + self._env_obs[i], env_info = self.envs[i].reset() + + self._rewards[i] = 0.0 + self._terminations[i] = False + self._truncations[i] = False + else: + ( + self._env_obs[i], + self._rewards[i], + self._terminations[i], + self._truncations[i], + env_info, + ) = self.envs[i].step(action) + elif self.autoreset_mode == AutoresetMode.DISABLED: + # assumes that the user has correctly autoreset + assert not self._autoreset_envs[i], f"{self._autoreset_envs=}" ( - env_obs, + self._env_obs[i], self._rewards[i], self._terminations[i], self._truncations[i], env_info, ) = self.envs[i].step(action) + elif self.autoreset_mode == AutoresetMode.SAME_STEP: + ( + self._env_obs[i], + self._rewards[i], + self._terminations[i], + self._truncations[i], + env_info, + ) = self.envs[i].step(action) + + if self._terminations[i] or self._truncations[i]: + infos = self._add_info( + infos, + {"final_obs": self._env_obs[i], "final_info": env_info}, + i, + ) + + self._env_obs[i], env_info = self.envs[i].reset() + else: + raise ValueError(f"Unexpected autoreset mode, {self.autoreset_mode}") - observations.append(env_obs) infos = self._add_info(infos, env_info, i) # Concatenate the observations self._observations = concatenate( - self.single_observation_space, observations, self._observations + self.single_observation_space, self._env_obs, self._observations ) self._autoreset_envs = np.logical_or(self._terminations, self._truncations) diff --git a/gymnasium/vector/vector_env.py b/gymnasium/vector/vector_env.py index 7a59a11a2..00b6204c3 100644 --- a/gymnasium/vector/vector_env.py +++ b/gymnasium/vector/vector_env.py @@ -2,12 +2,14 @@ from __future__ import annotations +from enum import Enum from typing import TYPE_CHECKING, Any, Generic, TypeVar import numpy as np import gymnasium as gym from gymnasium.core import ActType, ObsType, RenderFrame +from gymnasium.logger import warn from gymnasium.utils import seeding @@ -24,9 +26,18 @@ "VectorActionWrapper", "VectorRewardWrapper", "ArrayType", + "AutoresetMode", ] +class AutoresetMode(Enum): + """Enum representing the different autoreset modes, next step, same step and disabled.""" + + NEXT_STEP: str = "NextStep" + SAME_STEP: str = "SameStep" + DISABLED: str = "Disabled" + + class VectorEnv(Generic[ObsType, ActType, ArrayType]): """Base class for vectorized environments to run multiple independent copies of the same environment in parallel. @@ -280,8 +291,15 @@ def _add_info( infos (dict): the (updated) infos of the vectorized environment """ for key, value in env_info.items(): + # It is easier for users to access their `final_obs` in the unbatched array of `obs` objects + if key == "final_obs": + if "final_obs" in vector_infos: + array = vector_infos["final_obs"] + else: + array = np.full(self.num_envs, fill_value=None, dtype=object) + array[env_num] = value # If value is a dictionary, then we apply the `_add_info` recursively. - if isinstance(value, dict): + elif isinstance(value, dict): array = self._add_info(vector_infos.get(key, {}), value, env_num) # Otherwise, we are a base case to group the data else: @@ -315,7 +333,6 @@ def _add_info( # Update the vector info with the updated data and mask information vector_infos[key], vector_infos[f"_{key}"] = array, array_mask - return vector_infos def __del__(self): @@ -509,6 +526,23 @@ class VectorObservationWrapper(VectorWrapper): Equivalent to :class:`gymnasium.ObservationWrapper` for vectorized environments. """ + def __init__(self, env: VectorEnv): + """Vector observation wrapper that batch transforms observations. + + Args: + env: Vector environment. + """ + super().__init__(env) + if "autoreset_mode" not in env.metadata: + warn( + f"Vector environment ({env}) is missing `autoreset_mode` metadata key." + ) + else: + assert ( + env.metadata["autoreset_mode"] == AutoresetMode.NEXT_STEP + or env.metadata["autoreset_mode"] == AutoresetMode.DISABLED + ) + def reset( self, *, diff --git a/gymnasium/wrappers/vector/common.py b/gymnasium/wrappers/vector/common.py index a9a570520..ea20186ea 100644 --- a/gymnasium/wrappers/vector/common.py +++ b/gymnasium/wrappers/vector/common.py @@ -8,7 +8,13 @@ import numpy as np from gymnasium.core import ActType, ObsType -from gymnasium.vector.vector_env import ArrayType, VectorEnv, VectorWrapper +from gymnasium.logger import warn +from gymnasium.vector.vector_env import ( + ArrayType, + AutoresetMode, + VectorEnv, + VectorWrapper, +) __all__ = ["RecordEpisodeStatistics"] @@ -78,13 +84,19 @@ def __init__( """ super().__init__(env) self._stats_key = stats_key + if "autoreset_mode" not in self.env.metadata: + warn("todo") + self._autoreset_mode = AutoresetMode.NEXT_STEP + else: + assert isinstance(self.env.metadata["autoreset_mode"], AutoresetMode) + self._autoreset_mode = self.env.metadata["autoreset_mode"] self.episode_count = 0 - self.episode_start_times: np.ndarray = np.zeros(()) - self.episode_returns: np.ndarray = np.zeros(()) - self.episode_lengths: np.ndarray = np.zeros((), dtype=int) - self.prev_dones: np.ndarray = np.zeros((), dtype=bool) + self.episode_start_times: np.ndarray = np.zeros((self.num_envs,)) + self.episode_returns: np.ndarray = np.zeros((self.num_envs,)) + self.episode_lengths: np.ndarray = np.zeros((self.num_envs,), dtype=int) + self.prev_dones: np.ndarray = np.zeros((self.num_envs,), dtype=bool) self.time_queue = deque(maxlen=buffer_length) self.return_queue = deque(maxlen=buffer_length) @@ -98,10 +110,30 @@ def reset( """Resets the environment using kwargs and resets the episode returns and lengths.""" obs, info = super().reset(seed=seed, options=options) - self.episode_start_times = np.full(self.num_envs, time.perf_counter()) - self.episode_returns = np.zeros(self.num_envs) - self.episode_lengths = np.zeros(self.num_envs, dtype=int) - self.prev_dones = np.zeros(self.num_envs, dtype=bool) + if options is not None and "reset_mask" in options: + reset_mask = options.pop("reset_mask") + assert isinstance( + reset_mask, np.ndarray + ), f"`options['reset_mask': mask]` must be a numpy array, got {type(reset_mask)}" + assert reset_mask.shape == ( + self.num_envs, + ), f"`options['reset_mask': mask]` must have shape `({self.num_envs},)`, got {reset_mask.shape}" + assert ( + reset_mask.dtype == np.bool_ + ), f"`options['reset_mask': mask]` must have `dtype=np.bool_`, got {reset_mask.dtype}" + assert np.any( + reset_mask + ), f"`options['reset_mask': mask]` must contain a boolean array, got reset_mask={reset_mask}" + + self.episode_start_times[reset_mask] = time.perf_counter() + self.episode_returns[reset_mask] = 0 + self.episode_lengths[reset_mask] = 0 + self.prev_dones[reset_mask] = False + else: + self.episode_start_times = np.full(self.num_envs, time.perf_counter()) + self.episode_returns = np.zeros(self.num_envs) + self.episode_lengths = np.zeros(self.num_envs, dtype=int) + self.prev_dones = np.zeros(self.num_envs, dtype=bool) return obs, info @@ -122,18 +154,22 @@ def step( ), f"`vector.RecordEpisodeStatistics` requires `info` type to be `dict`, its actual type is {type(infos)}. This may be due to usage of other wrappers in the wrong order." self.episode_returns[self.prev_dones] = 0 + self.episode_returns[np.logical_not(self.prev_dones)] += rewards[ + np.logical_not(self.prev_dones) + ] + self.episode_lengths[self.prev_dones] = 0 - self.episode_start_times[self.prev_dones] = time.perf_counter() - self.episode_returns[~self.prev_dones] += rewards[~self.prev_dones] self.episode_lengths[~self.prev_dones] += 1 + self.episode_start_times[self.prev_dones] = time.perf_counter() + self.prev_dones = dones = np.logical_or(terminations, truncations) num_dones = np.sum(dones) if num_dones: if self._stats_key in infos or f"_{self._stats_key}" in infos: raise ValueError( - f"Attempted to add episode stats when they already exist, info keys: {list(infos.keys())}" + f"Attempted to add episode stats with key '{self._stats_key}' but this key already exists in info: {list(infos.keys())}" ) else: episode_time_length = np.round( diff --git a/gymnasium/wrappers/vector/stateful_observation.py b/gymnasium/wrappers/vector/stateful_observation.py index 266c488d1..a220cb316 100644 --- a/gymnasium/wrappers/vector/stateful_observation.py +++ b/gymnasium/wrappers/vector/stateful_observation.py @@ -5,11 +5,18 @@ from __future__ import annotations +from typing import Any + import numpy as np import gymnasium as gym from gymnasium.core import ObsType -from gymnasium.vector.vector_env import VectorEnv, VectorObservationWrapper +from gymnasium.logger import warn +from gymnasium.vector.vector_env import ( + AutoresetMode, + VectorEnv, + VectorObservationWrapper, +) from gymnasium.wrappers.utils import RunningMeanStd @@ -65,6 +72,13 @@ def __init__(self, env: VectorEnv, epsilon: float = 1e-8): gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon) VectorObservationWrapper.__init__(self, env) + if "autoreset_mode" not in self.env.metadata: + warn( + f"{self} is missing `autoreset_mode` data. Assuming that the vector environment it follows the `NextStep` autoreset api or autoreset is disabled. Read todo for more details." + ) + else: + assert self.env.metadata["autoreset_mode"] in {AutoresetMode.NEXT_STEP} + self.obs_rms = RunningMeanStd( shape=self.single_observation_space.shape, dtype=self.single_observation_space.dtype, @@ -82,6 +96,20 @@ def update_running_mean(self, setting: bool): """Sets the property to freeze/continue the running mean calculation of the observation statistics.""" self._update_running_mean = setting + def reset( + self, + *, + seed: int | list[int] | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[ObsType, dict[str, Any]]: + """Reset function for `NormalizeObservationWrapper` which is disabled for partial resets.""" + assert ( + options is None + or "reset_mask" not in options + or np.all(options["reset_mask"]) + ) + return super().reset(seed=seed, options=options) + def observations(self, observations: ObsType) -> ObsType: """Defines the vector observation normalization function. diff --git a/gymnasium/wrappers/vector/vectorize_observation.py b/gymnasium/wrappers/vector/vectorize_observation.py index 88bd539ad..52b5b9a07 100644 --- a/gymnasium/wrappers/vector/vectorize_observation.py +++ b/gymnasium/wrappers/vector/vectorize_observation.py @@ -8,9 +8,11 @@ import numpy as np from gymnasium import Space -from gymnasium.core import Env, ObsType +from gymnasium.core import ActType, Env, ObsType +from gymnasium.logger import warn from gymnasium.vector import VectorEnv, VectorObservationWrapper from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate +from gymnasium.vector.vector_env import ArrayType, AutoresetMode from gymnasium.wrappers import transform_observation @@ -138,6 +140,15 @@ def __init__( """ super().__init__(env) + if "autoreset_mode" not in env.metadata: + warn( + f"Vector environment ({env}) is missing `autoreset_mode` metadata key." + ) + self.autoreset_mode = AutoresetMode.NEXT_STEP + else: + assert isinstance(env.metadata["autoreset_mode"], AutoresetMode) + self.autoreset_mode = env.metadata["autoreset_mode"] + self.wrapper = wrapper( self._SingleEnv(self.env.single_observation_space), **kwargs ) @@ -149,6 +160,24 @@ def __init__( self.same_out = self.observation_space == self.env.observation_space self.out = create_empty_array(self.single_observation_space, self.num_envs) + def step( + self, actions: ActType + ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]: + """Steps through the vector environments, transforming the observation and for final obs individually transformed.""" + obs, rewards, terminations, truncations, infos = self.env.step(actions) + obs = self.observations(obs) + + if self.autoreset_mode == AutoresetMode.SAME_STEP and "final_obs" in infos: + final_obs = infos["final_obs"] + + for i, (sub_obs, has_final_obs) in enumerate( + zip(final_obs, infos["_final_obs"]) + ): + if has_final_obs: + final_obs[i] = self.wrapper.observation(sub_obs) + + return obs, rewards, terminations, truncations, infos + def observations(self, observations: ObsType) -> ObsType: """Iterates over the vector observations applying the single-agent wrapper ``observation`` then concatenates the observations together again.""" if self.same_out: diff --git a/pyproject.toml b/pyproject.toml index 3942ae9a1..695403228 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ all = [ "moviepy >=1.0.0", ] testing = [ - "pytest ==7.1.3", + "pytest >=7.1.3", "scipy >=1.7.3", "dill >=0.3.7", ] diff --git a/tests/envs/registration/test_make_vec.py b/tests/envs/registration/test_make_vec.py index 32d1d743c..c2fbe09b3 100644 --- a/tests/envs/registration/test_make_vec.py +++ b/tests/envs/registration/test_make_vec.py @@ -2,6 +2,7 @@ import multiprocessing import re +import warnings import pytest @@ -9,7 +10,7 @@ from gymnasium import VectorizeMode, error, wrappers from gymnasium.envs.classic_control import CartPoleEnv from gymnasium.envs.classic_control.cartpole import CartPoleVectorEnv -from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv +from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv, VectorEnv from gymnasium.wrappers import TimeLimit, TransformObservation from tests.wrappers.utils import has_wrapper @@ -282,3 +283,44 @@ def test_make_vec_with_spec_additional_wrappers(): del gym.registry["TestEnv-v0"] del gym.registry["TestEnv-v1"] + + +class MissingMetadataVecEnv(VectorEnv): + metadata = {"render_fps": 1} + + def __init__(self, num_envs: int): + self.num_envs = num_envs + + +class IncorrectMetadataVecEnv(VectorEnv): + metadata = {"autoreset_mode": "next_step"} + + def __init__(self, num_envs: int): + self.num_envs = num_envs + + +def test_missing_autoreset_mode_metadata(): + gym.register("MissingMetadataVecEnv-v0", vector_entry_point=MissingMetadataVecEnv) + gym.register( + "IncorrectMetadataVecEnv-v0", vector_entry_point=IncorrectMetadataVecEnv + ) + + with warnings.catch_warnings(): + with pytest.warns( + UserWarning, + match=re.escape( + "The VectorEnv (MissingMetadataVecEnv(MissingMetadataVecEnv-v0, num_envs=1)) is missing AutoresetMode metadata, metadata={'render_fps': 1}" + ), + ): + gym.make_vec("MissingMetadataVecEnv-v0") + + with pytest.warns( + UserWarning, + match=re.escape( + "The VectorEnv (IncorrectMetadataVecEnv(IncorrectMetadataVecEnv-v0, num_envs=1)) metadata['autoreset_mode'] is not an instance of AutoresetMode, ." + ), + ): + gym.make_vec("IncorrectMetadataVecEnv-v0") + + gym.registry.pop("MissingMetadataVecEnv-v0") + gym.registry.pop("IncorrectMetadataVecEnv-v0") diff --git a/tests/testing_env.py b/tests/testing_env.py index dafbf15a6..41572a62d 100644 --- a/tests/testing_env.py +++ b/tests/testing_env.py @@ -45,7 +45,6 @@ def basic_render_func(self): pass -# todo: change all testing environment to this generic class class GenericTestEnv(gym.Env): """A generic testing environment for use in testing with modified environments are required.""" diff --git a/tests/vector/test_autoreset_mode.py b/tests/vector/test_autoreset_mode.py new file mode 100644 index 000000000..a8673f65f --- /dev/null +++ b/tests/vector/test_autoreset_mode.py @@ -0,0 +1,343 @@ +from __future__ import annotations + +from functools import partial + +import numpy as np +import pytest + +import gymnasium as gym +from gymnasium import VectorizeMode +from gymnasium.spaces import Discrete +from gymnasium.utils.env_checker import data_equivalence +from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv +from gymnasium.vector.vector_env import AutoresetMode +from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS +from tests.testing_env import GenericTestEnv + + +def count_reset( + self: GenericTestEnv, seed: int | None = None, options: dict | None = None +): + super(GenericTestEnv, self).reset(seed=seed) + + self.count = seed if seed is not None else 0 + return self.count, {} + + +def count_step(self: GenericTestEnv, action): + self.count += 1 + + return self.count, action, self.count == self.max_count, False, {} + + +@pytest.mark.parametrize( + "vectoriser", + [ + SyncVectorEnv, + AsyncVectorEnv, + partial(AsyncVectorEnv, shared_memory=False), + ], + ids=["Sync", "Async(shared_memory=True)", "Async(shared_memory=False)"], +) +def test_autoreset_next_step(vectoriser): + envs = vectoriser( + [ + lambda: GenericTestEnv( + action_space=Discrete(5), + observation_space=Discrete(5), + reset_func=count_reset, + step_func=count_step, + ) + for _ in range(3) + ], + autoreset_mode=AutoresetMode.NEXT_STEP, + ) + assert envs.metadata["autoreset_mode"] == AutoresetMode.NEXT_STEP + envs.set_attr("max_count", [2, 3, 3]) + + obs, info = envs.reset() + assert np.all(obs == [0, 0, 0]) + assert info == {} + + obs, rewards, terminations, truncations, info = envs.step([1, 2, 3]) + assert np.all(obs == [1, 1, 1]) + assert np.all(rewards == [1, 2, 3]) + assert np.all(terminations == [False, False, False]) + assert np.all(truncations == [False, False, False]) + assert info == {} + + obs, rewards, terminations, truncations, info = envs.step([1, 2, 3]) + assert np.all(obs == [2, 2, 2]) + assert np.all(rewards == [1, 2, 3]) + assert np.all(terminations == [True, False, False]) + assert np.all(truncations == [False, False, False]) + assert info == {} + + obs, rewards, terminations, truncations, info = envs.step([1, 2, 3]) + assert np.all(obs == [0, 3, 3]) + assert np.all(rewards == [0, 2, 3]) + assert np.all(terminations == [False, True, True]) + assert np.all(truncations == [False, False, False]) + assert info == {} + + obs, rewards, terminations, truncations, info = envs.step([1, 2, 3]) + assert np.all(obs == [1, 0, 0]) + assert np.all(rewards == [1, 0, 0]) + assert np.all(terminations == [False, False, False]) + assert np.all(truncations == [False, False, False]) + assert info == {} + + envs.close() + + +@pytest.mark.parametrize( + "vectoriser", + [ + SyncVectorEnv, + AsyncVectorEnv, + partial(AsyncVectorEnv, shared_memory=False), + ], + ids=["Sync", "Async(shared_memory=True)", "Async(shared_memory=False)"], +) +def test_autoreset_within_step(vectoriser): + envs = vectoriser( + [ + lambda: GenericTestEnv( + action_space=Discrete(5), + observation_space=Discrete(5), + reset_func=count_reset, + step_func=count_step, + ) + for _ in range(3) + ], + autoreset_mode=AutoresetMode.SAME_STEP, + ) + assert envs.metadata["autoreset_mode"] == AutoresetMode.SAME_STEP + envs.set_attr("max_count", [2, 3, 3]) + + obs, info = envs.reset() + assert np.all(obs == [0, 0, 0]) + assert info == {} + + obs, rewards, terminations, truncations, info = envs.step([1, 2, 3]) + assert np.all(obs == [1, 1, 1]) + assert np.all(rewards == [1, 2, 3]) + assert np.all(terminations == [False, False, False]) + assert np.all(truncations == [False, False, False]) + assert info == {} + + obs, rewards, terminations, truncations, info = envs.step([1, 2, 3]) + assert np.all(obs == [0, 2, 2]) + assert np.all(rewards == [1, 2, 3]) + assert np.all(terminations == [True, False, False]) + assert np.all(truncations == [False, False, False]) + assert data_equivalence( + info, + { + "final_obs": np.array([2, None, None], dtype=object), + "final_info": {}, + "_final_obs": np.array([True, False, False]), + "_final_info": np.array([True, False, False]), + }, + ) + + obs, rewards, terminations, truncations, info = envs.step([1, 2, 3]) + assert np.all(obs == [1, 0, 0]) + assert np.all(rewards == [1, 2, 3]) + assert np.all(terminations == [False, True, True]) + assert np.all(truncations == [False, False, False]) + assert data_equivalence( + info, + { + "final_obs": np.array([None, 3, 3], dtype=object), + "final_info": {}, + "_final_obs": np.array([False, True, True]), + "_final_info": np.array([False, True, True]), + }, + ) + + obs, rewards, terminations, truncations, info = envs.step([1, 2, 3]) + assert np.all(obs == [0, 1, 1]) + assert np.all(rewards == [1, 2, 3]) + assert np.all(terminations == [True, False, False]) + assert np.all(truncations == [False, False, False]) + assert data_equivalence( + info, + { + "final_obs": np.array([2, None, None], dtype=object), + "final_info": {}, + "_final_obs": np.array([True, False, False]), + "_final_info": np.array([True, False, False]), + }, + ) + + envs.close() + + +@pytest.mark.parametrize( + "vectoriser", + [ + SyncVectorEnv, + AsyncVectorEnv, + partial(AsyncVectorEnv, shared_memory=False), + ], + ids=["Sync", "Async(shared_memory=True)", "Async(shared_memory=False)"], +) +def test_autoreset_disabled(vectoriser): + envs = vectoriser( + [ + lambda: GenericTestEnv( + action_space=Discrete(5), + observation_space=Discrete(5), + reset_func=count_reset, + step_func=count_step, + ) + for _ in range(3) + ], + autoreset_mode=AutoresetMode.DISABLED, + ) + assert envs.metadata["autoreset_mode"] == AutoresetMode.DISABLED + envs.set_attr("max_count", [2, 3, 3]) + + obs, info = envs.reset() + assert np.all(obs == [0, 0, 0]) + assert info == {} + + obs, rewards, terminations, truncations, info = envs.step([1, 2, 3]) + assert np.all(obs == [1, 1, 1]) + assert np.all(rewards == [1, 2, 3]) + assert np.all(terminations == [False, False, False]) + assert np.all(truncations == [False, False, False]) + assert info == {} + + obs, rewards, terminations, truncations, info = envs.step([1, 2, 3]) + assert np.all(obs == [2, 2, 2]) + assert np.all(rewards == [1, 2, 3]) + assert np.all(terminations == [True, False, False]) + assert np.all(truncations == [False, False, False]) + assert info == {} + + obs, info = envs.reset(options={"reset_mask": terminations}) + assert np.all(obs == [0, 2, 2]) + assert info == {} + + obs, rewards, terminations, truncations, info = envs.step([1, 2, 3]) + assert np.all(obs == [1, 3, 3]) + assert np.all(rewards == [1, 2, 3]) + assert np.all(terminations == [False, True, True]) + assert np.all(truncations == [False, False, False]) + assert info == {} + + obs, info = envs.reset(options={"reset_mask": terminations}) + assert np.all(obs == [1, 0, 0]) + assert info == {} + + obs, rewards, terminations, truncations, info = envs.step([1, 2, 3]) + assert np.all(obs == [2, 1, 1]) + assert np.all(rewards == [1, 2, 3]) + assert np.all(terminations == [True, False, False]) + assert np.all(truncations == [False, False, False]) + assert info == {} + + envs.close() + + +@pytest.mark.parametrize( + "vectoriser", + [ + SyncVectorEnv, + AsyncVectorEnv, + partial(AsyncVectorEnv, shared_memory=False), + ], + ids=["Sync", "Async(shared_memory=True)", "Async(shared_memory=False)"], +) +@pytest.mark.parametrize( + "autoreset_mode", + [AutoresetMode.NEXT_STEP, AutoresetMode.DISABLED, AutoresetMode.SAME_STEP], +) +def test_autoreset_metadata(vectoriser, autoreset_mode): + envs = vectoriser( + [lambda: GenericTestEnv(), lambda: GenericTestEnv()], + autoreset_mode=autoreset_mode, + ) + assert envs.metadata["autoreset_mode"] == autoreset_mode + envs.close() + + envs = vectoriser( + [lambda: GenericTestEnv(), lambda: GenericTestEnv()], + autoreset_mode=autoreset_mode.value, + ) + assert envs.metadata["autoreset_mode"] == autoreset_mode + envs.close() + + +@pytest.mark.parametrize( + "vectorization_mode", [VectorizeMode.SYNC, VectorizeMode.ASYNC] +) +@pytest.mark.parametrize( + "autoreset_mode", + [AutoresetMode.NEXT_STEP, AutoresetMode.DISABLED, AutoresetMode.SAME_STEP], +) +def test_make_vec_autoreset(vectorization_mode, autoreset_mode): + envs = gym.make_vec( + "CartPole-v1", + vectorization_mode=vectorization_mode, + vector_kwargs={"autoreset_mode": autoreset_mode}, + ) + envs.metadata["autoreset_mode"] = autoreset_mode + envs.close() + + envs = gym.make_vec( + "CartPole-v1", + vectorization_mode=vectorization_mode, + vector_kwargs={"autoreset_mode": autoreset_mode.value}, + ) + envs.metadata["autoreset_mode"] = autoreset_mode + envs.close() + + +def count_reset_obs( + self: GenericTestEnv, seed: int | None = None, options: dict | None = None +): + super(GenericTestEnv, self).reset(seed=seed) + + self.count = seed if seed is not None else 0 + return self.observation_space.sample(), {} + + +def count_step_obs(self: GenericTestEnv, action): + self.count += 1 + + return ( + self.observation_space.sample(), + action, + self.count == self.max_count, + False, + {}, + ) + + +@pytest.mark.parametrize("obs_space", TESTING_SPACES, ids=TESTING_SPACES_IDS) +def test_same_step_final_obs(obs_space): + envs = SyncVectorEnv( + [ + lambda: GenericTestEnv( + action_space=Discrete(5), + observation_space=obs_space, + reset_func=count_reset_obs, + step_func=count_step_obs, + ) + for _ in range(3) + ], + autoreset_mode=AutoresetMode.SAME_STEP, + ) + assert envs.metadata["autoreset_mode"] == AutoresetMode.SAME_STEP + envs.set_attr("max_count", [2, 3, 3]) + + envs.reset() + envs.step([1, 2, 3]) + obs, rewards, terminations, truncations, info = envs.step([1, 2, 3]) + assert info["final_obs"][0] in envs.single_observation_space + obs, rewards, terminations, truncations, info = envs.step([1, 2, 3]) + assert info["final_obs"][1] in envs.single_observation_space + assert info["final_obs"][2] in envs.single_observation_space diff --git a/tests/vector/test_vector_env.py b/tests/vector/test_vector_env.py index 6286d9623..a88bfe5b0 100644 --- a/tests/vector/test_vector_env.py +++ b/tests/vector/test_vector_env.py @@ -2,28 +2,36 @@ from __future__ import annotations +import re from functools import partial import numpy as np import pytest +import gymnasium as gym from gymnasium.core import ActType, ObsType from gymnasium.spaces import Discrete from gymnasium.utils.env_checker import data_equivalence from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv +from gymnasium.vector.vector_env import AutoresetMode from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS from tests.testing_env import GenericTestEnv from tests.vector.testing_utils import make_env @pytest.mark.parametrize("shared_memory", [True, False]) -def test_vector_env_equal(shared_memory): +@pytest.mark.parametrize( + "autoreset_mode", [AutoresetMode.NEXT_STEP, AutoresetMode.SAME_STEP] +) +def test_vector_env_equal(shared_memory, autoreset_mode): """Test that vector environment are equal for both async and sync variants.""" env_fns = [make_env("CartPole-v1", i) for i in range(4)] num_steps = 100 - async_env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) - sync_env = SyncVectorEnv(env_fns) + async_env = AsyncVectorEnv( + env_fns, shared_memory=shared_memory, autoreset_mode=autoreset_mode + ) + sync_env = SyncVectorEnv(env_fns, autoreset_mode=autoreset_mode) assert async_env.num_envs == sync_env.num_envs assert async_env.observation_space == sync_env.observation_space @@ -212,3 +220,100 @@ def test_random_seeds_set_at_retrieval(venv_constructor, example_env_list): assert len(set(vector_env.np_random_seed)) == vector_env.num_envs # default seed starts at zero. Adjust or remove this test if the default seed changes assert vector_env.np_random_seed == tuple(range(vector_env.num_envs)) + + +@pytest.mark.parametrize( + "vectoriser", + [ + SyncVectorEnv, + AsyncVectorEnv, + partial(AsyncVectorEnv, shared_memory=False), + ], + ids=["Sync", "Async(shared_memory=True)", "Async(shared_memory=False)"], +) +def test_partial_reset(vectoriser): + envs = vectoriser( + [lambda: gym.make("CartPole-v1") for _ in range(3)], + autoreset_mode=AutoresetMode.DISABLED, + ) + reset_obs, _ = envs.reset(seed=[0, 1, 2]) + + envs.action_space.seed(123) + envs.step(envs.action_space.sample()) + envs.step(envs.action_space.sample()) + step_obs, *_ = envs.step(envs.action_space.sample()) + + reset_mask_obs, _ = envs.reset( + seed=[0, 1, 0], options={"reset_mask": np.array([True, True, False])} + ) + assert np.all(reset_mask_obs[:2] == reset_obs[:2]) + assert np.all(reset_mask_obs[2] == step_obs[2]) + + envs.close() + + +@pytest.mark.parametrize( + "vectoriser", + [ + SyncVectorEnv, + AsyncVectorEnv, + partial(AsyncVectorEnv, shared_memory=False), + ], + ids=["Sync", "Async(shared_memory=True)", "Async(shared_memory=False)"], +) +def test_partial_reset_failure(vectoriser): + envs = vectoriser( + [lambda: gym.make("CartPole-v1") for _ in range(3)], + autoreset_mode=AutoresetMode.DISABLED, + ) + + # Test first reset using a mask + # with pytest.raises(AssertionError): + # envs.reset(options={"reset_mask": np.array([True, True, False])}) + + # Reset with all trues + envs.reset(options={"reset_mask": np.array([True, True, True])}) + + # Reset with mask of an incorrect shape + with pytest.raises( + AssertionError, + match=re.escape( + "`options['reset_mask': mask]` must have shape `(3,)`, got (1,)" + ), + ): + envs.reset(options={"reset_mask": np.array([True])}) + with pytest.raises( + AssertionError, + match=re.escape( + "options['reset_mask': mask]` must have shape `(3,)`, got (4,)" + ), + ): + envs.reset(options={"reset_mask": np.array([True, True, False, False])}) + with pytest.raises( + AssertionError, + match=re.escape( + "`options['reset_mask': mask]` must have shape `(3,)`, got (1, 3)" + ), + ): + envs.reset(options={"reset_mask": np.array([[True, True, True]])}) + with pytest.raises( + AssertionError, + match=re.escape( + "`options['reset_mask': mask]` must contain a boolean array, got reset_mask=[False False False]" + ), + ): + envs.reset(options={"reset_mask": np.array([False, False, False])}) + with pytest.raises( + AssertionError, + match=re.escape( + "`options['reset_mask': mask]` must have `dtype=np.bool_`, got int64" + ), + ): + envs.reset(options={"reset_mask": np.array([1, 1, 0])}) + with pytest.raises( + AssertionError, + match=re.escape( + "`options['reset_mask': mask]` must have `dtype=np.bool_`, got float64" + ), + ): + envs.reset(options={"reset_mask": np.array([1.0, 1.0, 0.0])}) diff --git a/tests/vector/test_vector_env_info.py b/tests/vector/test_vector_env_info.py index e762758a6..fe1736a67 100644 --- a/tests/vector/test_vector_env_info.py +++ b/tests/vector/test_vector_env_info.py @@ -140,7 +140,7 @@ def step( @pytest.mark.parametrize("vectorizer", [AsyncVectorEnv, SyncVectorEnv]) -def test_vectorizers(vectorizer): +def test_vector_return_info(vectorizer): vec_env = vectorizer( [ lambda: ReturnInfoEnv([{"a": 1}, {"c": np.array([1, 2])}]), diff --git a/tests/wrappers/vector/test_vector_wrappers.py b/tests/wrappers/vector/test_vector_wrappers.py index 311ecc6fe..3855c13e6 100644 --- a/tests/wrappers/vector/test_vector_wrappers.py +++ b/tests/wrappers/vector/test_vector_wrappers.py @@ -19,6 +19,7 @@ from gymnasium.spaces import Box, Dict, Discrete from gymnasium.utils.env_checker import data_equivalence from gymnasium.vector import VectorEnv +from gymnasium.vector.vector_env import AutoresetMode from tests.testing_env import GenericTestEnv @@ -36,6 +37,9 @@ def custom_environments(): del gym.registry["DictObsEnv-v0"] +@pytest.mark.parametrize( + "autoreset_mode", [AutoresetMode.NEXT_STEP, AutoresetMode.SAME_STEP] +) @pytest.mark.parametrize("num_envs", (1, 3)) @pytest.mark.parametrize( "env_id, wrapper_name, kwargs", @@ -68,11 +72,12 @@ def custom_environments(): ), ) def test_vector_wrapper_equivalence( + autoreset_mode: AutoresetMode, + num_envs: int, env_id: str, wrapper_name: str, kwargs: dict[str, Any], - num_envs: int, - custom_environments, + custom_environments, # pytest fixture vectorization_mode: str = "sync", num_steps: int = 50, ):