Skip to content

Commit

Permalink
Add (A)syncVectorEnv support for sub-envs with different obs spaces (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
reginald-mclean authored Aug 9, 2024
1 parent 1a92702 commit d20ac56
Show file tree
Hide file tree
Showing 5 changed files with 360 additions and 19 deletions.
52 changes: 46 additions & 6 deletions gymnasium/vector/async_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import numpy as np

from gymnasium import logger
from gymnasium import Space, logger
from gymnasium.core import ActType, Env, ObsType, RenderFrame
from gymnasium.error import (
AlreadyPendingCallError,
Expand All @@ -33,6 +33,10 @@
read_from_shared_memory,
write_to_shared_memory,
)
from gymnasium.vector.utils.batched_spaces import (
all_spaces_have_same_shape,
batch_differing_spaces,
)
from gymnasium.vector.vector_env import ArrayType, VectorEnv


Expand Down Expand Up @@ -98,6 +102,7 @@ def __init__(
]
| None
) = None,
observation_mode: str | Space = "same",
):
"""Vectorized environment that runs multiple environments in parallel.
Expand All @@ -113,6 +118,9 @@ def __init__(
so for some environments you may want to have it set to ``False``.
worker: If set, then use that worker in a subprocess instead of a default one.
Can be useful to override some inner vector env logic, for instance, how resets on termination or truncation are handled.
observation_mode: Defines how environment observation spaces should be batched. 'same' defines that there should be ``n`` copies of identical spaces.
'different' defines that there can be multiple observation spaces with the same length but different high/low values batched together. Passing a ``Space`` object
allows the user to set some custom observation space mode not covered by 'same' or 'different.'
Warnings:
worker is an advanced mode option. It provides a high degree of flexibility and a high chance
Expand All @@ -139,12 +147,29 @@ def __init__(
self.metadata = dummy_env.metadata
self.render_mode = dummy_env.render_mode

self.single_observation_space = dummy_env.observation_space
self.single_action_space = dummy_env.action_space

self.observation_space = batch_space(
self.single_observation_space, self.num_envs
)
if isinstance(observation_mode, Space):
self.observation_space = observation_mode
else:
if observation_mode == "same":
self.single_observation_space = dummy_env.observation_space
self.observation_space = batch_space(
self.single_observation_space, self.num_envs
)
elif observation_mode == "different":
current_spaces = [env().observation_space for env in self.env_fns]

assert all_spaces_have_same_shape(
current_spaces
), "Low & High values for observation spaces can be different but shapes need to be the same"

self.single_observation_space = batch_differing_spaces(current_spaces)

self.observation_space = self.single_observation_space

else:
raise ValueError("Need to pass in mode for batching observations")
self.action_space = batch_space(self.single_action_space, self.num_envs)

dummy_env.close()
Expand Down Expand Up @@ -716,7 +741,22 @@ def _async_worker(
elif command == "_check_spaces":
pipe.send(
(
(data[0] == observation_space, data[1] == action_space),
(
(data[0] == observation_space)
or (
hasattr(observation_space, "low")
and hasattr(observation_space, "high")
and np.any(
np.all(observation_space.low == data[0].low, axis=1)
)
and np.any(
np.all(
observation_space.high == data[0].high, axis=1
)
)
),
data[1] == action_space,
),
True,
)
)
Expand Down
74 changes: 62 additions & 12 deletions gymnasium/vector/sync_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@

import numpy as np

from gymnasium import Env
from gymnasium import Env, Space
from gymnasium.core import ActType, ObsType, RenderFrame
from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
from gymnasium.vector.utils.batched_spaces import (
all_spaces_have_same_shape,
all_spaces_have_same_type,
batch_differing_spaces,
)
from gymnasium.vector.vector_env import ArrayType, VectorEnv


Expand Down Expand Up @@ -57,13 +62,16 @@ def __init__(
self,
env_fns: Iterator[Callable[[], Env]] | Sequence[Callable[[], Env]],
copy: bool = True,
observation_mode: str | Space = "same",
):
"""Vectorized environment that serially runs multiple environments.
Args:
env_fns: iterable of callable functions that create the environments.
copy: If ``True``, then the :meth:`reset` and :meth:`step` methods return a copy of the observations.
observation_mode: Defines how environment observation spaces should be batched. 'same' defines that there should be ``n`` copies of identical spaces.
'different' defines that there can be multiple observation spaces with the same length but different high/low values batched together. Passing a ``Space`` object
allows the user to set some custom observation space mode not covered by 'same' or 'different.'
Raises:
RuntimeError: If the observation space of some sub-environment does not match observation_space
(or, by default, the observation space of the first sub-environment).
Expand All @@ -80,15 +88,39 @@ def __init__(
self.metadata = self.envs[0].metadata
self.render_mode = self.envs[0].render_mode

# Initialises the single spaces from the sub-environments
self.single_observation_space = self.envs[0].observation_space
self.single_action_space = self.envs[0].action_space

# Initialise the obs and action space based on the desired mode

if isinstance(observation_mode, Space):
self.observation_space = observation_mode
else:
if observation_mode == "same":
self.single_observation_space = self.envs[0].observation_space
self.single_action_space = self.envs[0].action_space

self.observation_space = batch_space(
self.single_observation_space, self.num_envs
)
elif observation_mode == "different":
current_spaces = [env.observation_space for env in self.envs]

assert all_spaces_have_same_shape(
current_spaces
), "Low & High values for observation spaces can be different but shapes need to be the same"
assert all_spaces_have_same_type(
current_spaces
), "Observation spaces must have same Space type"

self.observation_space = batch_differing_spaces(current_spaces)

self.single_observation_space = self.observation_space

else:
raise ValueError("Need to pass in mode for batching observations")

self._check_spaces()

# Initialise the obs and action space based on the single versions and num of sub-environments
self.observation_space = batch_space(
self.single_observation_space, self.num_envs
)
self.action_space = batch_space(self.single_action_space, self.num_envs)

# Initialise attributes used in `step` and `reset`
Expand Down Expand Up @@ -270,10 +302,28 @@ def _check_spaces(self) -> bool:
"""Check that each of the environments obs and action spaces are equivalent to the single obs and action space."""
for env in self.envs:
if not (env.observation_space == self.single_observation_space):
raise RuntimeError(
f"Some environments have an observation space different from `{self.single_observation_space}`. "
"In order to batch observations, the observation spaces from all environments must be equal."
)
if not (
hasattr(env.observation_space, "low")
and hasattr(env.observation_space, "high")
and np.any(
np.all(
env.observation_space.low
== self.single_observation_space.low,
axis=1,
)
)
and np.any(
np.all(
env.observation_space.high
== self.single_observation_space.high,
axis=1,
)
)
):
raise RuntimeError(
f"Some environments have an observation space different from `{self.single_observation_space}`. "
"In order to batch observations, the observation spaces from all environments must be equal."
)

if not (env.action_space == self.single_action_space):
raise RuntimeError(
Expand Down
141 changes: 141 additions & 0 deletions gymnasium/vector/utils/batched_spaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""Batching support for Spaces of same type but possibly varying low/high values."""

from __future__ import annotations

from copy import deepcopy
from functools import singledispatch

import numpy as np

from gymnasium import Space
from gymnasium.spaces import (
Box,
Dict,
Discrete,
Graph,
MultiBinary,
MultiDiscrete,
OneOf,
Sequence,
Text,
Tuple,
)


@singledispatch
def batch_differing_spaces(spaces: list[Space]):
"""Batch a Sequence of spaces that allows the subspaces to contain minor differences."""
assert len(spaces) > 0
assert all(isinstance(space, type(spaces[0])) for space in spaces)
assert type(spaces[0]) in batch_differing_spaces.registry

return batch_differing_spaces.dispatch(type(spaces[0]))(spaces)


@batch_differing_spaces.register(Box)
def _batch_differing_spaces_box(spaces: list[Box]):
assert all(spaces[0].dtype == space for space in spaces)

return Box(
low=np.array([space.low for space in spaces]),
high=np.array([space.high for space in spaces]),
dtype=spaces[0].dtype,
seed=deepcopy(spaces[0].np_random),
)


@batch_differing_spaces.register(Discrete)
def _batch_differing_spaces_discrete(spaces: list[Discrete]):
return MultiDiscrete(
nvec=np.array([space.n for space in spaces]),
start=np.array([space.start for space in spaces]),
seed=deepcopy(spaces[0].np_random),
)


@batch_differing_spaces.register(MultiDiscrete)
def _batch_differing_spaces_multi_discrete(spaces: list[MultiDiscrete]):
return Box(
low=np.array([space.start for space in spaces]),
high=np.array([space.start + space.nvec for space in spaces]) - 1,
dtype=spaces[0].dtype,
seed=deepcopy(spaces[0].np_random),
)


@batch_differing_spaces.register(MultiBinary)
def _batch_differing_spaces_multi_binary(spaces: list[MultiBinary]):
assert all(spaces[0].shape == space.shape for space in spaces)

return Box(
low=0,
high=1,
shape=(len(spaces),) + spaces[0].shape,
dtype=spaces[0].dtype,
seed=deepcopy(spaces[0].np_random),
)


@batch_differing_spaces.register(Tuple)
def _batch_differing_spaces_tuple(spaces: list[Tuple]):
return Tuple(
tuple(
batch_differing_spaces(subspaces)
for subspaces in zip(*[space.spaces for space in spaces])
),
seed=deepcopy(spaces[0].np_random),
)


@batch_differing_spaces.register(Dict)
def _batch_differing_spaces_dict(spaces: list[Dict]):
assert all(spaces[0].keys() == space.keys() for space in spaces)

return Dict(
{
key: batch_differing_spaces([space[key] for space in spaces])
for key in spaces[0].keys()
},
seed=deepcopy(spaces[0].np_random),
)


@batch_differing_spaces.register(Graph)
@batch_differing_spaces.register(Text)
@batch_differing_spaces.register(Sequence)
@batch_differing_spaces.register(OneOf)
def _batch_spaces_undefined(spaces: list[Graph | Text | Sequence | OneOf]):
return Tuple(spaces, seed=deepcopy(spaces[0].np_random))


def all_spaces_have_same_shape(spaces):
"""Check if all spaces have the same size."""
if not spaces:
return True # An empty list is considered to have the same shape

def get_space_shape(space):
if isinstance(space, Box):
return space.shape
elif isinstance(space, Discrete):
return () # Discrete spaces are considered scalar
elif isinstance(space, Dict):
return tuple(get_space_shape(s) for s in space.spaces.values())
elif isinstance(space, Tuple):
return tuple(get_space_shape(s) for s in space.spaces)
else:
raise ValueError(f"Unsupported space type: {type(space)}")

first_shape = get_space_shape(spaces[0])
return all(get_space_shape(space) == first_shape for space in spaces[1:])


def all_spaces_have_same_type(spaces):
"""Check if all spaces have the same space type (Box, Discrete, etc)."""
if not spaces:
return True # An empty list is considered to have the same type

# Get the type of the first space
first_type = type(spaces[0])

# Check if all spaces have the same type as the first one
return all(isinstance(space, first_type) for space in spaces)
Loading

0 comments on commit d20ac56

Please sign in to comment.