Skip to content

Commit

Permalink
Use spaces utils to process actions in Isaac Gym preview wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Oct 2, 2024
1 parent 7e41f75 commit b5e27b2
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
1 change: 0 additions & 1 deletion skrl/envs/wrappers/torch/gymnasium_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import gymnasium

import numpy as np
import torch

from skrl import logger
Expand Down
11 changes: 8 additions & 3 deletions skrl/envs/wrappers/torch/isaacgym_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
import torch

from skrl.envs.wrappers.torch.base import Wrapper
from skrl.utils.spaces.torch import convert_gym_space, flatten_tensorized_space, tensorize_space
from skrl.utils.spaces.torch import (
convert_gym_space,
flatten_tensorized_space,
tensorize_space,
unflatten_tensorized_space
)


class IsaacGymPreview2Wrapper(Wrapper):
Expand Down Expand Up @@ -42,7 +47,7 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch
:return: Observation, reward, terminated, truncated, info
:rtype: tuple of torch.Tensor and any other info
"""
observations, reward, terminated, self._info = self._env.step(actions)
observations, reward, terminated, self._info = self._env.step(unflatten_tensorized_space(self.action_space, actions))
self._observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations))
truncated = self._info["time_outs"] if "time_outs" in self._info else torch.zeros_like(terminated)
return self._observations, reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), self._info
Expand Down Expand Up @@ -115,7 +120,7 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch
:return: Observation, reward, terminated, truncated, info
:rtype: tuple of torch.Tensor and any other info
"""
observations, reward, terminated, self._info = self._env.step(actions)
observations, reward, terminated, self._info = self._env.step(unflatten_tensorized_space(self.action_space, actions))
self._observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["obs"]))
truncated = self._info["time_outs"] if "time_outs" in self._info else torch.zeros_like(terminated)
return self._observations, reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), self._info
Expand Down

0 comments on commit b5e27b2

Please sign in to comment.