Skip to content

Commit

Permalink
[Performance] Accelerate slice sampler on GPU
Browse files Browse the repository at this point in the history
ghstack-source-id: a4dc1515d8b51f5ec150b2fae4e1a84254f2af09
Pull Request resolved: #2672
  • Loading branch information
vmoens committed Dec 20, 2024
1 parent 4fd54fe commit 84c3ec3
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 8 deletions.
48 changes: 40 additions & 8 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from torchrl._utils import _replace_last, logger
from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage
from torchrl.data.replay_buffers.utils import _is_int, unravel_index
from torchrl.data.replay_buffers.utils import _auto_device, _is_int, unravel_index

try:
from torchrl._torchrl import (
Expand Down Expand Up @@ -726,6 +726,10 @@ class SliceSampler(Sampler):
This class samples sub-trajectories with replacement. For a version without
replacement, see :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement`.
.. note:: `SliceSampler` can be slow to retrieve the trajectory indices. To accelerate
its execution, prefer using `end_key` over `traj_key`, and consider the following
keyword arguments: :attr:`compile`, :attr:`cache_values` and :attr:`use_gpu`.
Keyword Args:
num_slices (int): the number of slices to be sampled. The batch-size
must be greater or equal to the ``num_slices`` argument. Exclusive
Expand Down Expand Up @@ -796,6 +800,10 @@ class SliceSampler(Sampler):
that at least `slice_len - i` samples will be gathered for each sampled trajectory.
Using tuples allows a fine grained control over the span on the left (beginning
of the stored trajectory) and on the right (end of the stored trajectory).
use_gpu (bool or torch.device): if ``True`` (or is a device is passed), an accelerator
will be used to retrieve the indices of the trajectory starts. This can significanlty
accelerate the sampling when the buffer content is large.
Defaults to ``False``.
.. note:: To recover the trajectory splits in the storage,
:class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first
Expand Down Expand Up @@ -985,6 +993,7 @@ def __init__(
strict_length: bool = True,
compile: bool | dict = False,
span: bool | int | Tuple[bool | int, bool | int] = False,
use_gpu: torch.device | bool = False,
):
self.num_slices = num_slices
self.slice_len = slice_len
Expand All @@ -995,6 +1004,14 @@ def __init__(
self._fetch_traj = True
self.strict_length = strict_length
self._cache = {}
self.use_gpu = bool(use_gpu)
self._gpu_device = (
None
if not self.use_gpu
else torch.device(use_gpu)
if not isinstance(use_gpu, bool)
else _auto_device()
)

if isinstance(span, (bool, int)):
span = (span, span)
Expand Down Expand Up @@ -1086,9 +1103,8 @@ def __repr__(self):
f"strict_length={self.strict_length})"
)

@classmethod
def _find_start_stop_traj(
cls, *, trajectory=None, end=None, at_capacity: bool, cursor=None
self, *, trajectory=None, end=None, at_capacity: bool, cursor=None
):
if trajectory is not None:
# slower
Expand Down Expand Up @@ -1141,10 +1157,15 @@ def _find_start_stop_traj(
raise RuntimeError(
"Expected the end-of-trajectory signal to be at least 1-dimensional."
)
return cls._end_to_start_stop(length=length, end=end)

@staticmethod
def _end_to_start_stop(end, length):
return self._end_to_start_stop(length=length, end=end)

def _end_to_start_stop(self, end, length):
device = None
if self.use_gpu:
gpu_device = self._gpu_device
if end.device != gpu_device:
device = end.device
end = end.to(self._gpu_device)
# Using transpose ensures the start and stop are sorted the same way
stop_idx = end.transpose(0, -1).nonzero()
stop_idx[:, [0, -1]] = stop_idx[:, [-1, 0]].clone()
Expand All @@ -1171,6 +1192,8 @@ def _end_to_start_stop(end, length):
pass
lengths = stop_idx[:, 0] - start_idx[:, 0] + 1
lengths[lengths <= 0] = lengths[lengths <= 0] + length
if device is not None:
return start_idx.to(device), stop_idx.to(device), lengths.to(device)
return start_idx, stop_idx, lengths

def _start_to_end(self, st: torch.Tensor, length: int):
Expand Down Expand Up @@ -1547,6 +1570,10 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
the sampler, and continuous sampling without replacement is currently not
allowed.
.. note:: `SliceSamplerWithoutReplacement` can be slow to retrieve the trajectory indices. To accelerate
its execution, prefer using `end_key` over `traj_key`, and consider the following
keyword arguments: :attr:`compile`, :attr:`cache_values` and :attr:`use_gpu`.
Keyword Args:
drop_last (bool, optional): if ``True``, the last incomplete sample (if any) will be dropped.
If ``False``, this last sample will be kept.
Expand Down Expand Up @@ -1589,6 +1616,10 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
the :meth:`~sample` method will be compiled with :func:`~torch.compile`.
Keyword arguments can also be passed to torch.compile with this arg.
Defaults to ``False``.
use_gpu (bool or torch.device): if ``True`` (or is a device is passed), an accelerator
will be used to retrieve the indices of the trajectory starts. This can significanlty
accelerate the sampling when the buffer content is large.
Defaults to ``False``.
.. note:: To recover the trajectory splits in the storage,
:class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement` will first
Expand Down Expand Up @@ -1693,7 +1724,6 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
tensor([[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.]])
"""

def __init__(
Expand All @@ -1710,6 +1740,7 @@ def __init__(
strict_length: bool = True,
shuffle: bool = True,
compile: bool | dict = False,
use_gpu: bool | torch.device = False,
):
SliceSampler.__init__(
self,
Expand All @@ -1723,6 +1754,7 @@ def __init__(
ends=ends,
trajectories=trajectories,
compile=compile,
use_gpu=use_gpu,
)
SamplerWithoutReplacement.__init__(self, drop_last=drop_last, shuffle=shuffle)

Expand Down
8 changes: 8 additions & 0 deletions torchrl/data/replay_buffers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,3 +1034,11 @@ def tree_iter(pytree): # noqa: F811
def tree_iter(pytree): # noqa: F811
"""A version-compatible wrapper around tree_iter."""
yield from torch.utils._pytree.tree_iter(pytree)


def _auto_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda:0")
elif torch.mps.is_available():
return torch.device("mps:0")
return torch.device("cpu")

0 comments on commit 84c3ec3

Please sign in to comment.