diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index ebeea7dcf6b..8c928ced9b6 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -883,40 +883,21 @@ def add(self, data: TensorDictBase) -> int: return index def extend(self, tensordicts: TensorDictBase) -> torch.Tensor: + _data = tensordicts + + if self._transform is not None: + _data = self._transform.inv(_data) tensordicts = TensorDict( - {"_data": tensordicts}, - batch_size=torch.Size([]), + {"_data": _data}, + batch_size=_data.shape[:1], + device=_data.device, + names=_data.names[:1] if _data._has_names() else None, ) - if tensordicts.batch_dims > 1: - # we want the tensordict to have one dimension only. The batch size - # of the sampled tensordicts can be changed thereafter - if not isinstance(tensordicts, LazyStackedTensorDict): - tensordicts = tensordicts.clone(recurse=False) - else: - tensordicts = tensordicts.contiguous() - # we keep track of the batch size to reinstantiate it when sampling - if "_rb_batch_size" in tensordicts.keys(): - raise KeyError( - "conflicting key '_rb_batch_size'. Consider removing from data." - ) - shape = torch.tensor(tensordicts.batch_size[1:]).expand( - tensordicts.batch_size[0], tensordicts.batch_dims - 1 - ) - tensordicts.set("_rb_batch_size", shape) - if self._transform is not None: - data = self._transform.inv(tensordicts.get("_data")) - tensordicts._set_str("_data", data, validated=True, inplace=False) - if data.device is not None: - tensordicts = tensordicts.to(data.device) - - _data = tensordicts.get("_data") - tensordicts.batch_size = _data.batch_size[:1] - tensordicts.names = _data.names[:1] tensordicts.set( "index", - torch.zeros(tensordicts.shape, device=tensordicts.device, dtype=torch.int), + torch.zeros(tensordicts.shape, device=tensordicts.device, dtype=torch.long), ) index = super()._extend(tensordicts) @@ -989,10 +970,12 @@ def _sample(self, batch_size: int) -> Tuple[Any, dict]: with self._replay_lock: index, info = self._sampler.sample(self._storage, batch_size) info["index"] = index - data = self._storage[index] + data = self._storage.get(index) if not isinstance(index, INT_CLASSES): data = self._collate_fn(data) - + if self._transform is not None and len(self._transform): + with data.unlock_(), _set_dispatch_td_nn_modules(True): + data = self._transform(data) return data, info diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 1a0119cd50e..f975bf34b86 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -734,7 +734,8 @@ def _init( .to(self.device) ) if self.names is not None: - out.names = self.names + names = self.names + [None] * (out.batch_dims - len(self.names)) + out.refine_names(*names) else: # if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype out = tree_map(