Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 16, 2024
1 parent 1518b4a commit 7cba108
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,14 +387,14 @@ class ContinuousBox(Box):
@property
def low(self):
low = self._low
if low.device != self.device:
if self.device is not None and low.device != self.device:
low = low.to(self.device)
return low

@property
def high(self):
high = self._high
if high.device != self.device:
if self.device is not None and high.device != self.device:
high = high.to(self.device)
return high

Expand Down Expand Up @@ -2285,7 +2285,10 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor:
r = torch.rand(_size([*shape, *self._safe_shape]), device=interval.device)
r = interval * r
r = self.space.low + r
r = r.to(self.dtype).to(self.device)
if r.dtype != self.dtype:
r = r.to(self.dtype)
if self.dtype is not None and r.device != self.device:
r = r.to(self.device)
return r

def _project(self, val: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -2767,8 +2770,8 @@ def __eq__(self, other):
# those specs are equivalent to a discrete spec
if isinstance(other, Bounded):
minval, maxval = _minmax_dtype(self.dtype)
minval = torch.as_tensor(minval).to(self.device, self.dtype)
maxval = torch.as_tensor(maxval).to(self.device, self.dtype)
minval = torch.as_tensor(minval, device=self.device, dtype=self.dtype)
maxval = torch.as_tensor(maxval, device=self.device, dtype=self.dtype)
return (
Bounded(
shape=self.shape,
Expand Down

0 comments on commit 7cba108

Please sign in to comment.