From d509f865f49323c9d087dbca7bd1dc4299ff5857 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 2 Sep 2024 16:05:45 +0100 Subject: [PATCH] Update [ghstack-poisoned] --- torchrl/data/replay_buffers/replay_buffers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 2e0eeb80705..9040cc8b0cf 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -652,7 +652,6 @@ def update_priority( with self._replay_lock, self._write_lock: self._sampler.update_priority(index, priority, storage=self.storage) - @pin_memory_output def _sample(self, batch_size: int) -> Tuple[Any, dict]: with self._replay_lock: index, info = self._sampler.sample(self._storage, batch_size) @@ -675,6 +674,7 @@ def empty(self): self._sampler._empty() self._storage._empty() + @pin_memory_output def sample(self, batch_size: int | None = None, return_info: bool = False) -> Any: """Samples a batch of data from the replay buffer. @@ -1262,6 +1262,7 @@ def update_tensordict_priority(self, data: TensorDictBase) -> None: index = index[..., 0] return self.update_priority(index, priority) + @pin_memory_output def sample( self, batch_size: int | None = None,