diff --git a/skrl/envs/wrappers/jax/brax_envs.py b/skrl/envs/wrappers/jax/brax_envs.py index 1e052764..8fa10a29 100644 --- a/skrl/envs/wrappers/jax/brax_envs.py +++ b/skrl/envs/wrappers/jax/brax_envs.py @@ -8,6 +8,12 @@ from skrl import logger from skrl.envs.wrappers.jax.base import Wrapper +from skrl.utils.spaces.jax import ( + convert_gym_space, + flatten_tensorized_space, + tensorize_space, + unflatten_tensorized_space +) class BraxWrapper(Wrapper): @@ -28,15 +34,13 @@ def __init__(self, env: Any) -> None: def observation_space(self) -> gymnasium.Space: """Observation space """ - limit = np.inf * np.ones(self._unwrapped.observation_space.shape[1:], dtype='float32') - return gymnasium.spaces.Box(-limit, limit, dtype='float32') + return convert_gym_space(self._unwrapped.observation_space, squeeze_batch_dimension=True) @property def action_space(self) -> gymnasium.Space: """Action space """ - limit = np.inf * np.ones(self._unwrapped.action_space.shape[1:], dtype='float32') - return gymnasium.spaces.Box(-limit, limit, dtype='float32') + return convert_gym_space(self._unwrapped.action_space, squeeze_batch_dimension=True) def step(self, actions: Union[np.ndarray, jax.Array]) -> \ Tuple[Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], @@ -49,7 +53,8 @@ def step(self, actions: Union[np.ndarray, jax.Array]) -> \ :return: Observation, reward, terminated, truncated, info :rtype: tuple of np.ndarray or jax.Array and any other info """ - observation, reward, terminated, info = self._env.step(actions) + observation, reward, terminated, info = self._env.step(unflatten_tensorized_space(self.action_space, actions)) + observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device)) truncated = jnp.zeros_like(terminated) if not self._jax: observation = np.asarray(jax.device_get(observation)) @@ -65,6 +70,7 @@ def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: :rtype: np.ndarray or jax.Array and any other info """ observation = self._env.reset() + observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device)) if not self._jax: observation = np.asarray(jax.device_get(observation)) return observation, {}