diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index ff2a00ab242..e7f4da9c4bb 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -176,7 +176,7 @@ def add(self, data: Any) -> int | torch.Tensor: # Other than that, a "flat" (1d) index is ok to write the data self._storage.set(_cursor, data) index = self._replicate_index(index) - for ent in self._storage._attached_entities_iter: + for ent in self._storage._attached_entities_iter(): ent.mark_update(index) return index @@ -302,7 +302,7 @@ def add(self, data: Any) -> int | torch.Tensor: ) self._storage.set(index, data) index = self._replicate_index(index) - for ent in self._storage._attached_entities_iter: + for ent in self._storage._attached_entities_iter(): ent.mark_update(index) return index @@ -332,7 +332,7 @@ def extend(self, data: Sequence) -> torch.Tensor: # Other than that, a "flat" (1d) index is ok to write the data self._storage.set(index, data) index = self._replicate_index(index) - for ent in self._storage._attached_entities_iter: + for ent in self._storage._attached_entities_iter(): ent.mark_update(index) return index @@ -533,7 +533,7 @@ def add(self, data: Any) -> int | torch.Tensor: # Other than that, a "flat" (1d) index is ok to write the data self._storage.set(index, data) index = self._replicate_index(index) - for ent in self._storage._attached_entities_iter: + for ent in self._storage._attached_entities_iter(): ent.mark_update(index) return index @@ -567,7 +567,7 @@ def extend(self, data: TensorDictBase) -> None: device = getattr(self._storage, "device", None) out_index = torch.full(data.shape, -1, dtype=torch.long, device=device) index = self._replicate_index(out_index) - for ent in self._storage._attached_entities_iter: + for ent in self._storage._attached_entities_iter(): ent.mark_update(index) return index