diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index c9e012ea069..d0471d29f1e 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -1082,7 +1082,56 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler): """Samples slices of data along the first dimension, given start and stop signals, using prioritized sampling. - For more info see: SliceSampler and PrioritizedSampler + This class samples sub-trajectories with replacement following a priority weighting presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. + Prioritized experience replay." + (https://arxiv.org/abs/1511.05952) + + For more info see :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` and :class:`~torchrl.data.replay_buffers.samplers.PrioritizedSampler`. + + Args: + alpha (float): exponent α determines how much prioritization is used, + with α = 0 corresponding to the uniform case. + beta (float): importance sampling negative exponent. + eps (float, optional): delta added to the priorities to ensure that the buffer + does not contain null priorities. Defaults to 1e-8. + reduction (str, optional): the reduction method for multidimensional + tensordicts (ie stored trajectory). Can be one of "max", "min", + "median" or "mean". + + Keyword Args: + num_slices (int): the number of slices to be sampled. The batch-size + must be greater or equal to the ``num_slices`` argument. Exclusive + with ``slice_len``. + slice_len (int): the length of the slices to be sampled. The batch-size + must be greater or equal to the ``slice_len`` argument and divisible + by it. Exclusive with ``num_slices``. + end_key (NestedKey, optional): the key indicating the end of a + trajectory (or episode). Defaults to ``("next", "done")``. + traj_key (NestedKey, optional): the key indicating the trajectories. + Defaults to ``"episode"`` (commonly used across datasets in TorchRL). + ends (torch.Tensor, optional): a 1d boolean tensor containing the end of run signals. + To be used whenever the ``end_key`` or ``traj_key`` is expensive to get, + or when this signal is readily available. Must be used with ``cache_values=True`` + and cannot be used in conjunction with ``end_key`` or ``traj_key``. + trajectories (torch.Tensor, optional): a 1d integer tensor containing the run ids. + To be used whenever the ``end_key`` or ``traj_key`` is expensive to get, + or when this signal is readily available. Must be used with ``cache_values=True`` + and cannot be used in conjunction with ``end_key`` or ``traj_key``. + cache_values (bool, optional): to be used with static datasets. + Will cache the start and end signal of the trajectory. + truncated_key (NestedKey, optional): If not ``None``, this argument + indicates where a truncated signal should be written in the output + data. This is used to indicate to value estimators where the provided + trajectory breaks. Defaults to ``("next", "truncated")``. + This feature only works with :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer` + instances (otherwise the truncated key is returned in the info dictionary + returned by the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method). + strict_length (bool, optional): if ``False``, trajectories of length + shorter than `slice_len` (or `batch_size // num_slices`) will be + allowed to appear in the batch. + Be mindful that this can result in effective `batch_size` shorter + than the one asked for! Trajectories can be split using + :func:`torchrl.collectors.split_trajectories`. Defaults to ``True``. """ def __init__( @@ -1159,19 +1208,11 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict] stop_idx[..., None] - subtractive_idx[None, ...] ).view(-1) else: - raise NotImplementedError() - # preceding_stop_idx = torch.cat( - # [ - # stop_idx - # - torch.arange(_seq_len, device=stop_idx.device, dtype=stop_idx.dtype) - # for stop_idx, _seq_len in zip(stop_idx, seq_length) - # ] - # ) + raise NotImplementedError("seq_length as a list is not supported for now") # force to not sample index at the end of a trajectory. - # it's ok to not touch self._min_tree. + # no need to update self._min_tree. self._sum_tree[preceding_stop_idx] = 0.0 - # self._min_tree[preceding_stop_idx] = 0.0 starts, info = PrioritizedSampler.sample( self, storage=storage, batch_size=batch_size // seq_length