From 093fa97996826615c2093c396e52f2442aa7eab9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 16:40:58 +0000 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- torchrl/data/tensor_specs.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 3e3fb9daee7..2ef74bb4521 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1392,7 +1392,6 @@ def type_check(self, value: torch.Tensor, key: NestedKey | None = None) -> None: spec.type_check(val) def is_in(self, value) -> bool: - raise RuntimeError if self.dim == 0 and not hasattr(value, "unbind"): # We don't use unbind because value could be a tuple or a nested tensor return all( @@ -1821,7 +1820,6 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: return val def is_in(self, val: torch.Tensor) -> bool: - raise RuntimeError if self.mask is None: shape = torch.broadcast_shapes(self._safe_shape, val.shape) shape_match = val.shape == shape @@ -2272,7 +2270,6 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: return val def is_in(self, val: torch.Tensor) -> bool: - raise RuntimeError val_shape = _remove_neg_shapes(tensordict.utils._shape(val)) shape = torch.broadcast_shapes(self._safe_shape, val_shape) shape = list(shape) @@ -2470,7 +2467,6 @@ def one(self, shape=None): ) def is_in(self, val: torch.Tensor) -> bool: - raise RuntimeError shape = torch.broadcast_shapes(self._safe_shape, val.shape) return ( isinstance(val, NonTensorData) @@ -2663,7 +2659,6 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor: return torch.empty(shape, device=self.device, dtype=self.dtype).random_() def is_in(self, val: torch.Tensor) -> bool: - raise RuntimeError shape = torch.broadcast_shapes(self._safe_shape, val.shape) return val.shape == shape and val.dtype == self.dtype @@ -3012,7 +3007,6 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten return torch.cat(out, -1) def is_in(self, val: torch.Tensor) -> bool: - raise RuntimeError vals = self._split(val) if vals is None: return False @@ -3358,7 +3352,6 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: return val def is_in(self, val: torch.Tensor) -> bool: - raise RuntimeError if self.mask is None: shape = torch.broadcast_shapes(self._safe_shape, val.shape) shape_match = val.shape == shape @@ -3984,7 +3977,6 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: return val.squeeze(0) if val_is_scalar else val def is_in(self, val: torch.Tensor) -> bool: - raise RuntimeError if self.mask is not None: vals = val.unbind(-1) splits = self._split_self()