Skip to content

Commit

Permalink
[Performance] Avoid cloning trajs in SliceSampler
Browse files Browse the repository at this point in the history
ghstack-source-id: 2e133fcea716b202694cfa84df3f6e4ba3507bbc
Pull Request resolved: #2671
  • Loading branch information
vmoens committed Dec 20, 2024
1 parent 21eeca4 commit 4fd54fe
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,7 @@ def _get_stop_and_length(self, storage, fallback=True):
"Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
)
vals = self._find_start_stop_traj(
trajectory=trajectory.clone(),
trajectory=trajectory,
at_capacity=storage._is_full,
cursor=getattr(storage, "_last_cursor", None),
)
Expand Down

0 comments on commit 4fd54fe

Please sign in to comment.