Skip to content

Commit

Permalink
Add utility to convert gym spaces to gymnasium spaces
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Oct 1, 2024
1 parent 834e94a commit 43bdc34
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion skrl/utils/spaces/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,34 @@
import torch


__all__ = ["tensorize_space", "flatten_tensorized_space", "compute_space_size", "unflatten_tensorized_space", "sample_space"]
__all__ = ["convert_gym_space", "tensorize_space", "flatten_tensorized_space", "compute_space_size", "unflatten_tensorized_space", "sample_space"]


def convert_gym_space(space: "gym.Space") -> gymnasium.Space:
"""Converts a gym space to a gymnasium space.
:param space: Gym space to convert to.
:raises NotImplementedError: The conversion is not supported for the given space.
:return: Converted space.
"""
import gym

if isinstance(space, gym.spaces.Discrete):
return spaces.Discrete(n=space.n)
elif isinstance(space, gym.spaces.Box):
return spaces.Box(low=space.low, high=space.high, shape=space.shape, dtype=space.dtype)
elif isinstance(space, gym.spaces.MultiDiscrete):
return spaces.MultiDiscrete(nvec=space.nvec)
elif isinstance(space, gym.spaces.Tuple):
return spaces.Tuple(spaces=tuple(map(convert_gym_space, space.spaces)))
elif isinstance(space, gym.spaces.Dict):
return spaces.Dict(spaces={k: convert_gym_space(v) for k, v in space.spaces.items()})
elif isinstance(space, gym.spaces.Sequence):
return spaces.Sequence(space=convert_gym_space(space.feature_space))
raise NotImplementedError(f"Unsupported space ({space})")

def tensorize_space(space: spaces.Space, x: Any, device: Optional[Union[str, torch.device]] = None) -> Any:
"""Convert the sample/value items of a given gymnasium space to PyTorch tensors.
Expand Down

0 comments on commit 43bdc34

Please sign in to comment.