diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 7dfe32b93be..4563fd3ca21 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -455,7 +455,7 @@ def __eq__(self, other): ) -@dataclass(repr=False) +@dataclass(repr=False, frozen=True) class CategoricalBox(Box): """A box of discrete, categorical values.""" @@ -502,7 +502,7 @@ def from_nvec(nvec: torch.Tensor): return BoxList([BoxList.from_nvec(n) for n in nvec.unbind(-1)]) -@dataclass(repr=False) +@dataclass(repr=False, frozen=True) class BinaryBox(Box): """A box of n binary values.""" @@ -3313,6 +3313,7 @@ def __init__( ) self.update_mask(mask) self._provisional_n = None + self._undefined_n = self.space.n < 0 def enumerate(self) -> torch.Tensor: dtype = self.dtype @@ -3379,7 +3380,7 @@ def set_provisional_n(self, n: int): self._provisional_n = n def rand(self, shape: torch.Size = None) -> torch.Tensor: - if self.space.n < 0: + if self._undefined_n: if self._provisional_n is None: raise RuntimeError( "Cannot generate random categorical samples for undefined cardinality (n=-1). "