Skip to content

Commit

Permalink
[Performance] Faster slice sampler (#2031)
Browse files Browse the repository at this point in the history
(cherry picked from commit cd540bf)
  • Loading branch information
vmoens committed Mar 25, 2024
1 parent 01301ca commit d74fc05
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 37 deletions.
143 changes: 110 additions & 33 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from torchrl._extension import EXTENSION_WARNING

from torchrl._utils import _replace_last
from torchrl._utils import _replace_last, logger
from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage
from torchrl.data.replay_buffers.utils import _is_int

Expand Down Expand Up @@ -54,7 +54,7 @@ def extend(self, index: torch.Tensor) -> None:

def update_priority(
self, index: Union[int, torch.Tensor], priority: Union[float, torch.Tensor]
) -> dict:
) -> dict | None:
return

def mark_update(self, index: Union[int, torch.Tensor]) -> None:
Expand Down Expand Up @@ -221,7 +221,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]:
if storage.ndim > 1:
index = torch.unravel_index(index, storage.shape)
# we 'always' return the indices. The 'drop_last' just instructs the
# sampler to turn to 'ran_out = True` whenever the next sample
# sampler to turn to `ran_out = True` whenever the next sample
# will be too short. This will be read by the replay buffer
# as a signal for an early break of the __iter__().
return index, {}
Expand Down Expand Up @@ -477,7 +477,7 @@ def update_priority(
"""
priority = torch.as_tensor(priority, device=torch.device("cpu")).detach()
index = torch.as_tensor(index, dtype=torch.long, device=torch.device("cpu"))
# we need to reshape priority if it has more than one elements or if it has
# we need to reshape priority if it has more than one element or if it has
# a different shape than index
if priority.numel() > 1 and priority.shape != index.shape:
try:
Expand Down Expand Up @@ -637,7 +637,25 @@ class SliceSampler(Sampler):
if the last element of the trajectory tensor is identical to the first,
the same trajectory spans across end and beginning.
cache_values (bool, optional): to be used with static datasets.
Will cache the start and end signal of the trajectory.
Will cache the start and end signal of the trajectory. This can be safely used even
if the trajectory indices change during calls to :class:`~torchrl.data.ReplayBuffer.extend`
as this operation will erase the cache.
.. warning:: ``cache_values=True`` will not work if the sampler is used with a
storage that is extended by another buffer. For instance:
>>> buffer0 = ReplayBuffer(storage=storage,
... sampler=SliceSampler(num_slices=8, cache_values=True),
... writer=ImmutableWriter())
>>> buffer1 = ReplayBuffer(storage=storage,
... sampler=other_sampler)
>>> # Wrong! Does not erase the buffer from the sampler of buffer0
>>> buffer1.extend(data)
.. warning:: ``cache_values=True`` will not work as expected if the buffer is
shared between processes and one process is responsible for writing
and one process for sampling, as erasing the cache can only be done locally.
truncated_key (NestedKey, optional): If not ``None``, this argument
indicates where a truncated signal should be written in the output
data. This is used to indicate to value estimators where the provided
Expand All @@ -652,6 +670,10 @@ class SliceSampler(Sampler):
Be mindful that this can result in effective `batch_size` shorter
than the one asked for! Trajectories can be split using
:func:`~torchrl.collectors.split_trajectories`. Defaults to ``True``.
compile (bool or dict of kwargs, optional): if ``True``, the bottleneck of
the :meth:`~sample` method will be compiled with :func:`~torch.compile`.
Keyword arguments can also be passed to torch.compile with this arg.
Defaults to ``False``.
.. note:: To recover the trajectory splits in the storage,
:class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first
Expand Down Expand Up @@ -730,6 +752,7 @@ def __init__(
cache_values: bool = False,
truncated_key: NestedKey | None = ("next", "truncated"),
strict_length: bool = True,
compile: bool | dict = False,
):
self.num_slices = num_slices
self.slice_len = slice_len
Expand Down Expand Up @@ -784,6 +807,31 @@ def __init__(
"Either num_slices or slice_len must be not None, and not both. "
f"Got num_slices={num_slices} and slice_len={slice_len}."
)
self.compile = bool(compile)
if self.compile:
if isinstance(compile, dict):
kwargs = compile
else:
kwargs = {}
self._get_index = torch.compile(self._get_index, **kwargs)

def __getstate__(self):
if get_spawning_popen() is not None and self.cache_values:
logger.warning(
f"It seems you are sharing a {type(self).__name__} across processes with"
f"cache_values=True. "
f"While this isn't forbidden and could perfectly work if your dataset "
f"is unaltered on both processes, remember that calling extend/add on"
f"one process will NOT erase the cache on another process's sampler, "
f"which will cause synchronization issues."
)
state = copy(self.__dict__)
state["_cache"] = {}
return state

def extend(self, index: torch.Tensor) -> None:
if self.cache_values:
self._cache.clear()

def __repr__(self):
return (
Expand All @@ -795,8 +843,8 @@ def __repr__(self):
f"strict_length={self.strict_length})"
)

@staticmethod
def _find_start_stop_traj(*, trajectory=None, end=None, at_capacity: bool):
@classmethod
def _find_start_stop_traj(cls, *, trajectory=None, end=None, at_capacity: bool):
if trajectory is not None:
# slower
# _, stop_idx = torch.unique_consecutive(trajectory, return_counts=True)
Expand Down Expand Up @@ -835,6 +883,10 @@ def _find_start_stop_traj(*, trajectory=None, end=None, at_capacity: bool):
raise RuntimeError(
"Expected the end-of-trajectory signal to be at least 1-dimensional."
)
return cls._end_to_start_stop(length=length, end=end)

@staticmethod
def _end_to_start_stop(end, length):
# Using transpose ensures the start and stop are sorted the same way
stop_idx = end.transpose(0, -1).nonzero()
stop_idx[:, [0, -1]] = stop_idx[:, [-1, 0]].clone()
Expand All @@ -859,30 +911,33 @@ def _find_start_stop_traj(*, trajectory=None, end=None, at_capacity: bool):
lengths[lengths < 0] = lengths[lengths < 0] + length
return start_idx, stop_idx, lengths

def _start_to_end(self, st: torch.Tensor, length: int):
arange = torch.arange(length, device=st.device, dtype=st.dtype)
ndims = st.shape[-1] - 1 if st.ndim else 0
if ndims:
arange = torch.stack([arange] + [torch.zeros_like(arange)] * ndims, -1)
else:
arange = arange.unsqueeze(-1)
if st.shape != arange.shape:
# we do this to make sure that we're not broadcasting the start
# wrong as a tensor with shape [N] can't be expanded to [N, 1]
# without getting an error
st = st.expand_as(arange)
return arange + st

def _tensor_slices_from_startend(self, seq_length, start, storage_length):
# start is a 2d tensor resulting from nonzero()
# seq_length is a 1d tensor indicating the desired length of each sequence

def _start_to_end(st: torch.Tensor, length: int):
arange = torch.arange(length, device=st.device, dtype=st.dtype)
ndims = st.shape[-1] - 1 if st.ndim else 0
arange = torch.stack([arange] + [torch.zeros_like(arange)] * ndims, -1)
if st.shape != arange.shape:
# we do this to make sure that we're not broadcasting the start
# wrong as a tensor with shape [N] can't be expanded to [N, 1]
# without getting an error
st = st.expand_as(arange)
return arange + st

if isinstance(seq_length, int):
result = torch.cat(
[_start_to_end(_start, length=seq_length) for _start in start]
[self._start_to_end(_start, length=seq_length) for _start in start]
)
else:
# when padding is needed
result = torch.cat(
[
_start_to_end(_start, _seq_len)
self._start_to_end(_start, _seq_len)
for _start, _seq_len in zip(start, seq_length)
]
)
Expand Down Expand Up @@ -945,14 +1000,16 @@ def _adjusted_batch_size(self, batch_size):
if self.num_slices is not None:
if batch_size % self.num_slices != 0:
raise RuntimeError(
f"The batch-size must be divisible by the number of slices, got batch_size={batch_size} and num_slices={self.num_slices}."
f"The batch-size must be divisible by the number of slices, got "
f"batch_size={batch_size} and num_slices={self.num_slices}."
)
seq_length = batch_size // self.num_slices
num_slices = self.num_slices
else:
if batch_size % self.slice_len != 0:
raise RuntimeError(
f"The batch-size must be divisible by the slice length, got batch_size={batch_size} and slice_len={self.slice_len}."
f"The batch-size must be divisible by the slice length, got "
f"batch_size={batch_size} and slice_len={self.slice_len}."
)
seq_length = self.slice_len
num_slices = batch_size // self.slice_len
Expand Down Expand Up @@ -993,8 +1050,8 @@ def _sample_slices(
) -> Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]:
# start_idx and stop_idx are 2d tensors organized like a non-zero

def get_traj_idx(lengths=lengths):
return torch.randint(lengths.shape[0], (num_slices,), device=lengths.device)
def get_traj_idx(maxval):
return torch.randint(maxval, (num_slices,), device=lengths.device)

if (lengths < seq_length).any():
if self.strict_length:
Expand All @@ -1013,7 +1070,7 @@ def get_traj_idx(lengths=lengths):
stop_idx = stop_idx[idx]

if traj_idx is None:
traj_idx = get_traj_idx(lengths=lengths_idx)
traj_idx = get_traj_idx(lengths_idx.shape[0])
else:
# Here we must filter out the indices that correspond to trajectories
# we don't want to keep. That could potentially lead to an empty sample.
Expand All @@ -1036,18 +1093,37 @@ def get_traj_idx(lengths=lengths):
lengths = lengths_idx
else:
if traj_idx is None:
traj_idx = get_traj_idx()
traj_idx = get_traj_idx(lengths.shape[0])
else:
num_slices = traj_idx.shape[0]

# make seq_length a tensor with values clamped by lengths
seq_length = lengths[traj_idx].clamp_max(seq_length)
else:
if traj_idx is None:
traj_idx = get_traj_idx()
traj_idx = get_traj_idx(lengths.shape[0])
else:
num_slices = traj_idx.shape[0]
return self._get_index(
lengths=lengths,
start_idx=start_idx,
stop_idx=stop_idx,
num_slices=num_slices,
seq_length=seq_length,
storage_length=storage_length,
traj_idx=traj_idx,
)

def _get_index(
self,
lengths: torch.Tensor,
start_idx: torch.Tensor,
stop_idx: torch.Tensor,
seq_length: int,
num_slices: int,
storage_length: int,
traj_idx: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, dict]:
relative_starts = (
(
torch.rand(num_slices, device=lengths.device)
Expand Down Expand Up @@ -1130,11 +1206,6 @@ def state_dict(self) -> Dict[str, Any]:
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
...

def __getstate__(self):
state = copy(self.__dict__)
state["_cache"] = {}
return state


class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
"""Samples slices of data along the first dimension, given start and stop signals, without replacement.
Expand Down Expand Up @@ -1182,6 +1253,10 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
:func:`~torchrl.collectors.split_trajectories`. Defaults to ``True``.
shuffle (bool, optional): if ``False``, the order of the trajectories
is not shuffled. Defaults to ``True``.
compile (bool or dict of kwargs, optional): if ``True``, the bottleneck of
the :meth:`~sample` method will be compiled with :func:`~torch.compile`.
Keyword arguments can also be passed to torch.compile with this arg.
Defaults to ``False``.
.. note:: To recover the trajectory splits in the storage,
:class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement` will first
Expand Down Expand Up @@ -1256,6 +1331,7 @@ def __init__(
truncated_key: NestedKey | None = ("next", "truncated"),
strict_length: bool = True,
shuffle: bool = True,
compile: bool | dict = False,
):
SliceSampler.__init__(
self,
Expand All @@ -1268,6 +1344,7 @@ def __init__(
strict_length=strict_length,
ends=ends,
trajectories=trajectories,
compile=compile,
)
SamplerWithoutReplacement.__init__(self, drop_last=drop_last, shuffle=shuffle)

Expand Down Expand Up @@ -1360,7 +1437,7 @@ class SamplerEnsemble(Sampler):
The indices provided in the info dictionary are placed in a :class:`~tensordict.TensorDict` with
keys ``index`` and ``buffer_ids`` that allow the upper :class:`~torchrl.data.ReplayBufferEnsemble`
and :class:`~torchrl.data.StorageEnsemble` objects to retrieve the data.
This format is different than with other samplers which usually return indices
This format is different from with other samplers which usually return indices
as regular tensors.
"""
Expand Down
8 changes: 4 additions & 4 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,16 +576,16 @@ def flatten(self):
def __getstate__(self):
state = copy(self.__dict__)
if get_spawning_popen() is None:
len = self._len
length = self._len
del state["_len_value"]
state["len__context"] = len
state["len__context"] = length
elif not self.initialized:
# check that the storage is initialized
raise RuntimeError(
f"Cannot share a storage of type {type(self)} between processed if "
f"Cannot share a storage of type {type(self)} between processes if "
f"it has not been initialized yet. Populate the buffer with "
f"some data in the main process before passing it to the other "
f"subprocesses (or create the buffer explicitely with a TensorStorage)."
f"subprocesses (or create the buffer explicitly with a TensorStorage)."
)
else:
# check that the content is shared, otherwise tell the user we can't help
Expand Down

0 comments on commit d74fc05

Please sign in to comment.