Skip to content

Commit

Permalink
Add space utility to untensorize spaces
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Oct 1, 2024
1 parent 1579d1a commit d96b13a
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 4 deletions.
50 changes: 47 additions & 3 deletions skrl/utils/spaces/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
import torch


__all__ = ["convert_gym_space", "tensorize_space", "flatten_tensorized_space", "compute_space_size", "unflatten_tensorized_space", "sample_space"]


def convert_gym_space(space: "gym.Space") -> gymnasium.Space:
"""Converts a gym space to a gymnasium space.
Expand Down Expand Up @@ -89,6 +86,53 @@ def tensorize_space(space: spaces.Space, x: Any, device: Optional[Union[str, tor
elif isinstance(space, spaces.Tuple):
return tuple([tensorize_space(s, _x, device) for s, _x in zip(space, x)])

def untensorize_space(space: spaces.Space, x: Any, squeeze_batch_dimension: bool = True) -> Any:
"""Convert a tensorized space to a gymnasium space with expected sample/value item types.
:param space: Gymnasium space.
:param x: Tensorized space (Sample/value space where items are tensors).
:param squeeze_batch_dimension: Whether to remove the batch dimension. If True, only the
sample/value with a batch dimension of size 1 will be affected
:raises ValueError: The sample/value type is not a tensor.
:return: Sample/value space with expected item types.
"""
if x is None:
return None
# fundamental spaces
# Box
if isinstance(space, spaces.Box):
if isinstance(x, torch.Tensor):
array = np.array(x.cpu().numpy(), dtype=space.dtype)
if squeeze_batch_dimension and array.shape[0] == 1:
return array.reshape(space.shape)
return array.reshape(-1, *space.shape)
raise ValueError(f"Unsupported type ({type(x)}) for the given space ({space})")
# Discrete
elif isinstance(space, spaces.Discrete):
if isinstance(x, torch.Tensor):
array = np.array(x.cpu().numpy(), dtype=space.dtype)
if squeeze_batch_dimension and array.shape[0] == 1:
return array.item()
return array.reshape(-1, 1)
raise ValueError(f"Unsupported type ({type(x)}) for the given space ({space})")
# MultiDiscrete
elif isinstance(space, spaces.MultiDiscrete):
if isinstance(x, torch.Tensor):
array = np.array(x.cpu().numpy(), dtype=space.dtype)
if squeeze_batch_dimension and array.shape[0] == 1:
return array.reshape(space.nvec.shape)
return array.reshape(-1, *space.nvec.shape)
raise ValueError(f"Unsupported type ({type(x)}) for the given space ({space})")
# composite spaces
# Dict
elif isinstance(space, spaces.Dict):
return {k: untensorize_space(s, x[k], squeeze_batch_dimension) for k, s in space.items()}
# Tuple
elif isinstance(space, spaces.Tuple):
return tuple([untensorize_space(s, _x, squeeze_batch_dimension) for s, _x in zip(space, x)])

def flatten_tensorized_space(x: Any) -> torch.Tensor:
"""Flatten a tensorized space.
Expand Down
30 changes: 29 additions & 1 deletion tests/torch/test_utils_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
flatten_tensorized_space,
sample_space,
tensorize_space,
unflatten_tensorized_space
unflatten_tensorized_space,
untensorize_space
)

from ..stategies import gym_space_stategy, gymnasium_space_stategy
Expand Down Expand Up @@ -95,6 +96,33 @@ def check_tensorized_space(s, x, n):
tensorized_space = tensorize_space(space, sampled_space)
check_tensorized_space(space, tensorized_space, 5)

@hypothesis.given(space=gymnasium_space_stategy())
@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None)
def test_untensorize_space(capsys, space: gymnasium.spaces.Space):
def check_untensorized_space(s, x, squeeze_batch_dimension):
if isinstance(s, gymnasium.spaces.Box):
assert isinstance(x, np.ndarray)
assert x.shape == s.shape if squeeze_batch_dimension else (1, *s.shape)
elif isinstance(s, gymnasium.spaces.Discrete):
assert isinstance(x, (np.ndarray, int))
assert isinstance(x, int) if squeeze_batch_dimension else x.shape == (1, 1)
elif isinstance(s, gymnasium.spaces.MultiDiscrete):
assert isinstance(x, np.ndarray) and x.shape == s.nvec.shape if squeeze_batch_dimension else (1, *s.nvec.shape)
elif isinstance(s, gymnasium.spaces.Dict):
list(map(check_untensorized_space, s.values(), x.values(), [squeeze_batch_dimension] * len(s)))
elif isinstance(s, gymnasium.spaces.Tuple):
list(map(check_untensorized_space, s, x, [squeeze_batch_dimension] * len(s)))
else:
raise ValueError(f"Invalid space type: {type(s)}")

tensorized_space = tensorize_space(space, space.sample())

untensorized_space = untensorize_space(space, tensorized_space, squeeze_batch_dimension=False)
check_untensorized_space(space, untensorized_space, squeeze_batch_dimension=False)

untensorized_space = untensorize_space(space, tensorized_space, squeeze_batch_dimension=True)
check_untensorized_space(space, untensorized_space, squeeze_batch_dimension=True)

@hypothesis.given(space=gymnasium_space_stategy(), batch_size=st.integers(min_value=1, max_value=10))
@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None)
def test_sample_space(capsys, space: gymnasium.spaces.Space, batch_size: int):
Expand Down

0 comments on commit d96b13a

Please sign in to comment.