diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 15e46ae1038..0352b803b66 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -790,7 +790,7 @@ def _get_stop_and_length(self, storage, fallback=True): raise RuntimeError( "Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories." ) - vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)] + vals = self._find_start_stop_traj(end=done.squeeze()[: len(storage)]) if self.cache_values: self._cache["stop-and-length"] = vals return vals