Skip to content

Commit

Permalink
Add funtion to sample spaces by batches
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Sep 21, 2024
1 parent 38ca1c8 commit c8fe715
Showing 1 changed file with 45 additions and 1 deletion.
46 changes: 45 additions & 1 deletion skrl/utils/spaces/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Sequence, Union
from typing import Any, Literal, Optional, Sequence, Union

import gymnasium
from gymnasium import spaces
Expand Down Expand Up @@ -113,3 +113,47 @@ def compute_space_size(space: Union[spaces.Space, Sequence[int], int], occupied_
return int(np.prod(space))
# gymnasium computation
return gymnasium.spaces.flatdim(space)

def sample_space(space: spaces.Space, batch_size: int = 1, backend: str = Literal["numpy", "torch"], device = None) -> Any:
"""Generates a random sample from the specified space
:param space: Gymnasium space
:param batch_size: Size of the sampled batch (default: ``1``)
:param backend: Whether backend will be used to construct the fundamental spaces (default: ``"numpy"``)
:param device: Device on which a tensor/array is or will be allocated (default: ``None``).
This parameter is used when the backend is ``"torch"``
:return: Sample of the space
"""
# fundamental spaces
# Box
if isinstance(space, spaces.Box):
if backend == "numpy":
return np.stack([space.sample() for _ in range(batch_size)])
elif backend == "torch":
return torch.tensor(np.stack([space.sample() for _ in range(batch_size)]), device=device)
else:
raise ValueError(f"Unsupported backend type ({backend})")
# Discrete
elif isinstance(space, spaces.Discrete):
if backend == "numpy":
return np.stack([[space.sample()] for _ in range(batch_size)])
elif backend == "torch":
return torch.tensor(np.stack([[space.sample()] for _ in range(batch_size)]), device=device)
else:
raise ValueError(f"Unsupported backend type ({backend})")
# MultiDiscrete
elif isinstance(space, spaces.MultiDiscrete):
if backend == "numpy":
return np.stack([space.sample() for _ in range(batch_size)])
elif backend == "torch":
return torch.tensor(np.stack([space.sample() for _ in range(batch_size)]), device=device)
else:
raise ValueError(f"Unsupported backend type ({backend})")
# composite spaces
# Dict
elif isinstance(space, spaces.Dict):
return {k: sample_space(s, batch_size, backend, device) for k, s in space.items()}
# Tuple
elif isinstance(space, spaces.Tuple):
return tuple([sample_space(s, batch_size, backend, device) for s in space])

0 comments on commit c8fe715

Please sign in to comment.