Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 14, 2024
1 parent 89c8d98 commit 0dc4622
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def __eq__(self, other):
)


@dataclass(repr=False)
@dataclass(repr=False, frozen=True)
class CategoricalBox(Box):
"""A box of discrete, categorical values."""

Expand Down Expand Up @@ -502,7 +502,7 @@ def from_nvec(nvec: torch.Tensor):
return BoxList([BoxList.from_nvec(n) for n in nvec.unbind(-1)])


@dataclass(repr=False)
@dataclass(repr=False, frozen=True)
class BinaryBox(Box):
"""A box of n binary values."""

Expand Down Expand Up @@ -3313,6 +3313,7 @@ def __init__(
)
self.update_mask(mask)
self._provisional_n = None
self._undefined_n = self.space.n < 0

def enumerate(self) -> torch.Tensor:
dtype = self.dtype
Expand Down Expand Up @@ -3379,7 +3380,7 @@ def set_provisional_n(self, n: int):
self._provisional_n = n

def rand(self, shape: torch.Size = None) -> torch.Tensor:
if self.space.n < 0:
if self._undefined_n:
if self._provisional_n is None:
raise RuntimeError(
"Cannot generate random categorical samples for undefined cardinality (n=-1). "
Expand Down

0 comments on commit 0dc4622

Please sign in to comment.