Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 19, 2024
1 parent 5d8db9b commit 234e2e8
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
22 changes: 19 additions & 3 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from torchrl._utils import _replace_last, logger
from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage
from torchrl.data.replay_buffers.utils import _is_int, unravel_index
from torchrl.data.replay_buffers.utils import _auto_device, _is_int, unravel_index

try:
from torchrl._torchrl import (
Expand Down Expand Up @@ -985,6 +985,7 @@ def __init__(
strict_length: bool = True,
compile: bool | dict = False,
span: bool | int | Tuple[bool | int, bool | int] = False,
use_gpu: torch.device | bool = False,
):
self.num_slices = num_slices
self.slice_len = slice_len
Expand All @@ -995,6 +996,14 @@ def __init__(
self._fetch_traj = True
self.strict_length = strict_length
self._cache = {}
self.use_gpu = bool(use_gpu)
self._gpu_device = (
None
if not self.use_gpu
else torch.device(use_gpu)
if not isinstance(use_gpu, bool)
else _auto_device()
)

if isinstance(span, (bool, int)):
span = (span, span)
Expand Down Expand Up @@ -1143,8 +1152,13 @@ def _find_start_stop_traj(
)
return cls._end_to_start_stop(length=length, end=end)

@staticmethod
def _end_to_start_stop(end, length):
def _end_to_start_stop(self, end, length):
device = None
if self.use_gpu:
gpu_device = self._gpu_device
if end.device != gpu_device:
device = end.device
end = end.to(self._gpu_device)
# Using transpose ensures the start and stop are sorted the same way
stop_idx = end.transpose(0, -1).nonzero()
stop_idx[:, [0, -1]] = stop_idx[:, [-1, 0]].clone()
Expand All @@ -1171,6 +1185,8 @@ def _end_to_start_stop(end, length):
pass
lengths = stop_idx[:, 0] - start_idx[:, 0] + 1
lengths[lengths <= 0] = lengths[lengths <= 0] + length
if device is not None:
return start_idx.to(device), stop_idx.to(device), lengths.to(device)
return start_idx, stop_idx, lengths

def _start_to_end(self, st: torch.Tensor, length: int):
Expand Down
8 changes: 8 additions & 0 deletions torchrl/data/replay_buffers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,3 +1034,11 @@ def tree_iter(pytree): # noqa: F811
def tree_iter(pytree): # noqa: F811
"""A version-compatible wrapper around tree_iter."""
yield from torch.utils._pytree.tree_iter(pytree)


def _auto_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda:0")
elif torch.mps.is_available():
return torch.device("mps:0")
return torch.device("cpu")

0 comments on commit 234e2e8

Please sign in to comment.