Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Mar 4, 2024
1 parent ef7d366 commit ed98a4f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchrl/data/datasets/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,8 +761,8 @@ def _read_from_splits(self, item: int | torch.Tensor):
split = (item < self.frames_per_split[1:, 1].unsqueeze(1)) & (
item >= self.frames_per_split[:-1, 1].unsqueeze(1)
)
# split_tmp, idx = split.squeeze().nonzero().unbind(-1)
split_tmp, idx = split.nonzero().unbind(-1)
split_tmp, idx = split.squeeze(0).nonzero().unbind(-1)
# split_tmp, idx = split.nonzero().unbind(-1)
split = torch.zeros_like(split_tmp)
split[idx] = split_tmp
split = self.frames_per_split[split + 1, 0]
Expand Down

0 comments on commit ed98a4f

Please sign in to comment.