Skip to content

Commit

Permalink
Convert gym wrapper spaces to gymnasium
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Oct 1, 2024
1 parent 43bdc34 commit a92ebed
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 52 deletions.
68 changes: 20 additions & 48 deletions skrl/envs/wrappers/torch/gym_envs.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import Any, Optional, Tuple

import gym
import gymnasium
from packaging import version

import numpy as np
import torch

from skrl import logger
from skrl.envs.wrappers.torch.base import Wrapper
from skrl.utils.spaces.torch import convert_gym_space, flatten_tensorized_space, tensorize_space


class GymWrapper(Wrapper):
Expand Down Expand Up @@ -40,51 +42,20 @@ def __init__(self, env: Any) -> None:
logger.warning(f"Using a deprecated version of OpenAI Gym's API: {gym.__version__}")

@property
def observation_space(self) -> gym.Space:
def observation_space(self) -> gymnasium.Space:
"""Observation space
"""
if self._vectorized:
return self._env.single_observation_space
return self._env.observation_space
return convert_gym_space(self._env.single_observation_space)
return convert_gym_space(self._env.observation_space)

@property
def action_space(self) -> gym.Space:
def action_space(self) -> gymnasium.Space:
"""Action space
"""
if self._vectorized:
return self._env.single_action_space
return self._env.action_space

def _observation_to_tensor(self, observation: Any, space: Optional[gym.Space] = None) -> torch.Tensor:
"""Convert the OpenAI Gym observation to a flat tensor
:param observation: The OpenAI Gym observation to convert to a tensor
:type observation: Any supported OpenAI Gym observation space
:raises: ValueError if the observation space type is not supported
:return: The observation as a flat tensor
:rtype: torch.Tensor
"""
observation_space = self._env.observation_space if self._vectorized else self.observation_space
space = space if space is not None else observation_space

if self._vectorized and isinstance(space, gym.spaces.MultiDiscrete):
return torch.tensor(observation, device=self.device, dtype=torch.int64).view(self.num_envs, -1)
elif isinstance(observation, int):
return torch.tensor(observation, device=self.device, dtype=torch.int64).view(self.num_envs, -1)
elif isinstance(observation, np.ndarray):
return torch.tensor(observation, device=self.device, dtype=torch.float32).view(self.num_envs, -1)
elif isinstance(space, gym.spaces.Discrete):
return torch.tensor(observation, device=self.device, dtype=torch.float32).view(self.num_envs, -1)
elif isinstance(space, gym.spaces.Box):
return torch.tensor(observation, device=self.device, dtype=torch.float32).view(self.num_envs, -1)
elif isinstance(space, gym.spaces.Dict):
tmp = torch.cat([self._observation_to_tensor(observation[k], space[k]) \
for k in sorted(space.keys())], dim=-1).view(self.num_envs, -1)
return tmp
else:
raise ValueError(f"Observation space type {type(space)} not supported. Please report this issue")
return convert_gym_space(self._env.single_action_space)
return convert_gym_space(self._env.action_space)

def _tensor_to_action(self, actions: torch.Tensor) -> Any:
"""Convert the action to the OpenAI Gym expected format
Expand All @@ -97,21 +68,21 @@ def _tensor_to_action(self, actions: torch.Tensor) -> Any:
:return: The action in the OpenAI Gym format
:rtype: Any supported OpenAI Gym action space
"""
space = self._env.action_space if self._vectorized else self.action_space
space = convert_gym_space(self._env.action_space) if self._vectorized else self.action_space

if self._vectorized:
if isinstance(space, gym.spaces.MultiDiscrete):
if isinstance(space, gymnasium.spaces.MultiDiscrete):
return np.array(actions.cpu().numpy(), dtype=space.dtype).reshape(space.shape)
elif isinstance(space, gym.spaces.Tuple):
if isinstance(space[0], gym.spaces.Box):
elif isinstance(space, gymnasium.spaces.Tuple):
if isinstance(space[0], gymnasium.spaces.Box):
return np.array(actions.cpu().numpy(), dtype=space[0].dtype).reshape(space.shape)
elif isinstance(space[0], gym.spaces.Discrete):
elif isinstance(space[0], gymnasium.spaces.Discrete):
return np.array(actions.cpu().numpy(), dtype=space[0].dtype).reshape(-1)
if isinstance(space, gym.spaces.Discrete):
if isinstance(space, gymnasium.spaces.Discrete):
return actions.item()
elif isinstance(space, gym.spaces.MultiDiscrete):
elif isinstance(space, gymnasium.spaces.MultiDiscrete):
return np.array(actions.cpu().numpy(), dtype=space.dtype).reshape(space.shape)
elif isinstance(space, gym.spaces.Box):
elif isinstance(space, gymnasium.spaces.Box):
return np.array(actions.cpu().numpy(), dtype=space.dtype).reshape(space.shape)
raise ValueError(f"Action space type {type(space)} not supported. Please report this issue")

Expand All @@ -138,7 +109,7 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch
observation, reward, terminated, truncated, info = self._env.step(self._tensor_to_action(actions))

# convert response to torch
observation = self._observation_to_tensor(observation)
observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device))
reward = torch.tensor(reward, device=self.device, dtype=torch.float32).view(self.num_envs, -1)
terminated = torch.tensor(terminated, device=self.device, dtype=torch.bool).view(self.num_envs, -1)
truncated = torch.tensor(truncated, device=self.device, dtype=torch.bool).view(self.num_envs, -1)
Expand All @@ -164,7 +135,7 @@ def reset(self) -> Tuple[torch.Tensor, Any]:
self._info = {}
else:
observation, self._info = self._env.reset()
self._observation = self._observation_to_tensor(observation)
self._observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device))
self._reset_once = False
return self._observation, self._info

Expand All @@ -173,7 +144,8 @@ def reset(self) -> Tuple[torch.Tensor, Any]:
info = {}
else:
observation, info = self._env.reset()
return self._observation_to_tensor(observation), info
observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device))
return observation, info

def render(self, *args, **kwargs) -> Any:
"""Render the environment
Expand Down
9 changes: 5 additions & 4 deletions tests/torch/test_torch_wrapper_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections.abc import Mapping
import gym
import gymnasium

import torch

Expand All @@ -21,8 +22,8 @@ def test_env(capsys: pytest.CaptureFixture):

# check properties
assert env.state_space is None
assert isinstance(env.observation_space, gym.Space) and env.observation_space.shape == (3,)
assert isinstance(env.action_space, gym.Space) and env.action_space.shape == (1,)
assert isinstance(env.observation_space, gymnasium.Space) and env.observation_space.shape == (3,)
assert isinstance(env.action_space, gymnasium.Space) and env.action_space.shape == (1,)
assert isinstance(env.num_envs, int) and env.num_envs == num_envs
assert isinstance(env.num_agents, int) and env.num_agents == 1
assert isinstance(env.device, torch.device)
Expand Down Expand Up @@ -59,8 +60,8 @@ def test_vectorized_env(capsys: pytest.CaptureFixture, vectorization_mode: str):

# check properties
assert env.state_space is None
assert isinstance(env.observation_space, gym.Space) and env.observation_space.shape == (3,)
assert isinstance(env.action_space, gym.Space) and env.action_space.shape == (1,)
assert isinstance(env.observation_space, gymnasium.Space) and env.observation_space.shape == (3,)
assert isinstance(env.action_space, gymnasium.Space) and env.action_space.shape == (1,)
assert isinstance(env.num_envs, int) and env.num_envs == num_envs
assert isinstance(env.num_agents, int) and env.num_agents == 1
assert isinstance(env.device, torch.device)
Expand Down

0 comments on commit a92ebed

Please sign in to comment.