diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 273cf627521..fc27401d5e5 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -24,7 +24,7 @@ from torchrl._utils import _replace_last, logger from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage -from torchrl.data.replay_buffers.utils import _is_int, unravel_index +from torchrl.data.replay_buffers.utils import _auto_device, _is_int, unravel_index try: from torchrl._torchrl import ( @@ -726,6 +726,10 @@ class SliceSampler(Sampler): This class samples sub-trajectories with replacement. For a version without replacement, see :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement`. + .. note:: `SliceSampler` can be slow to retrieve the trajectory indices. To accelerate + its execution, prefer using `end_key` over `traj_key`, and consider the following + keyword arguments: :attr:`compile`, :attr:`cache_values` and :attr:`use_gpu`. + Keyword Args: num_slices (int): the number of slices to be sampled. The batch-size must be greater or equal to the ``num_slices`` argument. Exclusive @@ -796,6 +800,10 @@ class SliceSampler(Sampler): that at least `slice_len - i` samples will be gathered for each sampled trajectory. Using tuples allows a fine grained control over the span on the left (beginning of the stored trajectory) and on the right (end of the stored trajectory). + use_gpu (bool or torch.device): if ``True`` (or is a device is passed), an accelerator + will be used to retrieve the indices of the trajectory starts. This can significanlty + accelerate the sampling when the buffer content is large. + Defaults to ``False``. .. note:: To recover the trajectory splits in the storage, :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first @@ -985,6 +993,7 @@ def __init__( strict_length: bool = True, compile: bool | dict = False, span: bool | int | Tuple[bool | int, bool | int] = False, + use_gpu: torch.device | bool = False, ): self.num_slices = num_slices self.slice_len = slice_len @@ -995,6 +1004,14 @@ def __init__( self._fetch_traj = True self.strict_length = strict_length self._cache = {} + self.use_gpu = bool(use_gpu) + self._gpu_device = ( + None + if not self.use_gpu + else torch.device(use_gpu) + if not isinstance(use_gpu, bool) + else _auto_device() + ) if isinstance(span, (bool, int)): span = (span, span) @@ -1086,9 +1103,8 @@ def __repr__(self): f"strict_length={self.strict_length})" ) - @classmethod def _find_start_stop_traj( - cls, *, trajectory=None, end=None, at_capacity: bool, cursor=None + self, *, trajectory=None, end=None, at_capacity: bool, cursor=None ): if trajectory is not None: # slower @@ -1141,10 +1157,15 @@ def _find_start_stop_traj( raise RuntimeError( "Expected the end-of-trajectory signal to be at least 1-dimensional." ) - return cls._end_to_start_stop(length=length, end=end) - - @staticmethod - def _end_to_start_stop(end, length): + return self._end_to_start_stop(length=length, end=end) + + def _end_to_start_stop(self, end, length): + device = None + if self.use_gpu: + gpu_device = self._gpu_device + if end.device != gpu_device: + device = end.device + end = end.to(self._gpu_device) # Using transpose ensures the start and stop are sorted the same way stop_idx = end.transpose(0, -1).nonzero() stop_idx[:, [0, -1]] = stop_idx[:, [-1, 0]].clone() @@ -1171,6 +1192,8 @@ def _end_to_start_stop(end, length): pass lengths = stop_idx[:, 0] - start_idx[:, 0] + 1 lengths[lengths <= 0] = lengths[lengths <= 0] + length + if device is not None: + return start_idx.to(device), stop_idx.to(device), lengths.to(device) return start_idx, stop_idx, lengths def _start_to_end(self, st: torch.Tensor, length: int): @@ -1547,6 +1570,10 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement): the sampler, and continuous sampling without replacement is currently not allowed. + .. note:: `SliceSamplerWithoutReplacement` can be slow to retrieve the trajectory indices. To accelerate + its execution, prefer using `end_key` over `traj_key`, and consider the following + keyword arguments: :attr:`compile`, :attr:`cache_values` and :attr:`use_gpu`. + Keyword Args: drop_last (bool, optional): if ``True``, the last incomplete sample (if any) will be dropped. If ``False``, this last sample will be kept. @@ -1589,6 +1616,10 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement): the :meth:`~sample` method will be compiled with :func:`~torch.compile`. Keyword arguments can also be passed to torch.compile with this arg. Defaults to ``False``. + use_gpu (bool or torch.device): if ``True`` (or is a device is passed), an accelerator + will be used to retrieve the indices of the trajectory starts. This can significanlty + accelerate the sampling when the buffer content is large. + Defaults to ``False``. .. note:: To recover the trajectory splits in the storage, :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement` will first @@ -1693,7 +1724,6 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement): tensor([[0., 0., 0., 0., 0.], [1., 1., 1., 1., 1.]]) - """ def __init__( @@ -1710,6 +1740,7 @@ def __init__( strict_length: bool = True, shuffle: bool = True, compile: bool | dict = False, + use_gpu: bool | torch.device = False, ): SliceSampler.__init__( self, @@ -1723,6 +1754,7 @@ def __init__( ends=ends, trajectories=trajectories, compile=compile, + use_gpu=use_gpu, ) SamplerWithoutReplacement.__init__(self, drop_last=drop_last, shuffle=shuffle) diff --git a/torchrl/data/replay_buffers/utils.py b/torchrl/data/replay_buffers/utils.py index ef941a6ca90..1e8985537f3 100644 --- a/torchrl/data/replay_buffers/utils.py +++ b/torchrl/data/replay_buffers/utils.py @@ -1034,3 +1034,11 @@ def tree_iter(pytree): # noqa: F811 def tree_iter(pytree): # noqa: F811 """A version-compatible wrapper around tree_iter.""" yield from torch.utils._pytree.tree_iter(pytree) + + +def _auto_device() -> torch.device: + if torch.cuda.is_available(): + return torch.device("cuda:0") + elif torch.mps.is_available(): + return torch.device("mps:0") + return torch.device("cpu")