diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 273cf627521..23a84a37961 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -24,7 +24,7 @@ from torchrl._utils import _replace_last, logger from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage -from torchrl.data.replay_buffers.utils import _is_int, unravel_index +from torchrl.data.replay_buffers.utils import _auto_device, _is_int, unravel_index try: from torchrl._torchrl import ( @@ -985,6 +985,7 @@ def __init__( strict_length: bool = True, compile: bool | dict = False, span: bool | int | Tuple[bool | int, bool | int] = False, + use_gpu: torch.device | bool = False, ): self.num_slices = num_slices self.slice_len = slice_len @@ -995,6 +996,14 @@ def __init__( self._fetch_traj = True self.strict_length = strict_length self._cache = {} + self.use_gpu = bool(use_gpu) + self._gpu_device = ( + None + if not self.use_gpu + else torch.device(use_gpu) + if not isinstance(use_gpu, bool) + else _auto_device() + ) if isinstance(span, (bool, int)): span = (span, span) @@ -1143,8 +1152,13 @@ def _find_start_stop_traj( ) return cls._end_to_start_stop(length=length, end=end) - @staticmethod - def _end_to_start_stop(end, length): + def _end_to_start_stop(self, end, length): + device = None + if self.use_gpu: + gpu_device = self._gpu_device + if end.device != gpu_device: + device = end.device + end = end.to(self._gpu_device) # Using transpose ensures the start and stop are sorted the same way stop_idx = end.transpose(0, -1).nonzero() stop_idx[:, [0, -1]] = stop_idx[:, [-1, 0]].clone() @@ -1171,6 +1185,8 @@ def _end_to_start_stop(end, length): pass lengths = stop_idx[:, 0] - start_idx[:, 0] + 1 lengths[lengths <= 0] = lengths[lengths <= 0] + length + if device is not None: + return start_idx.to(device), stop_idx.to(device), lengths.to(device) return start_idx, stop_idx, lengths def _start_to_end(self, st: torch.Tensor, length: int): diff --git a/torchrl/data/replay_buffers/utils.py b/torchrl/data/replay_buffers/utils.py index ef941a6ca90..1e8985537f3 100644 --- a/torchrl/data/replay_buffers/utils.py +++ b/torchrl/data/replay_buffers/utils.py @@ -1034,3 +1034,11 @@ def tree_iter(pytree): # noqa: F811 def tree_iter(pytree): # noqa: F811 """A version-compatible wrapper around tree_iter.""" yield from torch.utils._pytree.tree_iter(pytree) + + +def _auto_device() -> torch.device: + if torch.cuda.is_available(): + return torch.device("cuda:0") + elif torch.mps.is_available(): + return torch.device("mps:0") + return torch.device("cpu")