From a92ebeda32bb71f006baf34da8ea93723105ddbe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Tue, 1 Oct 2024 15:43:49 -0400 Subject: [PATCH] Convert gym wrapper spaces to gymnasium --- skrl/envs/wrappers/torch/gym_envs.py | 68 ++++++++------------------- tests/torch/test_torch_wrapper_gym.py | 9 ++-- 2 files changed, 25 insertions(+), 52 deletions(-) diff --git a/skrl/envs/wrappers/torch/gym_envs.py b/skrl/envs/wrappers/torch/gym_envs.py index 4e326818..8b3be16b 100644 --- a/skrl/envs/wrappers/torch/gym_envs.py +++ b/skrl/envs/wrappers/torch/gym_envs.py @@ -1,6 +1,7 @@ from typing import Any, Optional, Tuple import gym +import gymnasium from packaging import version import numpy as np @@ -8,6 +9,7 @@ from skrl import logger from skrl.envs.wrappers.torch.base import Wrapper +from skrl.utils.spaces.torch import convert_gym_space, flatten_tensorized_space, tensorize_space class GymWrapper(Wrapper): @@ -40,51 +42,20 @@ def __init__(self, env: Any) -> None: logger.warning(f"Using a deprecated version of OpenAI Gym's API: {gym.__version__}") @property - def observation_space(self) -> gym.Space: + def observation_space(self) -> gymnasium.Space: """Observation space """ if self._vectorized: - return self._env.single_observation_space - return self._env.observation_space + return convert_gym_space(self._env.single_observation_space) + return convert_gym_space(self._env.observation_space) @property - def action_space(self) -> gym.Space: + def action_space(self) -> gymnasium.Space: """Action space """ if self._vectorized: - return self._env.single_action_space - return self._env.action_space - - def _observation_to_tensor(self, observation: Any, space: Optional[gym.Space] = None) -> torch.Tensor: - """Convert the OpenAI Gym observation to a flat tensor - - :param observation: The OpenAI Gym observation to convert to a tensor - :type observation: Any supported OpenAI Gym observation space - - :raises: ValueError if the observation space type is not supported - - :return: The observation as a flat tensor - :rtype: torch.Tensor - """ - observation_space = self._env.observation_space if self._vectorized else self.observation_space - space = space if space is not None else observation_space - - if self._vectorized and isinstance(space, gym.spaces.MultiDiscrete): - return torch.tensor(observation, device=self.device, dtype=torch.int64).view(self.num_envs, -1) - elif isinstance(observation, int): - return torch.tensor(observation, device=self.device, dtype=torch.int64).view(self.num_envs, -1) - elif isinstance(observation, np.ndarray): - return torch.tensor(observation, device=self.device, dtype=torch.float32).view(self.num_envs, -1) - elif isinstance(space, gym.spaces.Discrete): - return torch.tensor(observation, device=self.device, dtype=torch.float32).view(self.num_envs, -1) - elif isinstance(space, gym.spaces.Box): - return torch.tensor(observation, device=self.device, dtype=torch.float32).view(self.num_envs, -1) - elif isinstance(space, gym.spaces.Dict): - tmp = torch.cat([self._observation_to_tensor(observation[k], space[k]) \ - for k in sorted(space.keys())], dim=-1).view(self.num_envs, -1) - return tmp - else: - raise ValueError(f"Observation space type {type(space)} not supported. Please report this issue") + return convert_gym_space(self._env.single_action_space) + return convert_gym_space(self._env.action_space) def _tensor_to_action(self, actions: torch.Tensor) -> Any: """Convert the action to the OpenAI Gym expected format @@ -97,21 +68,21 @@ def _tensor_to_action(self, actions: torch.Tensor) -> Any: :return: The action in the OpenAI Gym format :rtype: Any supported OpenAI Gym action space """ - space = self._env.action_space if self._vectorized else self.action_space + space = convert_gym_space(self._env.action_space) if self._vectorized else self.action_space if self._vectorized: - if isinstance(space, gym.spaces.MultiDiscrete): + if isinstance(space, gymnasium.spaces.MultiDiscrete): return np.array(actions.cpu().numpy(), dtype=space.dtype).reshape(space.shape) - elif isinstance(space, gym.spaces.Tuple): - if isinstance(space[0], gym.spaces.Box): + elif isinstance(space, gymnasium.spaces.Tuple): + if isinstance(space[0], gymnasium.spaces.Box): return np.array(actions.cpu().numpy(), dtype=space[0].dtype).reshape(space.shape) - elif isinstance(space[0], gym.spaces.Discrete): + elif isinstance(space[0], gymnasium.spaces.Discrete): return np.array(actions.cpu().numpy(), dtype=space[0].dtype).reshape(-1) - if isinstance(space, gym.spaces.Discrete): + if isinstance(space, gymnasium.spaces.Discrete): return actions.item() - elif isinstance(space, gym.spaces.MultiDiscrete): + elif isinstance(space, gymnasium.spaces.MultiDiscrete): return np.array(actions.cpu().numpy(), dtype=space.dtype).reshape(space.shape) - elif isinstance(space, gym.spaces.Box): + elif isinstance(space, gymnasium.spaces.Box): return np.array(actions.cpu().numpy(), dtype=space.dtype).reshape(space.shape) raise ValueError(f"Action space type {type(space)} not supported. Please report this issue") @@ -138,7 +109,7 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch observation, reward, terminated, truncated, info = self._env.step(self._tensor_to_action(actions)) # convert response to torch - observation = self._observation_to_tensor(observation) + observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device)) reward = torch.tensor(reward, device=self.device, dtype=torch.float32).view(self.num_envs, -1) terminated = torch.tensor(terminated, device=self.device, dtype=torch.bool).view(self.num_envs, -1) truncated = torch.tensor(truncated, device=self.device, dtype=torch.bool).view(self.num_envs, -1) @@ -164,7 +135,7 @@ def reset(self) -> Tuple[torch.Tensor, Any]: self._info = {} else: observation, self._info = self._env.reset() - self._observation = self._observation_to_tensor(observation) + self._observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device)) self._reset_once = False return self._observation, self._info @@ -173,7 +144,8 @@ def reset(self) -> Tuple[torch.Tensor, Any]: info = {} else: observation, info = self._env.reset() - return self._observation_to_tensor(observation), info + observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device)) + return observation, info def render(self, *args, **kwargs) -> Any: """Render the environment diff --git a/tests/torch/test_torch_wrapper_gym.py b/tests/torch/test_torch_wrapper_gym.py index a9e1aefb..cfee7672 100644 --- a/tests/torch/test_torch_wrapper_gym.py +++ b/tests/torch/test_torch_wrapper_gym.py @@ -2,6 +2,7 @@ from collections.abc import Mapping import gym +import gymnasium import torch @@ -21,8 +22,8 @@ def test_env(capsys: pytest.CaptureFixture): # check properties assert env.state_space is None - assert isinstance(env.observation_space, gym.Space) and env.observation_space.shape == (3,) - assert isinstance(env.action_space, gym.Space) and env.action_space.shape == (1,) + assert isinstance(env.observation_space, gymnasium.Space) and env.observation_space.shape == (3,) + assert isinstance(env.action_space, gymnasium.Space) and env.action_space.shape == (1,) assert isinstance(env.num_envs, int) and env.num_envs == num_envs assert isinstance(env.num_agents, int) and env.num_agents == 1 assert isinstance(env.device, torch.device) @@ -59,8 +60,8 @@ def test_vectorized_env(capsys: pytest.CaptureFixture, vectorization_mode: str): # check properties assert env.state_space is None - assert isinstance(env.observation_space, gym.Space) and env.observation_space.shape == (3,) - assert isinstance(env.action_space, gym.Space) and env.action_space.shape == (1,) + assert isinstance(env.observation_space, gymnasium.Space) and env.observation_space.shape == (3,) + assert isinstance(env.action_space, gymnasium.Space) and env.action_space.shape == (1,) assert isinstance(env.num_envs, int) and env.num_envs == num_envs assert isinstance(env.num_agents, int) and env.num_agents == 1 assert isinstance(env.device, torch.device)