Skip to content

Commit

Permalink
Improve docstring + remove comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Cadene committed Feb 6, 2024
1 parent ac4fa9d commit 74a4bee
Showing 1 changed file with 52 additions and 11 deletions.
63 changes: 52 additions & 11 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,7 +1082,56 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler):
"""Samples slices of data along the first dimension, given start and stop signals, using prioritized sampling.
For more info see: SliceSampler and PrioritizedSampler
This class samples sub-trajectories with replacement following a priority weighting presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015.
Prioritized experience replay."
(https://arxiv.org/abs/1511.05952)
For more info see :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` and :class:`~torchrl.data.replay_buffers.samplers.PrioritizedSampler`.
Args:
alpha (float): exponent α determines how much prioritization is used,
with α = 0 corresponding to the uniform case.
beta (float): importance sampling negative exponent.
eps (float, optional): delta added to the priorities to ensure that the buffer
does not contain null priorities. Defaults to 1e-8.
reduction (str, optional): the reduction method for multidimensional
tensordicts (ie stored trajectory). Can be one of "max", "min",
"median" or "mean".
Keyword Args:
num_slices (int): the number of slices to be sampled. The batch-size
must be greater or equal to the ``num_slices`` argument. Exclusive
with ``slice_len``.
slice_len (int): the length of the slices to be sampled. The batch-size
must be greater or equal to the ``slice_len`` argument and divisible
by it. Exclusive with ``num_slices``.
end_key (NestedKey, optional): the key indicating the end of a
trajectory (or episode). Defaults to ``("next", "done")``.
traj_key (NestedKey, optional): the key indicating the trajectories.
Defaults to ``"episode"`` (commonly used across datasets in TorchRL).
ends (torch.Tensor, optional): a 1d boolean tensor containing the end of run signals.
To be used whenever the ``end_key`` or ``traj_key`` is expensive to get,
or when this signal is readily available. Must be used with ``cache_values=True``
and cannot be used in conjunction with ``end_key`` or ``traj_key``.
trajectories (torch.Tensor, optional): a 1d integer tensor containing the run ids.
To be used whenever the ``end_key`` or ``traj_key`` is expensive to get,
or when this signal is readily available. Must be used with ``cache_values=True``
and cannot be used in conjunction with ``end_key`` or ``traj_key``.
cache_values (bool, optional): to be used with static datasets.
Will cache the start and end signal of the trajectory.
truncated_key (NestedKey, optional): If not ``None``, this argument
indicates where a truncated signal should be written in the output
data. This is used to indicate to value estimators where the provided
trajectory breaks. Defaults to ``("next", "truncated")``.
This feature only works with :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer`
instances (otherwise the truncated key is returned in the info dictionary
returned by the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method).
strict_length (bool, optional): if ``False``, trajectories of length
shorter than `slice_len` (or `batch_size // num_slices`) will be
allowed to appear in the batch.
Be mindful that this can result in effective `batch_size` shorter
than the one asked for! Trajectories can be split using
:func:`torchrl.collectors.split_trajectories`. Defaults to ``True``.
"""

def __init__(
Expand Down Expand Up @@ -1159,19 +1208,11 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
stop_idx[..., None] - subtractive_idx[None, ...]
).view(-1)
else:
raise NotImplementedError()
# preceding_stop_idx = torch.cat(
# [
# stop_idx
# - torch.arange(_seq_len, device=stop_idx.device, dtype=stop_idx.dtype)
# for stop_idx, _seq_len in zip(stop_idx, seq_length)
# ]
# )
raise NotImplementedError("seq_length as a list is not supported for now")

# force to not sample index at the end of a trajectory.
# it's ok to not touch self._min_tree.
# no need to update self._min_tree.
self._sum_tree[preceding_stop_idx] = 0.0
# self._min_tree[preceding_stop_idx] = 0.0

starts, info = PrioritizedSampler.sample(
self, storage=storage, batch_size=batch_size // seq_length
Expand Down

0 comments on commit 74a4bee

Please sign in to comment.