Skip to content

Commit

Permalink
[Feature] spec.cardinality
Browse files Browse the repository at this point in the history
ghstack-source-id: 1160900f8a81dd51dc72436e1af69c8248bff162
Pull Request resolved: #2638
  • Loading branch information
vmoens committed Dec 12, 2024
1 parent 4bc40a8 commit dd26ae7
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 15 deletions.
79 changes: 79 additions & 0 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1689,6 +1689,85 @@ def test_unboundeddiscrete(
assert spec is not spec.clone()


class TestCardinality:
@pytest.mark.parametrize("shape1", [(5, 4)])
def test_binary(self, shape1):
spec = Binary(n=4, shape=shape1, device="cpu", dtype=torch.bool)
assert spec.cardinality() == len(list(spec.enumerate()))

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_discrete(
self,
shape1,
):
spec = Categorical(n=4, shape=shape1, device="cpu", dtype=torch.long)
assert spec.cardinality() == len(list(spec.enumerate()))

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_multidiscrete(
self,
shape1,
):
if shape1 is None:
shape1 = (3,)
else:
shape1 = (*shape1, 3)
spec = MultiCategorical(
nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long
)
assert spec.cardinality() == len(spec.enumerate())

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_multionehot(
self,
shape1,
):
if shape1 is None:
shape1 = (15,)
else:
shape1 = (*shape1, 15)
spec = MultiOneHot(nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long)
assert spec.cardinality() == len(list(spec.enumerate()))

def test_non_tensor(self):
spec = NonTensor(shape=(3, 4), device="cpu")
with pytest.raises(RuntimeError, match="Cannot enumerate a NonTensorSpec."):
spec.cardinality()

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_onehot(
self,
shape1,
):
if shape1 is None:
shape1 = (15,)
else:
shape1 = (*shape1, 15)
spec = OneHot(n=15, shape=shape1, device="cpu", dtype=torch.long)
assert spec.cardinality() == len(list(spec.enumerate()))

def test_composite(self):
batch_size = (5,)
spec2 = Binary(n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool)
spec3 = Categorical(n=4, shape=batch_size, device="cpu", dtype=torch.long)
spec4 = MultiCategorical(
nvec=(4, 5, 6), shape=(*batch_size, 3), device="cpu", dtype=torch.long
)
spec5 = MultiOneHot(
nvec=(4, 5, 6), shape=(*batch_size, 15), device="cpu", dtype=torch.long
)
spec6 = OneHot(n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long)
spec = Composite(
spec2=spec2,
spec3=spec3,
spec4=spec4,
spec5=spec5,
spec6=spec6,
shape=batch_size,
)
assert spec.cardinality() == len(spec.enumerate())


class TestUnbind:
@pytest.mark.parametrize("shape1", [(5, 4)])
def test_binary(self, shape1):
Expand Down
124 changes: 109 additions & 15 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
unravel_key,
)
from tensordict.base import NO_DEFAULT
from tensordict.utils import _getitem_batch_size, NestedKey
from tensordict.utils import _getitem_batch_size, is_non_tensor, NestedKey
from torchrl._utils import _make_ordinal_device, get_binary_env_var, implement_for

DEVICE_TYPING = Union[torch.device, str, int]
Expand Down Expand Up @@ -582,6 +582,16 @@ def clear_device_(self) -> T:
"""
return self

@abc.abstractmethod
def cardinality(self) -> int:
"""The cardinality of the spec.
This refers to the number of possible outcomes in a spec. It is assumed that the cardinality of a composite
spec is the cartesian product of all possible outcomes.
"""
...

def encode(
self,
val: np.ndarray | torch.Tensor | TensorDictBase,
Expand Down Expand Up @@ -1515,6 +1525,9 @@ def __init__(
def n(self):
return self.space.n

def cardinality(self) -> int:
return self.n

def update_mask(self, mask):
"""Sets a mask to prevent some of the possible outcomes when a sample is taken.
Expand Down Expand Up @@ -2107,6 +2120,9 @@ def enumerate(self) -> Any:
f"enumerate is not implemented for spec of class {type(self).__name__}."
)

def cardinality(self) -> int:
return float("inf")

def __eq__(self, other):
return (
type(other) == type(self)
Expand Down Expand Up @@ -2426,8 +2442,11 @@ def __init__(
shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs
)

def cardinality(self) -> Any:
raise RuntimeError("Cannot enumerate a NonTensorSpec.")

def enumerate(self) -> Any:
raise NotImplementedError("Cannot enumerate a NonTensorSpec.")
raise RuntimeError("Cannot enumerate a NonTensorSpec.")

def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor:
if isinstance(dest, torch.dtype):
Expand Down Expand Up @@ -2466,10 +2485,10 @@ def one(self, shape=None):
data=None, batch_size=(*shape, *self._safe_shape), device=self.device
)

def is_in(self, val: torch.Tensor) -> bool:
def is_in(self, val: Any) -> bool:
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
return (
isinstance(val, NonTensorData)
is_non_tensor(val)
and val.shape == shape
# We relax constrains on device as they're hard to enforce for non-tensor
# tensordicts and pointless
Expand Down Expand Up @@ -2832,6 +2851,9 @@ def __init__(
)
self.update_mask(mask)

def cardinality(self) -> int:
return torch.as_tensor(self.nvec).prod()

def enumerate(self) -> torch.Tensor:
nvec = self.nvec
enum_disc = self.to_categorical_spec().enumerate()
Expand Down Expand Up @@ -3220,13 +3242,20 @@ class Categorical(TensorSpec):
The spec will have the shape defined by the ``shape`` argument: if a singleton dimension is
desired for the training dimension, one should specify it explicitly.
Attributes:
n (int): The number of possible outcomes.
shape (torch.Size): The shape of the variable.
device (torch.device): The device of the tensors.
dtype (torch.dtype): The dtype of the tensors.
Args:
n (int): number of possible outcomes.
n (int): number of possible outcomes. If set to -1, the cardinality of the categorical spec is undefined,
and `set_provisional_n` must be called before sampling from this spec.
shape: (torch.Size, optional): shape of the variable, default is "torch.Size([])".
device (str, int or torch.device, optional): device of the tensors.
dtype (str or torch.dtype, optional): dtype of the tensors.
mask (torch.Tensor or None): mask some of the possible outcomes when a
sample is taken. See :meth:`~.update_mask` for more information.
device (str, int or torch.device, optional): the device of the tensors.
dtype (str or torch.dtype, optional): the dtype of the tensors.
mask (torch.Tensor or None): A boolean mask to prevent some of the possible outcomes when a sample is taken.
See :meth:`~.update_mask` for more information.
Examples:
>>> categ = Categorical(3)
Expand All @@ -3249,6 +3278,13 @@ class Categorical(TensorSpec):
domain=discrete)
>>> categ.rand()
tensor([1])
>>> categ = Categorical(-1)
>>> categ.set_provisional_n(5)
>>> categ.rand()
tensor(3)
.. note:: When n is set to -1, calling `rand` without first setting a provisional n using `set_provisional_n`
will raise a ``RuntimeError``.
"""

Expand Down Expand Up @@ -3276,16 +3312,31 @@ def __init__(
shape=shape, space=space, device=device, dtype=dtype, domain="discrete"
)
self.update_mask(mask)
self._provisional_n = None

def enumerate(self) -> torch.Tensor:
arange = torch.arange(self.n, dtype=self.dtype, device=self.device)
dtype = self.dtype
if dtype is torch.bool:
dtype = torch.uint8
arange = torch.arange(self.n, dtype=dtype, device=self.device)
if self.ndim:
arange = arange.view(-1, *(1,) * self.ndim)
return arange.expand(self.n, *self.shape)

@property
def n(self):
return self.space.n
n = self.space.n
if n == -1:
n = self._provisional_n
if n is None:
raise RuntimeError(
f"Undefined cardinality for {type(self)}. Please call "
f"spec.set_provisional_n(int)."
)
return n

def cardinality(self) -> int:
return self.n

def update_mask(self, mask):
"""Sets a mask to prevent some of the possible outcomes when a sample is taken.
Expand Down Expand Up @@ -3316,13 +3367,33 @@ def update_mask(self, mask):
raise ValueError("Only boolean masks are accepted.")
self.mask = mask

def set_provisional_n(self, n: int):
"""Set the cardinality of the Categorical spec temporarily.
This method is required to be called before sampling from the spec when n is -1.
Args:
n (int): The cardinality of the Categorical spec.
"""
self._provisional_n = n

def rand(self, shape: torch.Size = None) -> torch.Tensor:
if self.space.n < 0:
if self._provisional_n is None:
raise RuntimeError(
"Cannot generate random categorical samples for undefined cardinality (n=-1). "
"To sample from this class, first call Categorical.set_provisional_n(n) before calling rand()."
)
n = self._provisional_n
else:
n = self.space.n
if shape is None:
shape = _size([])
if self.mask is None:
return torch.randint(
0,
self.space.n,
n,
_size([*shape, *self.shape]),
device=self.device,
dtype=self.dtype,
Expand All @@ -3334,6 +3405,12 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor:
else:
mask_flat = mask
shape_out = mask.shape[:-1]
# Check that the mask has the right size
if mask_flat.shape[-1] != n:
raise ValueError(
"The last dimension of the mask must match the number of action allowed by the "
f"Categorical spec. Got mask.shape={self.mask.shape} and n={n}."
)
out = torch.multinomial(mask_flat.float(), 1).reshape(shape_out)
return out

Expand All @@ -3360,6 +3437,8 @@ def is_in(self, val: torch.Tensor) -> bool:
dtype_match = val.dtype == self.dtype
if not dtype_match:
return False
if self.space.n == -1:
return True
return (0 <= val).all() and (val < self.space.n).all()
shape = self.mask.shape
shape = _size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]])
Expand Down Expand Up @@ -3607,7 +3686,7 @@ def __init__(
device: Optional[DEVICE_TYPING] = None,
dtype: Union[str, torch.dtype] = torch.int8,
):
if n is None and not shape:
if n is None and shape is None:
raise TypeError("Must provide either n or shape.")
if n is None:
n = shape[-1]
Expand Down Expand Up @@ -3813,6 +3892,9 @@ def enumerate(self) -> torch.Tensor:
arange = arange.expand(arange.shape[0], *self.shape)
return arange

def cardinality(self) -> int:
return self.nvec._base.prod()

def update_mask(self, mask):
"""Sets a mask to prevent some of the possible outcomes when a sample is taken.
Expand Down Expand Up @@ -4373,7 +4455,7 @@ def set(self, name, spec):
shape = spec.shape
if shape[: self.ndim] != self.shape:
if (
isinstance(spec, Composite)
isinstance(spec, (Composite, NonTensor))
and spec.ndim < self.ndim
and self.shape[: spec.ndim] == spec.shape
):
Expand All @@ -4382,7 +4464,7 @@ def set(self, name, spec):
spec.shape = self.shape
else:
raise ValueError(
"The shape of the spec and the Composite mismatch: the first "
f"The shape of the spec {type(spec).__name__} and the Composite {type(self).__name__} mismatch: the first "
f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and "
f"Composite.shape={self.shape}."
)
Expand Down Expand Up @@ -4798,6 +4880,18 @@ def clone(self) -> Composite:
shape=self.shape,
)

def cardinality(self) -> int:
n = None
for spec in self.values():
if spec is None:
continue
if n is None:
n = 1
n = n * spec.cardinality()
if n is None:
n = 0
return n

def enumerate(self) -> TensorDictBase:
# We are going to use meshgrid to create samples of all the subspecs in here
# but first let's get rid of the batch size, we'll put it back later
Expand Down
19 changes: 19 additions & 0 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,25 @@ def check_env_specs(self, *args, **kwargs):

check_env_specs.__doc__ = check_env_specs_func.__doc__

def cardinality(self, tensordict: TensorDictBase | None = None) -> int:
"""The cardinality of the action space.
By default, this is just a wrapper around :meth:`env.action_space.cardinality <~torchrl.data.TensorSpec.cardinality>`.
This class is useful when the action spec is variable:
- The number of actions can be undefined, e.g., ``Categorical(n=-1)``;
- The action cardinality may depend on the action mask;
- The shape can be dynamic, as in ``Unbound(shape=(-1))``.
In these cases, the :meth:`~.cardinality` should be overwritten,
Args:
tensordict (TensorDictBase, optional): a tensordict containing the data required to compute the cardinality.
"""
return self.full_action_spec.cardinality()

@classmethod
def __new__(cls, *args, _inplace_update=False, _batch_locked=True, **kwargs):
# inplace update will write tensors in-place on the provided tensordict.
Expand Down

0 comments on commit dd26ae7

Please sign in to comment.