From 43bdc34515e5aeb52f1dcc6be65e63ac8f393583 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Tue, 1 Oct 2024 15:41:56 -0400 Subject: [PATCH] Add utility to convert gym spaces to gymnasium spaces --- skrl/utils/spaces/torch/__init__.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/skrl/utils/spaces/torch/__init__.py b/skrl/utils/spaces/torch/__init__.py index cd692d45..aba89fe3 100644 --- a/skrl/utils/spaces/torch/__init__.py +++ b/skrl/utils/spaces/torch/__init__.py @@ -7,9 +7,34 @@ import torch -__all__ = ["tensorize_space", "flatten_tensorized_space", "compute_space_size", "unflatten_tensorized_space", "sample_space"] +__all__ = ["convert_gym_space", "tensorize_space", "flatten_tensorized_space", "compute_space_size", "unflatten_tensorized_space", "sample_space"] +def convert_gym_space(space: "gym.Space") -> gymnasium.Space: + """Converts a gym space to a gymnasium space. + + :param space: Gym space to convert to. + + :raises NotImplementedError: The conversion is not supported for the given space. + + :return: Converted space. + """ + import gym + + if isinstance(space, gym.spaces.Discrete): + return spaces.Discrete(n=space.n) + elif isinstance(space, gym.spaces.Box): + return spaces.Box(low=space.low, high=space.high, shape=space.shape, dtype=space.dtype) + elif isinstance(space, gym.spaces.MultiDiscrete): + return spaces.MultiDiscrete(nvec=space.nvec) + elif isinstance(space, gym.spaces.Tuple): + return spaces.Tuple(spaces=tuple(map(convert_gym_space, space.spaces))) + elif isinstance(space, gym.spaces.Dict): + return spaces.Dict(spaces={k: convert_gym_space(v) for k, v in space.spaces.items()}) + elif isinstance(space, gym.spaces.Sequence): + return spaces.Sequence(space=convert_gym_space(space.feature_space)) + raise NotImplementedError(f"Unsupported space ({space})") + def tensorize_space(space: spaces.Space, x: Any, device: Optional[Union[str, torch.device]] = None) -> Any: """Convert the sample/value items of a given gymnasium space to PyTorch tensors.