Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Mar 21, 2024
1 parent 27720b7 commit 041d678
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,8 +795,8 @@ def __repr__(self):
f"strict_length={self.strict_length})"
)

@staticmethod
def _find_start_stop_traj(*, trajectory=None, end=None, at_capacity: bool):
# @staticmethod
def _find_start_stop_traj(self, *, trajectory=None, end=None, at_capacity: bool):
if trajectory is not None:
# slower
# _, stop_idx = torch.unique_consecutive(trajectory, return_counts=True)
Expand Down Expand Up @@ -835,6 +835,9 @@ def _find_start_stop_traj(*, trajectory=None, end=None, at_capacity: bool):
raise RuntimeError(
"Expected the end-of-trajectory signal to be at least 1-dimensional."
)
return self._find_start_stop_traj_sub(length=length, end=end)

def _find_start_stop_traj_sub(self, end, length):
# Using transpose ensures the start and stop are sorted the same way
stop_idx = end.transpose(0, -1).nonzero()
stop_idx[:, [0, -1]] = stop_idx[:, [-1, 0]].clone()
Expand Down

0 comments on commit 041d678

Please sign in to comment.