diff --git a/skrl/envs/wrappers/jax/pettingzoo_envs.py b/skrl/envs/wrappers/jax/pettingzoo_envs.py index 1e524c9a..180e0209 100644 --- a/skrl/envs/wrappers/jax/pettingzoo_envs.py +++ b/skrl/envs/wrappers/jax/pettingzoo_envs.py @@ -1,12 +1,17 @@ from typing import Any, Mapping, Tuple, Union import collections -import gymnasium import jax import numpy as np from skrl.envs.wrappers.jax.base import MultiAgentEnvWrapper +from skrl.utils.spaces.jax import ( + flatten_tensorized_space, + tensorize_space, + unflatten_tensorized_space, + untensorize_space +) class PettingZooWrapper(MultiAgentEnvWrapper): @@ -18,49 +23,6 @@ def __init__(self, env: Any) -> None: """ super().__init__(env) - def _observation_to_tensor(self, observation: Any, space: gymnasium.Space) -> np.ndarray: - """Convert the Gymnasium observation to a flat tensor - - :param observation: The Gymnasium observation to convert to a tensor - :type observation: Any supported Gymnasium observation space - - :raises: ValueError if the observation space type is not supported - - :return: The observation as a flat tensor - :rtype: np.ndarray - """ - if isinstance(observation, int): - return np.array(observation, dtype=np.int32).view(self.num_envs, -1) - elif isinstance(observation, np.ndarray): - return observation.reshape(self.num_envs, -1).astype(np.float32) - elif isinstance(space, gymnasium.spaces.Discrete): - return np.array(observation, dtype=np.float32).reshape(self.num_envs, -1) - elif isinstance(space, gymnasium.spaces.Box): - return observation.reshape(self.num_envs, -1).astype(np.float32) - elif isinstance(space, gymnasium.spaces.Dict): - tmp = np.concatenate([self._observation_to_tensor(observation[k], space[k]) \ - for k in sorted(space.keys())], axis=-1).view(self.num_envs, -1) - return tmp - else: - raise ValueError(f"Observation space type {type(space)} not supported. Please report this issue") - - def _tensor_to_action(self, actions: np.ndarray, space: gymnasium.Space) -> Any: - """Convert the action to the Gymnasium expected format - - :param actions: The actions to perform - :type actions: np.ndarray - - :raise ValueError: If the action space type is not supported - - :return: The action in the Gymnasium format - :rtype: Any supported Gymnasium action space - """ - if isinstance(space, gymnasium.spaces.Discrete): - return actions.item() - elif isinstance(space, gymnasium.spaces.Box): - return actions.astype(space.dtype).reshape(space.shape) - raise ValueError(f"Action space type {type(space)} not supported. Please report this issue") - def step(self, actions: Mapping[str, Union[np.ndarray, jax.Array]]) -> \ Tuple[Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str, Union[np.ndarray, jax.Array]], @@ -75,11 +37,11 @@ def step(self, actions: Mapping[str, Union[np.ndarray, jax.Array]]) -> \ """ if self._jax: actions = jax.device_get(actions) - actions = {uid: self._tensor_to_action(action, self.action_space(uid)) for uid, action in actions.items()} + actions = {uid: untensorize_space(self.action_spaces[uid], unflatten_tensorized_space(self.action_spaces[uid], action)) for uid, action in actions.items()} observations, rewards, terminated, truncated, infos = self._env.step(actions) # convert response to numpy or jax - observations = {uid: self._observation_to_tensor(value, self.observation_space(uid)) for uid, value in observations.items()} + observations = {uid: flatten_tensorized_space(tensorize_space(self.observation_spaces[uid], value, self.device, False), False) for uid, value in observations.items()} rewards = {uid: np.array(value, dtype=np.float32).reshape(self.num_envs, -1) for uid, value in rewards.items()} terminated = {uid: np.array(value, dtype=np.int8).reshape(self.num_envs, -1) for uid, value in terminated.items()} truncated = {uid: np.array(value, dtype=np.int8).reshape(self.num_envs, -1) for uid, value in truncated.items()} @@ -96,7 +58,7 @@ def state(self) -> Union[np.ndarray, jax.Array]: :return: State :rtype: np.ndarray or jax.Array """ - state = self._observation_to_tensor(self._env.state(), next(iter(self.state_spaces.values()))) + state = flatten_tensorized_space(tensorize_space(next(iter(self.state_spaces.values())), self._env.state(), self.device, False), False) if self._jax: state = jax.device_put(state, device=self.device) return state @@ -115,7 +77,7 @@ def reset(self) -> Tuple[Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str observations, infos = outputs # convert response to numpy or jax - observations = {uid: self._observation_to_tensor(value, self.observation_space(uid)) for uid, value in observations.items()} + observations = {uid: flatten_tensorized_space(tensorize_space(self.observation_spaces[uid], value, self.device, False), False) for uid, value in observations.items()} if self._jax: observations = {uid: jax.device_put(value, device=self.device) for uid, value in observations.items()} return observations, infos diff --git a/skrl/envs/wrappers/torch/pettingzoo_envs.py b/skrl/envs/wrappers/torch/pettingzoo_envs.py index 08348f15..7b55785c 100644 --- a/skrl/envs/wrappers/torch/pettingzoo_envs.py +++ b/skrl/envs/wrappers/torch/pettingzoo_envs.py @@ -1,9 +1,7 @@ from typing import Any, Mapping, Tuple import collections -import gymnasium -import numpy as np import torch from skrl.envs.wrappers.torch.base import MultiAgentEnvWrapper