Skip to content

Commit

Permalink
Update Isaac Gym preview wrapper in jax
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Aug 13, 2024
1 parent a6c1601 commit f04af60
Showing 1 changed file with 29 additions and 14 deletions.
43 changes: 29 additions & 14 deletions skrl/envs/wrappers/jax/isaacgym_envs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Tuple, Union

import gym

import jax
import jax.dlpack as jax_dlpack
import numpy as np
Expand Down Expand Up @@ -42,7 +44,8 @@ def __init__(self, env: Any) -> None:
super().__init__(env)

self._reset_once = True
self._obs_buf = None
self._observations = None
self._info = {}

def step(self, actions: Union[np.ndarray, jax.Array]) -> \
Tuple[Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array],
Expand All @@ -58,16 +61,16 @@ def step(self, actions: Union[np.ndarray, jax.Array]) -> \
actions = _jax2torch(actions, self._env.device, self._jax)

with torch.no_grad():
self._obs_buf, reward, terminated, info = self._env.step(actions)
self._observations, reward, terminated, self._info = self._env.step(actions)

terminated = terminated.to(dtype=torch.int8)
truncated = info["time_outs"].to(dtype=torch.int8) if "time_outs" in info else torch.zeros_like(terminated)
truncated = self._info["time_outs"].to(dtype=torch.int8) if "time_outs" in self._info else torch.zeros_like(terminated)

return _torch2jax(self._obs_buf, self._jax), \
return _torch2jax(self._observations, self._jax), \
_torch2jax(reward.view(-1, 1), self._jax), \
_torch2jax(terminated.view(-1, 1), self._jax), \
_torch2jax(truncated.view(-1, 1), self._jax), \
info
self._info

def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]:
"""Reset the environment
Expand All @@ -76,9 +79,9 @@ def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]:
:rtype: np.ndarray or jax.Array and any other info
"""
if self._reset_once:
self._obs_buf = self._env.reset()
self._observations = self._env.reset()
self._reset_once = False
return _torch2jax(self._obs_buf, self._jax), {}
return _torch2jax(self._observations, self._jax), self._info

def render(self, *args, **kwargs) -> None:
"""Render the environment
Expand All @@ -101,7 +104,19 @@ def __init__(self, env: Any) -> None:
super().__init__(env)

self._reset_once = True
self._obs_dict = None
self._observations = None
self._info = {}

@property
def state_space(self) -> Union[gym.Space, None]:
"""State space
"""
try:
if self.num_states:
return self._unwrapped.state_space
except:
pass
return None

def step(self, actions: Union[np.ndarray, jax.Array]) ->\
Tuple[Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array],
Expand All @@ -117,16 +132,16 @@ def step(self, actions: Union[np.ndarray, jax.Array]) ->\
actions = _jax2torch(actions, self._env.device, self._jax)

with torch.no_grad():
self._obs_dict, reward, terminated, info = self._env.step(actions)
self._observations, reward, terminated, self._info = self._env.step(actions)

terminated = terminated.to(dtype=torch.int8)
truncated = info["time_outs"].to(dtype=torch.int8) if "time_outs" in info else torch.zeros_like(terminated)
truncated = self._info["time_outs"].to(dtype=torch.int8) if "time_outs" in self._info else torch.zeros_like(terminated)

return _torch2jax(self._obs_dict["obs"], self._jax), \
return _torch2jax(self._observations["obs"], self._jax), \
_torch2jax(reward.view(-1, 1), self._jax), \
_torch2jax(terminated.view(-1, 1), self._jax), \
_torch2jax(truncated.view(-1, 1), self._jax), \
info
self._info

def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]:
"""Reset the environment
Expand All @@ -135,9 +150,9 @@ def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]:
:rtype: np.ndarray or jax.Array and any other info
"""
if self._reset_once:
self._obs_dict = self._env.reset()
self._observations = self._env.reset()
self._reset_once = False
return _torch2jax(self._obs_dict["obs"], self._jax), {}
return _torch2jax(self._observations["obs"], self._jax), self._info

def render(self, *args, **kwargs) -> None:
"""Render the environment
Expand Down

0 comments on commit f04af60

Please sign in to comment.