diff --git a/test/test_specs.py b/test/test_specs.py index 3dedc6233a9..5334281f0ee 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -1689,6 +1689,85 @@ def test_unboundeddiscrete( assert spec is not spec.clone() +class TestCardinality: + @pytest.mark.parametrize("shape1", [(5, 4)]) + def test_binary(self, shape1): + spec = Binary(n=4, shape=shape1, device="cpu", dtype=torch.bool) + assert spec.cardinality() == len(list(spec.enumerate())) + + @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) + def test_discrete( + self, + shape1, + ): + spec = Categorical(n=4, shape=shape1, device="cpu", dtype=torch.long) + assert spec.cardinality() == len(list(spec.enumerate())) + + @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) + def test_multidiscrete( + self, + shape1, + ): + if shape1 is None: + shape1 = (3,) + else: + shape1 = (*shape1, 3) + spec = MultiCategorical( + nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long + ) + assert spec.cardinality() == len(spec.enumerate()) + + @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) + def test_multionehot( + self, + shape1, + ): + if shape1 is None: + shape1 = (15,) + else: + shape1 = (*shape1, 15) + spec = MultiOneHot(nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long) + assert spec.cardinality() == len(list(spec.enumerate())) + + def test_non_tensor(self): + spec = NonTensor(shape=(3, 4), device="cpu") + with pytest.raises(RuntimeError, match="Cannot enumerate a NonTensorSpec."): + spec.cardinality() + + @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) + def test_onehot( + self, + shape1, + ): + if shape1 is None: + shape1 = (15,) + else: + shape1 = (*shape1, 15) + spec = OneHot(n=15, shape=shape1, device="cpu", dtype=torch.long) + assert spec.cardinality() == len(list(spec.enumerate())) + + def test_composite(self): + batch_size = (5,) + spec2 = Binary(n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool) + spec3 = Categorical(n=4, shape=batch_size, device="cpu", dtype=torch.long) + spec4 = MultiCategorical( + nvec=(4, 5, 6), shape=(*batch_size, 3), device="cpu", dtype=torch.long + ) + spec5 = MultiOneHot( + nvec=(4, 5, 6), shape=(*batch_size, 15), device="cpu", dtype=torch.long + ) + spec6 = OneHot(n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long) + spec = Composite( + spec2=spec2, + spec3=spec3, + spec4=spec4, + spec5=spec5, + spec6=spec6, + shape=batch_size, + ) + assert spec.cardinality() == len(spec.enumerate()) + + class TestUnbind: @pytest.mark.parametrize("shape1", [(5, 4)]) def test_binary(self, shape1): diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index ddf6ed41c99..c03fb40f1ac 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -41,7 +41,7 @@ unravel_key, ) from tensordict.base import NO_DEFAULT -from tensordict.utils import _getitem_batch_size, NestedKey +from tensordict.utils import _getitem_batch_size, is_non_tensor, NestedKey from torchrl._utils import _make_ordinal_device, get_binary_env_var, implement_for DEVICE_TYPING = Union[torch.device, str, int] @@ -582,6 +582,16 @@ def clear_device_(self) -> T: """ return self + @abc.abstractmethod + def cardinality(self) -> int: + """The cardinality of the spec. + + This refers to the number of possible outcomes in a spec. It is assumed that the cardinality of a composite + spec is the cartesian product of all possible outcomes. + + """ + ... + def encode( self, val: np.ndarray | torch.Tensor | TensorDictBase, @@ -1515,6 +1525,9 @@ def __init__( def n(self): return self.space.n + def cardinality(self) -> int: + return self.n + def update_mask(self, mask): """Sets a mask to prevent some of the possible outcomes when a sample is taken. @@ -2107,6 +2120,9 @@ def enumerate(self) -> Any: f"enumerate is not implemented for spec of class {type(self).__name__}." ) + def cardinality(self) -> int: + return float("inf") + def __eq__(self, other): return ( type(other) == type(self) @@ -2426,8 +2442,11 @@ def __init__( shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs ) + def cardinality(self) -> Any: + raise RuntimeError("Cannot enumerate a NonTensorSpec.") + def enumerate(self) -> Any: - raise NotImplementedError("Cannot enumerate a NonTensorSpec.") + raise RuntimeError("Cannot enumerate a NonTensorSpec.") def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor: if isinstance(dest, torch.dtype): @@ -2466,10 +2485,10 @@ def one(self, shape=None): data=None, batch_size=(*shape, *self._safe_shape), device=self.device ) - def is_in(self, val: torch.Tensor) -> bool: + def is_in(self, val: Any) -> bool: shape = torch.broadcast_shapes(self._safe_shape, val.shape) return ( - isinstance(val, NonTensorData) + is_non_tensor(val) and val.shape == shape # We relax constrains on device as they're hard to enforce for non-tensor # tensordicts and pointless @@ -2832,6 +2851,9 @@ def __init__( ) self.update_mask(mask) + def cardinality(self) -> int: + return torch.as_tensor(self.nvec).prod() + def enumerate(self) -> torch.Tensor: nvec = self.nvec enum_disc = self.to_categorical_spec().enumerate() @@ -3220,13 +3242,20 @@ class Categorical(TensorSpec): The spec will have the shape defined by the ``shape`` argument: if a singleton dimension is desired for the training dimension, one should specify it explicitly. + Attributes: + n (int): The number of possible outcomes. + shape (torch.Size): The shape of the variable. + device (torch.device): The device of the tensors. + dtype (torch.dtype): The dtype of the tensors. + Args: - n (int): number of possible outcomes. + n (int): number of possible outcomes. If set to -1, the cardinality of the categorical spec is undefined, + and `set_provisional_n` must be called before sampling from this spec. shape: (torch.Size, optional): shape of the variable, default is "torch.Size([])". - device (str, int or torch.device, optional): device of the tensors. - dtype (str or torch.dtype, optional): dtype of the tensors. - mask (torch.Tensor or None): mask some of the possible outcomes when a - sample is taken. See :meth:`~.update_mask` for more information. + device (str, int or torch.device, optional): the device of the tensors. + dtype (str or torch.dtype, optional): the dtype of the tensors. + mask (torch.Tensor or None): A boolean mask to prevent some of the possible outcomes when a sample is taken. + See :meth:`~.update_mask` for more information. Examples: >>> categ = Categorical(3) @@ -3249,6 +3278,13 @@ class Categorical(TensorSpec): domain=discrete) >>> categ.rand() tensor([1]) + >>> categ = Categorical(-1) + >>> categ.set_provisional_n(5) + >>> categ.rand() + tensor(3) + + .. note:: When n is set to -1, calling `rand` without first setting a provisional n using `set_provisional_n` + will raise a ``RuntimeError``. """ @@ -3276,16 +3312,31 @@ def __init__( shape=shape, space=space, device=device, dtype=dtype, domain="discrete" ) self.update_mask(mask) + self._provisional_n = None def enumerate(self) -> torch.Tensor: - arange = torch.arange(self.n, dtype=self.dtype, device=self.device) + dtype = self.dtype + if dtype is torch.bool: + dtype = torch.uint8 + arange = torch.arange(self.n, dtype=dtype, device=self.device) if self.ndim: arange = arange.view(-1, *(1,) * self.ndim) return arange.expand(self.n, *self.shape) @property def n(self): - return self.space.n + n = self.space.n + if n == -1: + n = self._provisional_n + if n is None: + raise RuntimeError( + f"Undefined cardinality for {type(self)}. Please call " + f"spec.set_provisional_n(int)." + ) + return n + + def cardinality(self) -> int: + return self.n def update_mask(self, mask): """Sets a mask to prevent some of the possible outcomes when a sample is taken. @@ -3316,13 +3367,33 @@ def update_mask(self, mask): raise ValueError("Only boolean masks are accepted.") self.mask = mask + def set_provisional_n(self, n: int): + """Set the cardinality of the Categorical spec temporarily. + + This method is required to be called before sampling from the spec when n is -1. + + Args: + n (int): The cardinality of the Categorical spec. + + """ + self._provisional_n = n + def rand(self, shape: torch.Size = None) -> torch.Tensor: + if self.space.n < 0: + if self._provisional_n is None: + raise RuntimeError( + "Cannot generate random categorical samples for undefined cardinality (n=-1). " + "To sample from this class, first call Categorical.set_provisional_n(n) before calling rand()." + ) + n = self._provisional_n + else: + n = self.space.n if shape is None: shape = _size([]) if self.mask is None: return torch.randint( 0, - self.space.n, + n, _size([*shape, *self.shape]), device=self.device, dtype=self.dtype, @@ -3334,6 +3405,12 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor: else: mask_flat = mask shape_out = mask.shape[:-1] + # Check that the mask has the right size + if mask_flat.shape[-1] != n: + raise ValueError( + "The last dimension of the mask must match the number of action allowed by the " + f"Categorical spec. Got mask.shape={self.mask.shape} and n={n}." + ) out = torch.multinomial(mask_flat.float(), 1).reshape(shape_out) return out @@ -3360,6 +3437,8 @@ def is_in(self, val: torch.Tensor) -> bool: dtype_match = val.dtype == self.dtype if not dtype_match: return False + if self.space.n == -1: + return True return (0 <= val).all() and (val < self.space.n).all() shape = self.mask.shape shape = _size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]]) @@ -3607,7 +3686,7 @@ def __init__( device: Optional[DEVICE_TYPING] = None, dtype: Union[str, torch.dtype] = torch.int8, ): - if n is None and not shape: + if n is None and shape is None: raise TypeError("Must provide either n or shape.") if n is None: n = shape[-1] @@ -3813,6 +3892,9 @@ def enumerate(self) -> torch.Tensor: arange = arange.expand(arange.shape[0], *self.shape) return arange + def cardinality(self) -> int: + return self.nvec._base.prod() + def update_mask(self, mask): """Sets a mask to prevent some of the possible outcomes when a sample is taken. @@ -4373,7 +4455,7 @@ def set(self, name, spec): shape = spec.shape if shape[: self.ndim] != self.shape: if ( - isinstance(spec, Composite) + isinstance(spec, (Composite, NonTensor)) and spec.ndim < self.ndim and self.shape[: spec.ndim] == spec.shape ): @@ -4382,7 +4464,7 @@ def set(self, name, spec): spec.shape = self.shape else: raise ValueError( - "The shape of the spec and the Composite mismatch: the first " + f"The shape of the spec {type(spec).__name__} and the Composite {type(self).__name__} mismatch: the first " f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and " f"Composite.shape={self.shape}." ) @@ -4798,6 +4880,18 @@ def clone(self) -> Composite: shape=self.shape, ) + def cardinality(self) -> int: + n = None + for spec in self.values(): + if spec is None: + continue + if n is None: + n = 1 + n = n * spec.cardinality() + if n is None: + n = 0 + return n + def enumerate(self) -> TensorDictBase: # We are going to use meshgrid to create samples of all the subspecs in here # but first let's get rid of the batch size, we'll put it back later diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 78f89cc8a38..3b55fd227a7 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -561,6 +561,25 @@ def check_env_specs(self, *args, **kwargs): check_env_specs.__doc__ = check_env_specs_func.__doc__ + def cardinality(self, tensordict: TensorDictBase | None = None) -> int: + """The cardinality of the action space. + + By default, this is just a wrapper around :meth:`env.action_space.cardinality <~torchrl.data.TensorSpec.cardinality>`. + + This class is useful when the action spec is variable: + + - The number of actions can be undefined, e.g., ``Categorical(n=-1)``; + - The action cardinality may depend on the action mask; + - The shape can be dynamic, as in ``Unbound(shape=(-1))``. + + In these cases, the :meth:`~.cardinality` should be overwritten, + + Args: + tensordict (TensorDictBase, optional): a tensordict containing the data required to compute the cardinality. + + """ + return self.full_action_spec.cardinality() + @classmethod def __new__(cls, *args, _inplace_update=False, _batch_locked=True, **kwargs): # inplace update will write tensors in-place on the provided tensordict.