Skip to content

Commit

Permalink
Add space utility in jax
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Oct 5, 2024
1 parent ac92ba5 commit 05581c4
Show file tree
Hide file tree
Showing 2 changed files with 297 additions and 0 deletions.
9 changes: 9 additions & 0 deletions skrl/utils/spaces/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from skrl.utils.spaces.jax.spaces import (
compute_space_size,
convert_gym_space,
flatten_tensorized_space,
sample_space,
tensorize_space,
unflatten_tensorized_space,
untensorize_space
)
288 changes: 288 additions & 0 deletions skrl/utils/spaces/jax/spaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
from typing import Any, Literal, Optional, Sequence, Union

import gymnasium
from gymnasium import spaces

import jax
import jax.numpy as jnp
import numpy as np

from skrl import config


def convert_gym_space(space: "gym.Space", squeeze_batch_dimension: bool = False) -> gymnasium.Space:
"""Converts a gym space to a gymnasium space.
:param space: Gym space to convert to.
:param squeeze_batch_dimension: Whether to remove fundamental spaces' first dimension.
It currently affects ``Box`` space only.
:raises ValueError: The given space is not supported.
:return: Converted space.
"""
import gym

if isinstance(space, gym.spaces.Discrete):
return spaces.Discrete(n=space.n)
elif isinstance(space, gym.spaces.Box):
if squeeze_batch_dimension:
return spaces.Box(low=space.low[0], high=space.high[0], shape=space.shape[1:], dtype=space.dtype)
return spaces.Box(low=space.low, high=space.high, shape=space.shape, dtype=space.dtype)
elif isinstance(space, gym.spaces.MultiDiscrete):
return spaces.MultiDiscrete(nvec=space.nvec)
elif isinstance(space, gym.spaces.Tuple):
return spaces.Tuple(spaces=tuple(map(convert_gym_space, space.spaces)))
elif isinstance(space, gym.spaces.Dict):
return spaces.Dict(spaces={k: convert_gym_space(v) for k, v in space.spaces.items()})
raise ValueError(f"Unsupported space ({space})")

def tensorize_space(space: spaces.Space, x: Any, device: Optional[Union[str, jax.Device]] = None) -> Any:
"""Convert the sample/value items of a given gymnasium space to JAX array.
Fundamental spaces (:py:class:`~gymnasium.spaces.Box`, :py:class:`~gymnasium.spaces.Discrete`, and
:py:class:`~gymnasium.spaces.MultiDiscrete`) are converted to :py:class:`~jax.Array` with shape
(-1, space's shape). Composite spaces (:py:class:`~gymnasium.spaces.Dict` and :py:class:`~gymnasium.spaces.Tuple`)
are converted by recursively calling this function on their elements.
:param space: Gymnasium space.
:param x: Sample/value of the given space to tensorize to.
:param device: Device on which a tensor/array is or will be allocated (default: ``None``).
This parameter is used when the space value is not a JAX array (e.g.: numpy array, number).
:raises ValueError: The given space or the sample/value type is not supported.
:return: Sample/value space with items converted to tensors.
"""
if x is None:
return None
# fundamental spaces
# Box
if isinstance(space, spaces.Box):
if isinstance(x, jax.Array):
return x.reshape(-1, *space.shape)
elif isinstance(x, np.ndarray):
return jax.device_put(x.reshape(-1, *space.shape), config.jax.parse_device(device))
else:
raise ValueError(f"Unsupported type ({type(x)}) for the given space ({space})")
# Discrete
elif isinstance(space, spaces.Discrete):
if isinstance(x, jax.Array):
return x.reshape(-1, 1)
elif isinstance(x, np.ndarray):
return jax.device_put(x.reshape(-1, 1), config.jax.parse_device(device))
elif isinstance(x, np.number) or type(x) in [int, float]:
return jnp.array([x], device=device, dtype=jnp.int32).reshape(-1, 1)
else:
raise ValueError(f"Unsupported type ({type(x)}) for the given space ({space})")
# MultiDiscrete
elif isinstance(space, spaces.MultiDiscrete):
if isinstance(x, jax.Array):
return x.reshape(-1, *space.shape)
elif isinstance(x, np.ndarray):
return jax.device_put(x.reshape(-1, *space.shape), config.jax.parse_device(device))
elif type(x) in [list, tuple]:
return jnp.array(x, device=device, dtype=jnp.int32).reshape(-1, *space.shape)
else:
raise ValueError(f"Unsupported type ({type(x)}) for the given space ({space})")
# composite spaces
# Dict
elif isinstance(space, spaces.Dict):
return {k: tensorize_space(s, x[k], device) for k, s in space.items()}
# Tuple
elif isinstance(space, spaces.Tuple):
return tuple([tensorize_space(s, _x, device) for s, _x in zip(space, x)])
raise ValueError(f"Unsupported space ({space})")

def untensorize_space(space: spaces.Space, x: Any, squeeze_batch_dimension: bool = True) -> Any:
"""Convert a tensorized space to a gymnasium space with expected sample/value item types.
:param space: Gymnasium space.
:param x: Tensorized space (Sample/value space where items are tensors).
:param squeeze_batch_dimension: Whether to remove the batch dimension. If True, only the
sample/value with a batch dimension of size 1 will be affected
:raises ValueError: The given space or the sample/value type is not supported.
:return: Sample/value space with expected item types.
"""
if x is None:
return None
# fundamental spaces
# Box
if isinstance(space, spaces.Box):
if isinstance(x, jax.Array):
array = np.asarray(jax.device_get(x), dtype=space.dtype)
if squeeze_batch_dimension and array.shape[0] == 1:
return array.reshape(space.shape)
return array.reshape(-1, *space.shape)
raise ValueError(f"Unsupported type ({type(x)}) for the given space ({space})")
# Discrete
elif isinstance(space, spaces.Discrete):
if isinstance(x, jax.Array):
array = np.asarray(jax.device_get(x), dtype=space.dtype)
if squeeze_batch_dimension and array.shape[0] == 1:
return array.item()
return array.reshape(-1, 1)
raise ValueError(f"Unsupported type ({type(x)}) for the given space ({space})")
# MultiDiscrete
elif isinstance(space, spaces.MultiDiscrete):
if isinstance(x, jax.Array):
array = np.asarray(jax.device_get(x), dtype=space.dtype)
if squeeze_batch_dimension and array.shape[0] == 1:
return array.reshape(space.nvec.shape)
return array.reshape(-1, *space.nvec.shape)
raise ValueError(f"Unsupported type ({type(x)}) for the given space ({space})")
# composite spaces
# Dict
elif isinstance(space, spaces.Dict):
return {k: untensorize_space(s, x[k], squeeze_batch_dimension) for k, s in space.items()}
# Tuple
elif isinstance(space, spaces.Tuple):
return tuple([untensorize_space(s, _x, squeeze_batch_dimension) for s, _x in zip(space, x)])
raise ValueError(f"Unsupported space ({space})")

def flatten_tensorized_space(x: Any) -> jax.Array:
"""Flatten a tensorized space.
:param x: Tensorized space sample/value.
:raises ValueError: The given sample/value type is not supported.
:return: A tensor. The returned tensor will have shape (batch, space size).
"""
# fundamental spaces
# Box / Discrete / MultiDiscrete
if isinstance(x, jax.Array):
return x.reshape(x.shape[0], -1) if x.ndim > 1 else x.reshape(1, -1)
# composite spaces
# Dict
elif isinstance(x, dict):
return jnp.concatenate([flatten_tensorized_space(x[k])for k in sorted(x.keys())], axis=-1)
# Tuple
elif type(x) in [list, tuple]:
return jnp.concatenate([flatten_tensorized_space(_x) for _x in x], axis=-1)
raise ValueError(f"Unsupported sample/value type ({type(x)})")

def unflatten_tensorized_space(space: Union[spaces.Space, Sequence[int], int], x: jax.Array) -> Any:
"""Unflatten a tensor to create a tensorized space.
:param space: Gymnasium space.
:param x: A tensor with shape (batch, space size).
:raises ValueError: The given space is not supported.
:return: Tensorized space value.
"""
if x is None:
return None
# fundamental spaces
# Box
if isinstance(space, spaces.Box):
return x.reshape(-1, *space.shape)
# Discrete
elif isinstance(space, spaces.Discrete):
return x.reshape(-1, 1)
# MultiDiscrete
elif isinstance(space, spaces.MultiDiscrete):
return x.reshape(-1, *space.shape)
# composite spaces
# Dict
elif isinstance(space, spaces.Dict):
start = 0
output = {}
for k in sorted(space.keys()):
end = start + compute_space_size(space[k], occupied_size=True)
output[k] = unflatten_tensorized_space(space[k], x[:, start:end])
start = end
return output
# Tuple
elif isinstance(space, spaces.Tuple):
start = 0
output = []
for s in space:
end = start + compute_space_size(s, occupied_size=True)
output.append(unflatten_tensorized_space(s, x[:, start:end]))
start = end
return output
raise ValueError(f"Unsupported space ({space})")

def compute_space_size(space: Union[spaces.Space, Sequence[int], int], occupied_size: bool = False) -> int:
"""Get the size (number of elements) of a space.
:param space: Gymnasium space.
:param occupied_size: Whether the number of elements occupied by the space is returned (default: ``False``).
It only affects :py:class:`~gymnasium.spaces.Discrete` (occupied space is 1),
and :py:class:`~gymnasium.spaces.MultiDiscrete` (occupied space is the number of discrete spaces).
:return: Size of the space (number of elements).
"""
if occupied_size:
# fundamental spaces
# Discrete
if isinstance(space, spaces.Discrete):
return 1
# MultiDiscrete
elif isinstance(space, spaces.MultiDiscrete):
return space.nvec.shape[0]
# composite spaces
# Dict
elif isinstance(space, spaces.Dict):
return sum([compute_space_size(s, occupied_size) for s in space.values()])
# Tuple
elif isinstance(space, spaces.Tuple):
return sum([compute_space_size(s, occupied_size) for s in space])
# non-gymnasium spaces
if type(space) in [int, float]:
return space
elif type(space) in [tuple, list]:
return int(np.prod(space))
# gymnasium computation
return gymnasium.spaces.flatdim(space)

def sample_space(space: spaces.Space, batch_size: int = 1, backend: str = Literal["numpy", "jax"], device = None) -> Any:
"""Generates a random sample from the specified space.
:param space: Gymnasium space.
:param batch_size: Size of the sampled batch (default: ``1``).
:param backend: Whether backend will be used to construct the fundamental spaces (default: ``"numpy"``).
:param device: Device on which a tensor/array is or will be allocated (default: ``None``).
This parameter is used when the backend is ``"jax"``.
:raises ValueError: The given space or backend is not supported.
:return: Sample of the space
"""
# fundamental spaces
# Box
if isinstance(space, spaces.Box):
if backend == "numpy":
return np.stack([space.sample() for _ in range(batch_size)])
elif backend == "jax":
return jnp.array(np.stack([space.sample() for _ in range(batch_size)]), device=device)
else:
raise ValueError(f"Unsupported backend type ({backend})")
# Discrete
elif isinstance(space, spaces.Discrete):
if backend == "numpy":
return np.stack([[space.sample()] for _ in range(batch_size)])
elif backend == "jax":
return jnp.array(np.stack([[space.sample()] for _ in range(batch_size)]), device=device)
else:
raise ValueError(f"Unsupported backend type ({backend})")
# MultiDiscrete
elif isinstance(space, spaces.MultiDiscrete):
if backend == "numpy":
return np.stack([space.sample() for _ in range(batch_size)])
elif backend == "jax":
return jnp.array(np.stack([space.sample() for _ in range(batch_size)]), device=device)
else:
raise ValueError(f"Unsupported backend type ({backend})")
# composite spaces
# Dict
elif isinstance(space, spaces.Dict):
return {k: sample_space(s, batch_size, backend, device) for k, s in space.items()}
# Tuple
elif isinstance(space, spaces.Tuple):
return tuple([sample_space(s, batch_size, backend, device) for s in space])
raise ValueError(f"Unsupported space ({space})")

0 comments on commit 05581c4

Please sign in to comment.