diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index f975bf34b86..8b182e33fd2 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -911,7 +911,8 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: out = out.expand(self.max_size, *data.shape) out = out.memmap_like(prefix=self.scratch_dir) if self.names is not None: - out.names = self.names + names = self.names + [None] * (out.batch_dims - len(self.names)) + out.refine_names(*names) for key, tensor in sorted( out.items(include_nested=True, leaves_only=True), key=str ):