Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 16, 2024
1 parent 545bcc0 commit 5744e91
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,11 +381,17 @@ class ContinuousBox(Box):
# We store the tensors on CPU to avoid overloading CUDA with tensors that are rarely used.
@property
def low(self):
return self._low.to(self.device)
low = self._low
if low.device != self.device:
low = low.to(self.device)
return low

@property
def high(self):
return self._high.to(self.device)
high = self._high
if high.device != self.device:
high = high.to(self.device)
return high

def unbind(self, dim: int = 0):
return tuple(
Expand Down

0 comments on commit 5744e91

Please sign in to comment.