diff --git a/skrl/utils/spaces/torch/__init__.py b/skrl/utils/spaces/torch/__init__.py index c87ec936..7bfbb55e 100644 --- a/skrl/utils/spaces/torch/__init__.py +++ b/skrl/utils/spaces/torch/__init__.py @@ -7,7 +7,7 @@ import torch -__all__ = ["tensorize_space", "flatten_tensorized_space", "compute_space_size"] +__all__ = ["tensorize_space", "flatten_tensorized_space", "compute_space_size", "unflatten_tensorized_space", "sample_space"] def tensorize_space(space: spaces.Space, x: Any, device: Optional[Union[str, torch.device]] = None) -> Any: @@ -18,14 +18,14 @@ def tensorize_space(space: spaces.Space, x: Any, device: Optional[Union[str, tor (-1, space's shape). Composite spaces (:py:class:`~gymnasium.spaces.Dict`, :py:class:`~gymnasium.spaces.Tuple`, and :py:class:`~gymnasium.spaces.Sequence`) are converted by recursively calling this function on their elements. - :param space: Gymnasium space - :param x: Sample/value of the given space to tensorize to + :param space: Gymnasium space. + :param x: Sample/value of the given space to tensorize to. :param device: Device on which a tensor/array is or will be allocated (default: ``None``). - This parameter is used when the space value is not a PyTorch tensor (e.g.: numpy array, number) + This parameter is used when the space value is not a PyTorch tensor (e.g.: numpy array, number). - :raises ValueError: The conversion of the sample/value type is not supported for the given space + :raises ValueError: The conversion of the sample/value type is not supported for the given space. - :return: Sample/value space with items converted to tensors + :return: Sample/value space with items converted to tensors. """ # fundamental spaces # Box @@ -68,28 +68,71 @@ def tensorize_space(space: spaces.Space, x: Any, device: Optional[Union[str, tor return tuple([tensorize_space(space.feature_space, _x, device) for _x in x]) def flatten_tensorized_space(x: Any) -> torch.Tensor: - """Flatten a tensorized space (see :py:func:`~skrl.utils.spaces.torch.tensorize_space`). + """Flatten a tensorized space. - :param x: Tensorized space sample/value + :param x: Tensorized space sample/value. - :return: A tensor. The returned tensor will have shape (batch, space size) + :return: A tensor. The returned tensor will have shape (batch, space size). """ + # fundamental spaces + # Box / Discrete / MultiDiscrete if isinstance(x, torch.Tensor): return x.view(x.shape[0], -1) if x.ndim > 1 else x.view(1, -1) + # composite spaces + # Dict elif isinstance(x, dict): return torch.cat([flatten_tensorized_space(x[k])for k in sorted(x.keys())], dim=-1) + # Tuple / Sequence elif type(x) in [list, tuple]: return torch.cat([flatten_tensorized_space(_x) for _x in x], dim=-1) +def unflatten_tensorized_space(space: Union[spaces.Space, Sequence[int], int], x: torch.Tensor) -> Any: + """Unflatten a tensor to create a tensorized space. + + :param space: Gymnasium space. + :param x: A tensor with shape (batch, space size). + + :return: Tensorized space value. + """ + # fundamental spaces + # Box + if isinstance(space, spaces.Box): + return x.view(-1, *space.shape) + # Discrete + elif isinstance(space, spaces.Discrete): + return x.view(-1, 1) + # MultiDiscrete + elif isinstance(space, spaces.MultiDiscrete): + return x.view(-1, *space.shape) + # composite spaces + # Dict + elif isinstance(space, spaces.Dict): + start = 0 + output = {} + for k in sorted(space.keys()): + end = start + compute_space_size(space[k], occupied_size=True) + output[k] = unflatten_tensorized_space(space[k], x[:, start:end]) + start = end + return output + # Tuple + elif isinstance(space, spaces.Tuple): + start = 0 + output = [] + for s in space: + end = start + compute_space_size(s, occupied_size=True) + output.append(unflatten_tensorized_space(s, x[:, start:end])) + start = end + return output + def compute_space_size(space: Union[spaces.Space, Sequence[int], int], occupied_size: bool = False) -> int: - """Get the size (number of elements) of a space + """Get the size (number of elements) of a space. - :param space: Gymnasium space + :param space: Gymnasium space. :param occupied_size: Whether the number of elements occupied by the space is returned (default: ``False``). It only affects :py:class:`~gymnasium.spaces.Discrete` (occupied space is 1), - and :py:class:`~gymnasium.spaces.MultiDiscrete` (occupied space is the number of discrete spaces) + and :py:class:`~gymnasium.spaces.MultiDiscrete` (occupied space is the number of discrete spaces). - :return: Size of the space (number of elements) + :return: Size of the space (number of elements). """ if occupied_size: # fundamental spaces @@ -115,13 +158,13 @@ def compute_space_size(space: Union[spaces.Space, Sequence[int], int], occupied_ return gymnasium.spaces.flatdim(space) def sample_space(space: spaces.Space, batch_size: int = 1, backend: str = Literal["numpy", "torch"], device = None) -> Any: - """Generates a random sample from the specified space + """Generates a random sample from the specified space. - :param space: Gymnasium space - :param batch_size: Size of the sampled batch (default: ``1``) - :param backend: Whether backend will be used to construct the fundamental spaces (default: ``"numpy"``) + :param space: Gymnasium space. + :param batch_size: Size of the sampled batch (default: ``1``). + :param backend: Whether backend will be used to construct the fundamental spaces (default: ``"numpy"``). :param device: Device on which a tensor/array is or will be allocated (default: ``None``). - This parameter is used when the backend is ``"torch"`` + This parameter is used when the backend is ``"torch"``. :return: Sample of the space """ diff --git a/tests/torch/test_utils_spaces.py b/tests/torch/test_utils_spaces.py index 8ae939ef..6ed71594 100644 --- a/tests/torch/test_utils_spaces.py +++ b/tests/torch/test_utils_spaces.py @@ -1,15 +1,44 @@ import hypothesis import hypothesis.strategies as st -import pytest import gymnasium as gym import numpy as np import torch -from skrl.utils.spaces.torch import compute_space_size, flatten_tensorized_space, tensorize_space +from skrl.utils.spaces.torch import ( + compute_space_size, + flatten_tensorized_space, + sample_space, + tensorize_space, + unflatten_tensorized_space +) +def _check_backend(x, backend): + if backend == "torch": + assert isinstance(x, torch.Tensor) + elif backend == "numpy": + assert isinstance(x, np.ndarray) + else: + raise ValueError(f"Invalid backend type: {backend}") + +def check_sampled_space(space, x, n, backend): + if isinstance(space, gym.spaces.Box): + _check_backend(x, backend) + assert x.shape == (n, *space.shape) + elif isinstance(space, gym.spaces.Discrete): + _check_backend(x, backend) + assert x.shape == (n, 1) + elif isinstance(space, gym.spaces.MultiDiscrete): + assert x.shape == (n, *space.nvec.shape) + elif isinstance(space, gym.spaces.Dict): + list(map(check_sampled_space, space.values(), x.values(), [n] * len(space), [backend] * len(space))) + elif isinstance(space, gym.spaces.Tuple): + list(map(check_sampled_space, space, x, [n] * len(space), [backend] * len(space))) + else: + raise ValueError(f"Invalid space type: {type(space)}") + @st.composite def space_stategy(draw, space_type: str = "", remaining_iterations: int = 5) -> gym.spaces.Space: if not space_type: @@ -63,33 +92,66 @@ def occupied_size(s): @hypothesis.given(space=space_stategy()) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) def test_tensorize_space(capsys, space: gym.spaces.Space): - def check_tensorized_space(s, x): + def check_tensorized_space(s, x, n): if isinstance(s, gym.spaces.Box): - assert x.shape == torch.Size([1, *s.shape]) + assert isinstance(x, torch.Tensor) and x.shape == (n, *s.shape) elif isinstance(s, gym.spaces.Discrete): - assert x.ndim == 2 and x.shape[1] == 1 + assert isinstance(x, torch.Tensor) and x.shape == (n, 1) elif isinstance(s, gym.spaces.MultiDiscrete): - assert x.ndim == 2 and x.shape[1] == s.nvec.shape[0] + assert isinstance(x, torch.Tensor) and x.shape == (n, *s.nvec.shape) elif isinstance(s, gym.spaces.Dict): - list(map(check_tensorized_space, s.values(), x.values())) + list(map(check_tensorized_space, s.values(), x.values(), [n] * len(s))) elif isinstance(s, gym.spaces.Tuple): - list(map(check_tensorized_space, s, x)) + list(map(check_tensorized_space, s, x, [n] * len(s))) else: raise ValueError(f"Invalid space type: {type(s)}") tensorized_space = tensorize_space(space, space.sample()) - check_tensorized_space(space, tensorized_space) + check_tensorized_space(space, tensorized_space, 1) tensorized_space = tensorize_space(space, tensorized_space) - check_tensorized_space(space, tensorized_space) + check_tensorized_space(space, tensorized_space, 1) + + sampled_space = sample_space(space, 5, backend="numpy") + tensorized_space = tensorize_space(space, sampled_space) + check_tensorized_space(space, tensorized_space, 5) + + sampled_space = sample_space(space, 5, backend="torch") + tensorized_space = tensorize_space(space, sampled_space) + check_tensorized_space(space, tensorized_space, 5) + +@hypothesis.given(space=space_stategy(), batch_size=st.integers(min_value=1, max_value=10)) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +def test_sample_space(capsys, space: gym.spaces.Space, batch_size: int): + + sampled_space = sample_space(space, batch_size, backend="numpy") + check_sampled_space(space, sampled_space, batch_size, backend="numpy") + + sampled_space = sample_space(space, batch_size, backend="torch") + check_sampled_space(space, sampled_space, batch_size, backend="torch") @hypothesis.given(space=space_stategy()) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) def test_flatten_tensorized_space(capsys, space: gym.spaces.Space): - tensorized_space = tensorize_space(space, space.sample()) space_size = compute_space_size(space, occupied_size=True) + tensorized_space = tensorize_space(space, space.sample()) + flattened_space = flatten_tensorized_space(tensorized_space) + assert flattened_space.shape == (1, space_size) + + tensorized_space = sample_space(space, batch_size=5, backend="torch") + flattened_space = flatten_tensorized_space(tensorized_space) + assert flattened_space.shape == (5, space_size) + +@hypothesis.given(space=space_stategy()) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +def test_unflatten_tensorized_space(capsys, space: gym.spaces.Space): + tensorized_space = tensorize_space(space, space.sample()) + flattened_space = flatten_tensorized_space(tensorized_space) + unflattened_space = unflatten_tensorized_space(space, flattened_space) + check_sampled_space(space, unflattened_space, 1, backend="torch") + + tensorized_space = sample_space(space, batch_size=5, backend="torch") flattened_space = flatten_tensorized_space(tensorized_space) - with capsys.disabled(): - print(space, flattened_space.shape) - assert flattened_space.shape == torch.Size([1, space_size]) + unflattened_space = unflatten_tensorized_space(space, flattened_space) + check_sampled_space(space, unflattened_space, 5, backend="torch")