Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 7, 2024
1 parent c203b12 commit f10c59b
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor:
weight = torch.pow(weight / p_min, -self._beta)
return index, {"_weight": weight}

@torch.no_grad()
def _add_or_extend(self, index: Union[int, torch.Tensor]) -> None:
priority = self.default_priority

Expand All @@ -405,6 +406,11 @@ def _add_or_extend(self, index: Union[int, torch.Tensor]) -> None:
"priority should be a scalar or an iterable of the same "
"length as index"
)
# make sure everything is cast to cpu
if isinstance(index, torch.Tensor) and not index.is_cpu:
index = index.cpu()
if isinstance(priority, torch.Tensor) and not priority.is_cpu:
priority = priority.cpu()

self._sum_tree[index] = priority
self._min_tree[index] = priority
Expand Down Expand Up @@ -435,10 +441,10 @@ def update_priority(
indexed elements.
"""
priority = torch.as_tensor(
priority, dtype=torch.long, device=torch.device("cpu")
)
index = torch.as_tensor(index, dtype=torch.long, device=torch.device("cpu"))
priority = torch.as_tensor(priority, device=torch.device("cpu")).detach()
index = torch.as_tensor(
index, dtype=torch.long, device=torch.device("cpu")
).detach()
# we need to reshape priority if it has more than one elements or if it has
# a different shape than index
if priority.numel() > 1 and priority.shape != index.shape:
Expand Down

0 comments on commit f10c59b

Please sign in to comment.