From be46f1655e83f686838f23b9233775f36f5b0de2 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 7 Feb 2024 00:33:22 -0800 Subject: [PATCH] amend --- torchrl/data/replay_buffers/samplers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index e1aa8e7c6db..9cd851ff52f 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -435,8 +435,10 @@ def update_priority( indexed elements. """ - priority = torch.as_tensor(priority, dtype=torch.long) - index = torch.as_tensor(index, dtype=torch.long) + priority = torch.as_tensor( + priority, dtype=torch.long, device=torch.device("cpu") + ) + index = torch.as_tensor(index, dtype=torch.long, device=torch.device("cpu")) # we need to reshape priority if it has more than one elements or if it has # a different shape than index if priority.numel() > 1 and priority.shape != index.shape: