Skip to content

Commit

Permalink
Merge branch 'sampler-doc' of github.com:pytorch/rl into sampler-doc
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 7, 2024
2 parents a8a6537 + be46f16 commit c203b12
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,10 @@ def update_priority(
indexed elements.
"""
priority = torch.as_tensor(priority, dtype=torch.long)
index = torch.as_tensor(index, dtype=torch.long)
priority = torch.as_tensor(
priority, dtype=torch.long, device=torch.device("cpu")
)
index = torch.as_tensor(index, dtype=torch.long, device=torch.device("cpu"))
# we need to reshape priority if it has more than one elements or if it has
# a different shape than index
if priority.numel() > 1 and priority.shape != index.shape:
Expand Down

0 comments on commit c203b12

Please sign in to comment.