diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index c9751faf01e..c2b4e69389c 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -617,9 +617,9 @@ def _add(self, data): return index def _extend(self, data: Sequence) -> torch.Tensor: - is_compiling = is_compiling() + is_comp = is_compiling() nc = contextlib.nullcontext() - with self._replay_lock if not is_compiling else nc, self._write_lock if not is_compiling else nc: + with self._replay_lock if not is_comp else nc, self._write_lock if not is_comp else nc: if self.dim_extend > 0: data = self._transpose(data) index = self._writer.extend(data)