From f10c59b6e3a4b63fdda2a9a792ad399357939631 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 7 Feb 2024 09:05:59 +0000 Subject: [PATCH] amend --- torchrl/data/replay_buffers/samplers.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 9cd851ff52f..5e9b6dd75be 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -393,6 +393,7 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor: weight = torch.pow(weight / p_min, -self._beta) return index, {"_weight": weight} + @torch.no_grad() def _add_or_extend(self, index: Union[int, torch.Tensor]) -> None: priority = self.default_priority @@ -405,6 +406,11 @@ def _add_or_extend(self, index: Union[int, torch.Tensor]) -> None: "priority should be a scalar or an iterable of the same " "length as index" ) + # make sure everything is cast to cpu + if isinstance(index, torch.Tensor) and not index.is_cpu: + index = index.cpu() + if isinstance(priority, torch.Tensor) and not priority.is_cpu: + priority = priority.cpu() self._sum_tree[index] = priority self._min_tree[index] = priority @@ -435,10 +441,10 @@ def update_priority( indexed elements. """ - priority = torch.as_tensor( - priority, dtype=torch.long, device=torch.device("cpu") - ) - index = torch.as_tensor(index, dtype=torch.long, device=torch.device("cpu")) + priority = torch.as_tensor(priority, device=torch.device("cpu")).detach() + index = torch.as_tensor( + index, dtype=torch.long, device=torch.device("cpu") + ).detach() # we need to reshape priority if it has more than one elements or if it has # a different shape than index if priority.numel() > 1 and priority.shape != index.shape: