Skip to content

Commit

Permalink
Fix Isaac Lab multi-agent environment in jax
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Aug 27, 2024
1 parent 6618f77 commit c2c34b9
Showing 1 changed file with 18 additions and 9 deletions.
27 changes: 18 additions & 9 deletions skrl/envs/wrappers/jax/isaaclab_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self, env: Any) -> None:
"""
super().__init__(env)

self._env_device = torch.device(self._unwrapped.device)
self._reset_once = True
self._observations = None
self._info = {}
Expand Down Expand Up @@ -89,7 +90,7 @@ def step(self, actions: Union[np.ndarray, jax.Array]) -> \
:return: Observation, reward, terminated, truncated, info
:rtype: tuple of np.ndarray or jax.Array and any other info
"""
actions = _jax2torch(actions, self._env.device, self._jax)
actions = _jax2torch(actions, self._env_device, self._jax)

with torch.no_grad():
self._observations, reward, terminated, truncated, self._info = self._env.step(actions)
Expand Down Expand Up @@ -190,6 +191,7 @@ def __init__(self, env: Any) -> None:
"""
super().__init__(env)

self._env_device = torch.device(self._unwrapped.device)
self._reset_once = True
self._observations = None
self._info = {}
Expand All @@ -205,11 +207,16 @@ def step(self, actions: Mapping[str, Union[np.ndarray, jax.Array]]) -> \
:return: Observation, reward, terminated, truncated, info
:rtype: tuple of dictionaries of np.ndarray or jax.Array and any other info
"""
self._observations, rewards, terminated, truncated, self._info = self._env.step(actions)
actions = {uid: _jax2torch(value, self._env_device, self._jax) for uid, value in actions.items()}

with torch.no_grad():
observations, rewards, terminated, truncated, self._info = self._env.step(actions)

self._observations = {uid: _torch2jax(value, self._jax) for uid, value in observations.items()}
return self._observations, \
{k: v.view(-1, 1) for k, v in rewards.items()}, \
{k: v.view(-1, 1) for k, v in terminated.items()}, \
{k: v.view(-1, 1) for k, v in truncated.items()}, \
{uid: _torch2jax(value.view(-1, 1), self._jax) for uid, value in rewards.items()}, \
{uid: _torch2jax(value.to(dtype=torch.int8).view(-1, 1), self._jax) for uid, value in terminated.items()}, \
{uid: _torch2jax(value.to(dtype=torch.int8).view(-1, 1), self._jax) for uid, value in truncated.items()}, \
self._info

def reset(self) -> Tuple[Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str, Any]]:
Expand All @@ -219,17 +226,19 @@ def reset(self) -> Tuple[Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str
:rtype: np.ndarray or jax.Array and any other info
"""
if self._reset_once:
self._observations, self._info = self._env.reset()
observations, self._info = self._env.reset()
self._observations = {uid: _torch2jax(value, self._jax) for uid, value in observations.items()}
self._reset_once = False
return self._observations, self._info

def state(self) -> Union[np.ndarray, jax.Array]:
def state(self) -> Union[np.ndarray, jax.Array, None]:
"""Get the environment state
:return: State
:rtype: np.ndarray or jax.Array
:rtype: np.ndarray, jax.Array or None
"""
return _torch2jax(self._env.state(), self._jax)
state = self._env.state()
return None if state is None else _torch2jax(state, self._jax)

def render(self, *args, **kwargs) -> None:
"""Render the environment
Expand Down

0 comments on commit c2c34b9

Please sign in to comment.