Skip to content

Commit

Permalink
Update PettingZoo wrapper to use space utils in jax
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Oct 13, 2024
1 parent af41e69 commit 7364b86
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 50 deletions.
58 changes: 10 additions & 48 deletions skrl/envs/wrappers/jax/pettingzoo_envs.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from typing import Any, Mapping, Tuple, Union

import collections
import gymnasium

import jax
import numpy as np

from skrl.envs.wrappers.jax.base import MultiAgentEnvWrapper
from skrl.utils.spaces.jax import (
flatten_tensorized_space,
tensorize_space,
unflatten_tensorized_space,
untensorize_space
)


class PettingZooWrapper(MultiAgentEnvWrapper):
Expand All @@ -18,49 +23,6 @@ def __init__(self, env: Any) -> None:
"""
super().__init__(env)

def _observation_to_tensor(self, observation: Any, space: gymnasium.Space) -> np.ndarray:
"""Convert the Gymnasium observation to a flat tensor
:param observation: The Gymnasium observation to convert to a tensor
:type observation: Any supported Gymnasium observation space
:raises: ValueError if the observation space type is not supported
:return: The observation as a flat tensor
:rtype: np.ndarray
"""
if isinstance(observation, int):
return np.array(observation, dtype=np.int32).view(self.num_envs, -1)
elif isinstance(observation, np.ndarray):
return observation.reshape(self.num_envs, -1).astype(np.float32)
elif isinstance(space, gymnasium.spaces.Discrete):
return np.array(observation, dtype=np.float32).reshape(self.num_envs, -1)
elif isinstance(space, gymnasium.spaces.Box):
return observation.reshape(self.num_envs, -1).astype(np.float32)
elif isinstance(space, gymnasium.spaces.Dict):
tmp = np.concatenate([self._observation_to_tensor(observation[k], space[k]) \
for k in sorted(space.keys())], axis=-1).view(self.num_envs, -1)
return tmp
else:
raise ValueError(f"Observation space type {type(space)} not supported. Please report this issue")

def _tensor_to_action(self, actions: np.ndarray, space: gymnasium.Space) -> Any:
"""Convert the action to the Gymnasium expected format
:param actions: The actions to perform
:type actions: np.ndarray
:raise ValueError: If the action space type is not supported
:return: The action in the Gymnasium format
:rtype: Any supported Gymnasium action space
"""
if isinstance(space, gymnasium.spaces.Discrete):
return actions.item()
elif isinstance(space, gymnasium.spaces.Box):
return actions.astype(space.dtype).reshape(space.shape)
raise ValueError(f"Action space type {type(space)} not supported. Please report this issue")

def step(self, actions: Mapping[str, Union[np.ndarray, jax.Array]]) -> \
Tuple[Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str, Union[np.ndarray, jax.Array]],
Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str, Union[np.ndarray, jax.Array]],
Expand All @@ -75,11 +37,11 @@ def step(self, actions: Mapping[str, Union[np.ndarray, jax.Array]]) -> \
"""
if self._jax:
actions = jax.device_get(actions)
actions = {uid: self._tensor_to_action(action, self.action_space(uid)) for uid, action in actions.items()}
actions = {uid: untensorize_space(self.action_spaces[uid], unflatten_tensorized_space(self.action_spaces[uid], action)) for uid, action in actions.items()}
observations, rewards, terminated, truncated, infos = self._env.step(actions)

# convert response to numpy or jax
observations = {uid: self._observation_to_tensor(value, self.observation_space(uid)) for uid, value in observations.items()}
observations = {uid: flatten_tensorized_space(tensorize_space(self.observation_spaces[uid], value, self.device, False), False) for uid, value in observations.items()}
rewards = {uid: np.array(value, dtype=np.float32).reshape(self.num_envs, -1) for uid, value in rewards.items()}
terminated = {uid: np.array(value, dtype=np.int8).reshape(self.num_envs, -1) for uid, value in terminated.items()}
truncated = {uid: np.array(value, dtype=np.int8).reshape(self.num_envs, -1) for uid, value in truncated.items()}
Expand All @@ -96,7 +58,7 @@ def state(self) -> Union[np.ndarray, jax.Array]:
:return: State
:rtype: np.ndarray or jax.Array
"""
state = self._observation_to_tensor(self._env.state(), next(iter(self.state_spaces.values())))
state = flatten_tensorized_space(tensorize_space(next(iter(self.state_spaces.values())), self._env.state(), self.device, False), False)
if self._jax:
state = jax.device_put(state, device=self.device)
return state
Expand All @@ -115,7 +77,7 @@ def reset(self) -> Tuple[Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str
observations, infos = outputs

# convert response to numpy or jax
observations = {uid: self._observation_to_tensor(value, self.observation_space(uid)) for uid, value in observations.items()}
observations = {uid: flatten_tensorized_space(tensorize_space(self.observation_spaces[uid], value, self.device, False), False) for uid, value in observations.items()}
if self._jax:
observations = {uid: jax.device_put(value, device=self.device) for uid, value in observations.items()}
return observations, infos
Expand Down
2 changes: 0 additions & 2 deletions skrl/envs/wrappers/torch/pettingzoo_envs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from typing import Any, Mapping, Tuple

import collections
import gymnasium

import numpy as np
import torch

from skrl.envs.wrappers.torch.base import MultiAgentEnvWrapper
Expand Down

0 comments on commit 7364b86

Please sign in to comment.