Skip to content

Commit

Permalink
Update test file
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Sep 26, 2024
1 parent c8fe715 commit 65cbc82
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 32 deletions.
79 changes: 61 additions & 18 deletions skrl/utils/spaces/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch


__all__ = ["tensorize_space", "flatten_tensorized_space", "compute_space_size"]
__all__ = ["tensorize_space", "flatten_tensorized_space", "compute_space_size", "unflatten_tensorized_space", "sample_space"]


def tensorize_space(space: spaces.Space, x: Any, device: Optional[Union[str, torch.device]] = None) -> Any:
Expand All @@ -18,14 +18,14 @@ def tensorize_space(space: spaces.Space, x: Any, device: Optional[Union[str, tor
(-1, space's shape). Composite spaces (:py:class:`~gymnasium.spaces.Dict`, :py:class:`~gymnasium.spaces.Tuple`,
and :py:class:`~gymnasium.spaces.Sequence`) are converted by recursively calling this function on their elements.
:param space: Gymnasium space
:param x: Sample/value of the given space to tensorize to
:param space: Gymnasium space.
:param x: Sample/value of the given space to tensorize to.
:param device: Device on which a tensor/array is or will be allocated (default: ``None``).
This parameter is used when the space value is not a PyTorch tensor (e.g.: numpy array, number)
This parameter is used when the space value is not a PyTorch tensor (e.g.: numpy array, number).
:raises ValueError: The conversion of the sample/value type is not supported for the given space
:raises ValueError: The conversion of the sample/value type is not supported for the given space.
:return: Sample/value space with items converted to tensors
:return: Sample/value space with items converted to tensors.
"""
# fundamental spaces
# Box
Expand Down Expand Up @@ -68,28 +68,71 @@ def tensorize_space(space: spaces.Space, x: Any, device: Optional[Union[str, tor
return tuple([tensorize_space(space.feature_space, _x, device) for _x in x])

def flatten_tensorized_space(x: Any) -> torch.Tensor:
"""Flatten a tensorized space (see :py:func:`~skrl.utils.spaces.torch.tensorize_space`).
"""Flatten a tensorized space.
:param x: Tensorized space sample/value
:param x: Tensorized space sample/value.
:return: A tensor. The returned tensor will have shape (batch, space size)
:return: A tensor. The returned tensor will have shape (batch, space size).
"""
# fundamental spaces
# Box / Discrete / MultiDiscrete
if isinstance(x, torch.Tensor):
return x.view(x.shape[0], -1) if x.ndim > 1 else x.view(1, -1)
# composite spaces
# Dict
elif isinstance(x, dict):
return torch.cat([flatten_tensorized_space(x[k])for k in sorted(x.keys())], dim=-1)
# Tuple / Sequence
elif type(x) in [list, tuple]:
return torch.cat([flatten_tensorized_space(_x) for _x in x], dim=-1)

def unflatten_tensorized_space(space: Union[spaces.Space, Sequence[int], int], x: torch.Tensor) -> Any:
"""Unflatten a tensor to create a tensorized space.
:param space: Gymnasium space.
:param x: A tensor with shape (batch, space size).
:return: Tensorized space value.
"""
# fundamental spaces
# Box
if isinstance(space, spaces.Box):
return x.view(-1, *space.shape)
# Discrete
elif isinstance(space, spaces.Discrete):
return x.view(-1, 1)
# MultiDiscrete
elif isinstance(space, spaces.MultiDiscrete):
return x.view(-1, *space.shape)
# composite spaces
# Dict
elif isinstance(space, spaces.Dict):
start = 0
output = {}
for k in sorted(space.keys()):
end = start + compute_space_size(space[k], occupied_size=True)
output[k] = unflatten_tensorized_space(space[k], x[:, start:end])
start = end
return output
# Tuple
elif isinstance(space, spaces.Tuple):
start = 0
output = []
for s in space:
end = start + compute_space_size(s, occupied_size=True)
output.append(unflatten_tensorized_space(s, x[:, start:end]))
start = end
return output

def compute_space_size(space: Union[spaces.Space, Sequence[int], int], occupied_size: bool = False) -> int:
"""Get the size (number of elements) of a space
"""Get the size (number of elements) of a space.
:param space: Gymnasium space
:param space: Gymnasium space.
:param occupied_size: Whether the number of elements occupied by the space is returned (default: ``False``).
It only affects :py:class:`~gymnasium.spaces.Discrete` (occupied space is 1),
and :py:class:`~gymnasium.spaces.MultiDiscrete` (occupied space is the number of discrete spaces)
and :py:class:`~gymnasium.spaces.MultiDiscrete` (occupied space is the number of discrete spaces).
:return: Size of the space (number of elements)
:return: Size of the space (number of elements).
"""
if occupied_size:
# fundamental spaces
Expand All @@ -115,13 +158,13 @@ def compute_space_size(space: Union[spaces.Space, Sequence[int], int], occupied_
return gymnasium.spaces.flatdim(space)

def sample_space(space: spaces.Space, batch_size: int = 1, backend: str = Literal["numpy", "torch"], device = None) -> Any:
"""Generates a random sample from the specified space
"""Generates a random sample from the specified space.
:param space: Gymnasium space
:param batch_size: Size of the sampled batch (default: ``1``)
:param backend: Whether backend will be used to construct the fundamental spaces (default: ``"numpy"``)
:param space: Gymnasium space.
:param batch_size: Size of the sampled batch (default: ``1``).
:param backend: Whether backend will be used to construct the fundamental spaces (default: ``"numpy"``).
:param device: Device on which a tensor/array is or will be allocated (default: ``None``).
This parameter is used when the backend is ``"torch"``
This parameter is used when the backend is ``"torch"``.
:return: Sample of the space
"""
Expand Down
90 changes: 76 additions & 14 deletions tests/torch/test_utils_spaces.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,44 @@
import hypothesis
import hypothesis.strategies as st
import pytest

import gymnasium as gym

import numpy as np
import torch

from skrl.utils.spaces.torch import compute_space_size, flatten_tensorized_space, tensorize_space
from skrl.utils.spaces.torch import (
compute_space_size,
flatten_tensorized_space,
sample_space,
tensorize_space,
unflatten_tensorized_space
)


def _check_backend(x, backend):
if backend == "torch":
assert isinstance(x, torch.Tensor)
elif backend == "numpy":
assert isinstance(x, np.ndarray)
else:
raise ValueError(f"Invalid backend type: {backend}")

def check_sampled_space(space, x, n, backend):
if isinstance(space, gym.spaces.Box):
_check_backend(x, backend)
assert x.shape == (n, *space.shape)
elif isinstance(space, gym.spaces.Discrete):
_check_backend(x, backend)
assert x.shape == (n, 1)
elif isinstance(space, gym.spaces.MultiDiscrete):
assert x.shape == (n, *space.nvec.shape)
elif isinstance(space, gym.spaces.Dict):
list(map(check_sampled_space, space.values(), x.values(), [n] * len(space), [backend] * len(space)))
elif isinstance(space, gym.spaces.Tuple):
list(map(check_sampled_space, space, x, [n] * len(space), [backend] * len(space)))
else:
raise ValueError(f"Invalid space type: {type(space)}")

@st.composite
def space_stategy(draw, space_type: str = "", remaining_iterations: int = 5) -> gym.spaces.Space:
if not space_type:
Expand Down Expand Up @@ -63,33 +92,66 @@ def occupied_size(s):
@hypothesis.given(space=space_stategy())
@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None)
def test_tensorize_space(capsys, space: gym.spaces.Space):
def check_tensorized_space(s, x):
def check_tensorized_space(s, x, n):
if isinstance(s, gym.spaces.Box):
assert x.shape == torch.Size([1, *s.shape])
assert isinstance(x, torch.Tensor) and x.shape == (n, *s.shape)
elif isinstance(s, gym.spaces.Discrete):
assert x.ndim == 2 and x.shape[1] == 1
assert isinstance(x, torch.Tensor) and x.shape == (n, 1)
elif isinstance(s, gym.spaces.MultiDiscrete):
assert x.ndim == 2 and x.shape[1] == s.nvec.shape[0]
assert isinstance(x, torch.Tensor) and x.shape == (n, *s.nvec.shape)
elif isinstance(s, gym.spaces.Dict):
list(map(check_tensorized_space, s.values(), x.values()))
list(map(check_tensorized_space, s.values(), x.values(), [n] * len(s)))
elif isinstance(s, gym.spaces.Tuple):
list(map(check_tensorized_space, s, x))
list(map(check_tensorized_space, s, x, [n] * len(s)))
else:
raise ValueError(f"Invalid space type: {type(s)}")

tensorized_space = tensorize_space(space, space.sample())
check_tensorized_space(space, tensorized_space)
check_tensorized_space(space, tensorized_space, 1)

tensorized_space = tensorize_space(space, tensorized_space)
check_tensorized_space(space, tensorized_space)
check_tensorized_space(space, tensorized_space, 1)

sampled_space = sample_space(space, 5, backend="numpy")
tensorized_space = tensorize_space(space, sampled_space)
check_tensorized_space(space, tensorized_space, 5)

sampled_space = sample_space(space, 5, backend="torch")
tensorized_space = tensorize_space(space, sampled_space)
check_tensorized_space(space, tensorized_space, 5)

@hypothesis.given(space=space_stategy(), batch_size=st.integers(min_value=1, max_value=10))
@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None)
def test_sample_space(capsys, space: gym.spaces.Space, batch_size: int):

sampled_space = sample_space(space, batch_size, backend="numpy")
check_sampled_space(space, sampled_space, batch_size, backend="numpy")

sampled_space = sample_space(space, batch_size, backend="torch")
check_sampled_space(space, sampled_space, batch_size, backend="torch")

@hypothesis.given(space=space_stategy())
@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None)
def test_flatten_tensorized_space(capsys, space: gym.spaces.Space):
tensorized_space = tensorize_space(space, space.sample())
space_size = compute_space_size(space, occupied_size=True)

tensorized_space = tensorize_space(space, space.sample())
flattened_space = flatten_tensorized_space(tensorized_space)
assert flattened_space.shape == (1, space_size)

tensorized_space = sample_space(space, batch_size=5, backend="torch")
flattened_space = flatten_tensorized_space(tensorized_space)
assert flattened_space.shape == (5, space_size)

@hypothesis.given(space=space_stategy())
@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None)
def test_unflatten_tensorized_space(capsys, space: gym.spaces.Space):
tensorized_space = tensorize_space(space, space.sample())
flattened_space = flatten_tensorized_space(tensorized_space)
unflattened_space = unflatten_tensorized_space(space, flattened_space)
check_sampled_space(space, unflattened_space, 1, backend="torch")

tensorized_space = sample_space(space, batch_size=5, backend="torch")
flattened_space = flatten_tensorized_space(tensorized_space)
with capsys.disabled():
print(space, flattened_space.shape)
assert flattened_space.shape == torch.Size([1, space_size])
unflattened_space = unflatten_tensorized_space(space, flattened_space)
check_sampled_space(space, unflattened_space, 5, backend="torch")

0 comments on commit 65cbc82

Please sign in to comment.