From 2d92eab51a69326d7c5b1b38a2efe8fee7c7d19c Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Wed, 7 Feb 2024 14:33:51 +0100 Subject: [PATCH] [BugFix] Use traj_terminated for proper population of terminated key in SliceSampler --- torchrl/data/replay_buffers/samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 5e9b6dd75be..21245e37acd 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -917,7 +917,7 @@ def _sample_slices( truncated[seq_length.cumsum(0) - 1] = 1 traj_terminated = stop_idx[traj_idx] == start_idx[traj_idx] + seq_length - 1 terminated = torch.zeros_like(truncated) - if terminated.any(): + if traj_terminated.any(): if isinstance(seq_length, int): truncated.view(num_slices, -1)[traj_terminated] = 1 else: