Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for different vector autoreset modes #1227

Merged
3 changes: 2 additions & 1 deletion gymnasium/envs/classic_control/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from gymnasium import logger, spaces
from gymnasium.envs.classic_control import utils
from gymnasium.error import DependencyNotInstalled
from gymnasium.vector import VectorEnv
from gymnasium.vector import AutoresetMode, VectorEnv
from gymnasium.vector.utils import batch_space


Expand Down Expand Up @@ -355,6 +355,7 @@ class CartPoleVectorEnv(VectorEnv):
metadata = {
"render_modes": ["rgb_array"],
"render_fps": 50,
"autoreset_mode": AutoresetMode.NEXT_STEP,
}

def __init__(
Expand Down
3 changes: 2 additions & 1 deletion gymnasium/envs/functional_jax_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from gymnasium.envs.registration import EnvSpec
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.utils import seeding
from gymnasium.vector import AutoresetMode
from gymnasium.vector.utils import batch_space


Expand Down Expand Up @@ -115,7 +116,7 @@ def __init__(
"""Initialize the environment from a FuncEnv."""
super().__init__()
if metadata is None:
metadata = {}
metadata = {"autoreset_mode": AutoresetMode.NEXT_STEP}
self.func_env = func_env
self.num_envs = num_envs

Expand Down
8 changes: 7 additions & 1 deletion gymnasium/envs/phys2d/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.utils import EzPickle
from gymnasium.vector import AutoresetMode


RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock"] # type: ignore # noqa: F821
Expand Down Expand Up @@ -272,7 +273,12 @@ def __init__(self, render_mode: str | None = None, **kwargs: Any):
class CartPoleJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle):
"""Jax-based implementation of the vectorized CartPole environment."""

metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}
metadata = {
"render_modes": ["rgb_array"],
"render_fps": 50,
"jax": True,
"autoreset_mode": AutoresetMode.NEXT_STEP,
}

def __init__(
self,
Expand Down
8 changes: 7 additions & 1 deletion gymnasium/envs/phys2d/pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.utils import EzPickle
from gymnasium.vector import AutoresetMode


RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]] # type: ignore # noqa: F821
Expand Down Expand Up @@ -225,7 +226,12 @@ def get_default_params(self, **kwargs) -> PendulumParams:
class PendulumJaxEnv(FunctionalJaxEnv, EzPickle):
"""Jax-based pendulum environment using the functional version as base."""

metadata = {"render_modes": ["rgb_array"], "render_fps": 30, "jax": True}
metadata = {
"render_modes": ["rgb_array"],
"render_fps": 30,
"jax": True,
"autoreset_mode": AutoresetMode.NEXT_STEP,
}

def __init__(self, render_mode: str | None = None, **kwargs: Any):
"""Constructor where the kwargs are passed to the base environment to modify the parameters."""
Expand Down
11 changes: 11 additions & 0 deletions gymnasium/envs/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import gymnasium as gym
from gymnasium import Env, Wrapper, error, logger
from gymnasium.logger import warn
from gymnasium.vector import AutoresetMode


if sys.version_info < (3, 10):
Expand Down Expand Up @@ -976,6 +978,15 @@ def create_single_env() -> Env:
copied_id_spec.kwargs["wrappers"] = wrappers
env.unwrapped.spec = copied_id_spec

if "autoreset_mode" not in env.metadata:
warn(
f"The VectorEnv ({env}) is missing AutoresetMode metadata, metadata={env.metadata}"
)
elif not isinstance(env.metadata["autoreset_mode"], AutoresetMode):
warn(
f"The VectorEnv ({env}) metadata['autoreset_mode'] is not an instance of AutoresetMode, {type(env.metadata['autoreset_mode'])}."
)

return env


Expand Down
2 changes: 2 additions & 0 deletions gymnasium/envs/tabular/blackjack.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.utils import EzPickle, seeding
from gymnasium.vector import AutoresetMode
from gymnasium.wrappers import HumanRendering


Expand Down Expand Up @@ -239,6 +240,7 @@ class BlackjackFunctional(
metadata = {
"render_modes": ["rgb_array"],
"render_fps": 4,
"autoreseet-mode": AutoresetMode.NEXT_STEP,
}

def transition(
Expand Down
2 changes: 2 additions & 0 deletions gymnasium/envs/tabular/cliffwalking.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.utils import EzPickle
from gymnasium.vector import AutoresetMode
from gymnasium.wrappers import HumanRendering


Expand Down Expand Up @@ -136,6 +137,7 @@ class CliffWalkingFunctional(
metadata = {
"render_modes": ["rgb_array"],
"render_fps": 4,
"autoreset_mode": AutoresetMode.NEXT_STEP,
}

def transition(
Expand Down
2 changes: 2 additions & 0 deletions gymnasium/vector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from gymnasium.vector.async_vector_env import AsyncVectorEnv
from gymnasium.vector.sync_vector_env import SyncVectorEnv
from gymnasium.vector.vector_env import (
AutoresetMode,
VectorActionWrapper,
VectorEnv,
VectorObservationWrapper,
Expand All @@ -21,4 +22,5 @@
"SyncVectorEnv",
"AsyncVectorEnv",
"utils",
"AutoresetMode",
]
87 changes: 78 additions & 9 deletions gymnasium/vector/async_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
read_from_shared_memory,
write_to_shared_memory,
)
from gymnasium.vector.vector_env import ArrayType, VectorEnv
from gymnasium.vector.vector_env import ArrayType, AutoresetMode, VectorEnv


__all__ = ["AsyncVectorEnv", "AsyncState"]
Expand Down Expand Up @@ -101,6 +101,7 @@ def __init__(
| None
) = None,
observation_mode: str | Space = "same",
autoreset_mode: str | AutoresetMode = AutoresetMode.NEXT_STEP,
):
"""Vectorized environment that runs multiple environments in parallel.

Expand All @@ -120,6 +121,7 @@ def __init__(
'different' defines that there can be multiple observation spaces with different parameters though requires the same shape and dtype,
warning, may raise unexpected errors. Passing a ``Tuple[Space, Space]`` object allows defining a custom ``single_observation_space`` and
``observation_space``, warning, may raise unexpected errors.
autoreset_mode: The Autoreset Mode used, see todo for more details.

Warnings:
worker is an advanced mode option. It provides a high degree of flexibility and a high chance
Expand All @@ -135,7 +137,15 @@ def __init__(
self.env_fns = env_fns
self.shared_memory = shared_memory
self.copy = copy
self.context = context
self.daemon = daemon
self.worker = worker
self.observation_mode = observation_mode
self.autoreset_mode = (
autoreset_mode
if isinstance(autoreset_mode, AutoresetMode)
else AutoresetMode(autoreset_mode)
)

self.num_envs = len(env_fns)

Expand All @@ -145,6 +155,7 @@ def __init__(

# As we support `make_vec(spec)` then we can't include a `spec = dummy_env.spec` as this doesn't guarantee we can actual recreate the vector env.
self.metadata = dummy_env.metadata
self.metadata["autoreset_mode"] = self.autoreset_mode
self.render_mode = dummy_env.render_mode

self.single_action_space = dummy_env.action_space
Expand Down Expand Up @@ -211,6 +222,7 @@ def __init__(
parent_pipe,
_obs_buffer,
self.error_queue,
self.autoreset_mode,
),
)

Expand Down Expand Up @@ -287,9 +299,32 @@ def reset_async(
str(self._state.value),
)

for pipe, env_seed in zip(self.parent_pipes, seed):
env_kwargs = {"seed": env_seed, "options": options}
pipe.send(("reset", env_kwargs))
if options is not None and "reset_mask" in options:
reset_mask = options.pop("reset_mask")
assert isinstance(
reset_mask, np.ndarray
), f"`options['reset_mask': mask]` must be a numpy array, got {type(reset_mask)}"
assert reset_mask.shape == (
self.num_envs,
), f"`options['reset_mask': mask]` must have shape `({self.num_envs},)`, got {reset_mask.shape}"
assert (
reset_mask.dtype == np.bool_
), f"`options['reset_mask': mask]` must have `dtype=np.bool_`, got {reset_mask.dtype}"
assert np.any(
reset_mask
), f"`options['reset_mask': mask]` must contain a boolean array, got reset_mask={reset_mask}"

for pipe, env_seed, env_reset in zip(self.parent_pipes, seed, reset_mask):
if env_reset:
env_kwargs = {"seed": env_seed, "options": options}
pipe.send(("reset", env_kwargs))
else:
pipe.send(("reset-noop", None))
else:
for pipe, env_seed in zip(self.parent_pipes, seed):
env_kwargs = {"seed": env_seed, "options": options}
pipe.send(("reset", env_kwargs))

self._state = AsyncState.WAITING_RESET

def reset_wait(
Expand Down Expand Up @@ -688,11 +723,13 @@ def _async_worker(
parent_pipe: Connection,
shared_memory: multiprocessing.Array | dict[str, Any] | tuple[Any, ...],
error_queue: Queue,
autoreset_mode: AutoresetMode,
):
env = env_fn()
observation_space = env.observation_space
action_space = env.action_space
autoreset = False
observation = None

parent_pipe.close()

Expand All @@ -709,19 +746,51 @@ def _async_worker(
observation = None
autoreset = False
pipe.send(((observation, info), True))
elif command == "reset-noop":
pipe.send(((observation, {}), True))
elif command == "step":
if autoreset:
observation, info = env.reset()
reward, terminated, truncated = 0, False, False
else:
if autoreset_mode == AutoresetMode.NEXT_STEP:
if autoreset:
observation, info = env.reset()
reward, terminated, truncated = 0, False, False
else:
(
observation,
reward,
terminated,
truncated,
info,
) = env.step(data)
autoreset = terminated or truncated
elif autoreset_mode == AutoresetMode.SAME_STEP:
(
observation,
reward,
terminated,
truncated,
info,
) = env.step(data)
autoreset = terminated or truncated

if terminated or truncated:
reset_observation, reset_info = env.reset()

info = {
"final_info": info,
"final_obs": observation,
**reset_info,
}
observation = reset_observation
elif autoreset_mode == AutoresetMode.DISABLED:
assert autoreset is False
(
observation,
reward,
terminated,
truncated,
info,
) = env.step(data)
else:
raise ValueError(f"Unexpected autoreset_mode: {autoreset_mode}")

if shared_memory:
write_to_shared_memory(
Expand Down
Loading
Loading