Skip to content

Commit

Permalink
Update Brax wrapper to use space utils in jax
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Oct 13, 2024
1 parent 7364b86 commit e59b7c5
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions skrl/envs/wrappers/jax/brax_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@

from skrl import logger
from skrl.envs.wrappers.jax.base import Wrapper
from skrl.utils.spaces.jax import (
convert_gym_space,
flatten_tensorized_space,
tensorize_space,
unflatten_tensorized_space
)


class BraxWrapper(Wrapper):
Expand All @@ -28,15 +34,13 @@ def __init__(self, env: Any) -> None:
def observation_space(self) -> gymnasium.Space:
"""Observation space
"""
limit = np.inf * np.ones(self._unwrapped.observation_space.shape[1:], dtype='float32')
return gymnasium.spaces.Box(-limit, limit, dtype='float32')
return convert_gym_space(self._unwrapped.observation_space, squeeze_batch_dimension=True)

@property
def action_space(self) -> gymnasium.Space:
"""Action space
"""
limit = np.inf * np.ones(self._unwrapped.action_space.shape[1:], dtype='float32')
return gymnasium.spaces.Box(-limit, limit, dtype='float32')
return convert_gym_space(self._unwrapped.action_space, squeeze_batch_dimension=True)

def step(self, actions: Union[np.ndarray, jax.Array]) -> \
Tuple[Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array],
Expand All @@ -49,7 +53,8 @@ def step(self, actions: Union[np.ndarray, jax.Array]) -> \
:return: Observation, reward, terminated, truncated, info
:rtype: tuple of np.ndarray or jax.Array and any other info
"""
observation, reward, terminated, info = self._env.step(actions)
observation, reward, terminated, info = self._env.step(unflatten_tensorized_space(self.action_space, actions))
observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device))
truncated = jnp.zeros_like(terminated)
if not self._jax:
observation = np.asarray(jax.device_get(observation))
Expand All @@ -65,6 +70,7 @@ def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]:
:rtype: np.ndarray or jax.Array and any other info
"""
observation = self._env.reset()
observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device))
if not self._jax:
observation = np.asarray(jax.device_get(observation))
return observation, {}
Expand Down

0 comments on commit e59b7c5

Please sign in to comment.