Skip to content

Commit

Permalink
Use spaces utils to process actions in Gym wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Oct 2, 2024
1 parent fbde435 commit 6a4b24d
Showing 1 changed file with 14 additions and 33 deletions.
47 changes: 14 additions & 33 deletions skrl/envs/wrappers/torch/gym_envs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Tuple
from typing import Any, Tuple

import gym
import gymnasium
Expand All @@ -9,7 +9,13 @@

from skrl import logger
from skrl.envs.wrappers.torch.base import Wrapper
from skrl.utils.spaces.torch import convert_gym_space, flatten_tensorized_space, tensorize_space
from skrl.utils.spaces.torch import (
convert_gym_space,
flatten_tensorized_space,
tensorize_space,
unflatten_tensorized_space,
untensorize_space
)


class GymWrapper(Wrapper):
Expand Down Expand Up @@ -57,35 +63,6 @@ def action_space(self) -> gymnasium.Space:
return convert_gym_space(self._env.single_action_space)
return convert_gym_space(self._env.action_space)

def _tensor_to_action(self, actions: torch.Tensor) -> Any:
"""Convert the action to the OpenAI Gym expected format
:param actions: The actions to perform
:type actions: torch.Tensor
:raise ValueError: If the action space type is not supported
:return: The action in the OpenAI Gym format
:rtype: Any supported OpenAI Gym action space
"""
space = convert_gym_space(self._env.action_space) if self._vectorized else self.action_space

if self._vectorized:
if isinstance(space, gymnasium.spaces.MultiDiscrete):
return np.array(actions.cpu().numpy(), dtype=space.dtype).reshape(space.shape)
elif isinstance(space, gymnasium.spaces.Tuple):
if isinstance(space[0], gymnasium.spaces.Box):
return np.array(actions.cpu().numpy(), dtype=space[0].dtype).reshape(space.shape)
elif isinstance(space[0], gymnasium.spaces.Discrete):
return np.array(actions.cpu().numpy(), dtype=space[0].dtype).reshape(-1)
if isinstance(space, gymnasium.spaces.Discrete):
return actions.item()
elif isinstance(space, gymnasium.spaces.MultiDiscrete):
return np.array(actions.cpu().numpy(), dtype=space.dtype).reshape(space.shape)
elif isinstance(space, gymnasium.spaces.Box):
return np.array(actions.cpu().numpy(), dtype=space.dtype).reshape(space.shape)
raise ValueError(f"Action space type {type(space)} not supported. Please report this issue")

def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]:
"""Perform a step in the environment
Expand All @@ -95,8 +72,12 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch
:return: Observation, reward, terminated, truncated, info
:rtype: tuple of torch.Tensor and any other info
"""
actions = untensorize_space(self.action_space,
unflatten_tensorized_space(self.action_space, actions),
squeeze_batch_dimension=not self._vectorized)

if self._deprecated_api:
observation, reward, terminated, info = self._env.step(self._tensor_to_action(actions))
observation, reward, terminated, info = self._env.step(actions)
# truncated: https://gymnasium.farama.org/tutorials/handling_time_limits
if type(info) is list:
truncated = np.array([d.get("TimeLimit.truncated", False) for d in info], dtype=terminated.dtype)
Expand All @@ -106,7 +87,7 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch
if truncated:
terminated = False
else:
observation, reward, terminated, truncated, info = self._env.step(self._tensor_to_action(actions))
observation, reward, terminated, truncated, info = self._env.step(actions)

# convert response to torch
observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device))
Expand Down

0 comments on commit 6a4b24d

Please sign in to comment.