Skip to content

Commit

Permalink
[BugFix] Use traj_terminated in SliceSampler (#1884)
Browse files Browse the repository at this point in the history
  • Loading branch information
Cadene authored Feb 7, 2024
1 parent 144f547 commit b34e2d2
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,7 @@ def _sample_slices(
truncated[seq_length.cumsum(0) - 1] = 1
traj_terminated = stop_idx[traj_idx] == start_idx[traj_idx] + seq_length - 1
terminated = torch.zeros_like(truncated)
if terminated.any():
if traj_terminated.any():
if isinstance(seq_length, int):
truncated.view(num_slices, -1)[traj_terminated] = 1
else:
Expand Down

0 comments on commit b34e2d2

Please sign in to comment.