Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 18, 2024
2 parents 6d9a87e + 093fa97 commit 703bcac
Showing 1 changed file with 0 additions and 8 deletions.
8 changes: 0 additions & 8 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,7 +1392,6 @@ def type_check(self, value: torch.Tensor, key: NestedKey | None = None) -> None:
spec.type_check(val)

def is_in(self, value) -> bool:
raise RuntimeError
if self.dim == 0 and not hasattr(value, "unbind"):
# We don't use unbind because value could be a tuple or a nested tensor
return all(
Expand Down Expand Up @@ -1821,7 +1820,6 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
return val

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
if self.mask is None:
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
shape_match = val.shape == shape
Expand Down Expand Up @@ -2272,7 +2270,6 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
return val

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
val_shape = _remove_neg_shapes(tensordict.utils._shape(val))
shape = torch.broadcast_shapes(self._safe_shape, val_shape)
shape = list(shape)
Expand Down Expand Up @@ -2470,7 +2467,6 @@ def one(self, shape=None):
)

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
return (
isinstance(val, NonTensorData)
Expand Down Expand Up @@ -2663,7 +2659,6 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor:
return torch.empty(shape, device=self.device, dtype=self.dtype).random_()

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
return val.shape == shape and val.dtype == self.dtype

Expand Down Expand Up @@ -3012,7 +3007,6 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten
return torch.cat(out, -1)

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
vals = self._split(val)
if vals is None:
return False
Expand Down Expand Up @@ -3358,7 +3352,6 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
return val

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
if self.mask is None:
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
shape_match = val.shape == shape
Expand Down Expand Up @@ -3984,7 +3977,6 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
return val.squeeze(0) if val_is_scalar else val

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
if self.mask is not None:
vals = val.unbind(-1)
splits = self._split_self()
Expand Down

0 comments on commit 703bcac

Please sign in to comment.