diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 852592992b9..607be49211a 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -148,6 +148,7 @@ using the following components: LazyMemmapStorage LazyTensorStorage ListStorage + LazyStackStorage ListStorageCheckpointer NestedStorageCheckpointer PrioritizedSampler diff --git a/test/test_rb.py b/test/test_rb.py index a139d34f1a5..b63f888453d 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -81,6 +81,7 @@ from torchrl.data.replay_buffers.storages import ( LazyMemmapStorage, + LazyStackStorage, LazyTensorStorage, ListStorage, StorageEnsemble, @@ -1116,6 +1117,31 @@ def test_storage_inplace_writing_ndim(self, storage_type): assert (rb[:, 10:20] == 0).all() assert len(rb) == 100 + @pytest.mark.parametrize("max_size", [1000, None]) + @pytest.mark.parametrize("stack_dim", [-1, 0]) + def test_lazy_stack_storage(self, max_size, stack_dim): + # Create an instance of LazyStackStorage with given parameters + storage = LazyStackStorage(max_size=max_size, stack_dim=stack_dim) + # Create a ReplayBuffer using the created storage + rb = ReplayBuffer(storage=storage) + # Generate some random data to add to the buffer + torch.manual_seed(0) + data0 = TensorDict(a=torch.randn((10,)), b=torch.rand(4), c="a string!") + data1 = TensorDict(a=torch.randn((11,)), b=torch.rand(4), c="another string!") + # Add the data to the buffer + rb.add(data0) + rb.add(data1) + # Sample from the buffer + sample = rb.sample(10) + # Check that the sampled data has the correct shape and type + assert isinstance(sample, LazyStackedTensorDict) + assert sample["b"].shape[0] == 10 + assert all(isinstance(item, str) for item in sample["c"]) + # If densify is True, check that the sampled data is dense + sample = sample.densify(layout=torch.jagged) + assert isinstance(sample["a"], torch.Tensor) + assert sample["a"].shape[0] == 10 + @pytest.mark.parametrize("max_size", [1000]) @pytest.mark.parametrize("shape", [[3, 4]]) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 3ed65d59d16..7fa882cbbaa 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -23,6 +23,7 @@ H5StorageCheckpointer, ImmutableDatasetWriter, LazyMemmapStorage, + LazyStackStorage, LazyTensorStorage, ListStorage, ListStorageCheckpointer, diff --git a/torchrl/data/replay_buffers/__init__.py b/torchrl/data/replay_buffers/__init__.py index 25822dcfe4c..4f230f30701 100644 --- a/torchrl/data/replay_buffers/__init__.py +++ b/torchrl/data/replay_buffers/__init__.py @@ -32,6 +32,7 @@ ) from .storages import ( LazyMemmapStorage, + LazyStackStorage, LazyTensorStorage, ListStorage, Storage, diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index fc27401d5e5..fa92d84295a 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -1049,6 +1049,10 @@ def __init__( self._cache["stop-and-length"] = vals else: + if traj_key is not None: + self._fetch_traj = True + elif end_key is not None: + self._fetch_traj = False if end_key is None: end_key = ("next", "done") if traj_key is None: @@ -1331,7 +1335,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict] if start_idx.shape[1] != storage.ndim: raise RuntimeError( f"Expected the end-of-trajectory signal to be " - f"{storage.ndim}-dimensional. Got a {start_idx.shape[1]} tensor " + f"{storage.ndim}-dimensional. Got a tensor with shape[1]={start_idx.shape[1]} " "instead." ) seq_length, num_slices = self._adjusted_batch_size(batch_size) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 52d137208ad..344814e728c 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -297,7 +297,15 @@ def set( def get(self, index: Union[int, Sequence[int], slice]) -> Any: if isinstance(index, (INT_CLASSES, slice)): return self._storage[index] + elif isinstance(index, tuple): + if len(index) > 1: + raise RuntimeError( + f"{type(self).__name__} can only be indexed with one-length tuples." + ) + return self.get(index[0]) else: + if isinstance(index, torch.Tensor) and index.device.type != "cpu": + index = index.cpu().tolist() return [self._storage[i] for i in index] def __len__(self): @@ -353,6 +361,77 @@ def contains(self, item): raise NotImplementedError(f"type {type(item)} is not supported yet.") +class LazyStackStorage(ListStorage): + """A ListStorage that returns LazyStackTensorDict instances. + + This storage allows for heterougeneous structures to be indexed as a single `TensorDict` representation. + It uses :class:`~tensordict.LazyStackedTensorDict` which operates on non-contiguous lists of tensordicts, + lazily stacking items when queried. + This means that this storage is going to be fast to sample but data access may be slow (as it requires a stack). + Tensors of heterogeneous shapes can also be stored within the storage and stacked together. + Because the storage is represented as a list, the number of tensors to store in memory will grow linearly with + the size of the buffer. + + If possible, nested tensors can also be created via :meth:`~tensordict.LazyStackedTensorDict.densify` + (see :mod:`~torch.nested`). + + Args: + max_size (int, optional): the maximum number of elements stored in the storage. + If not provided, an unlimited storage is created. + + Keyword Args: + compilable (bool, optional): if ``True``, the storage will be made compatible with :func:`~torch.compile` at + the cost of being executable in multiprocessed settings. + stack_dim (int, optional): the stack dimension in terms of TensorDict batch sizes. Defaults to `-1`. + + Examples: + >>> import torch + >>> from torchrl.data import ReplayBuffer, LazyStackStorage + >>> from tensordict import TensorDict + >>> _ = torch.manual_seed(0) + >>> rb = ReplayBuffer(storage=LazyStackStorage(max_size=1000, stack_dim=-1)) + >>> data0 = TensorDict(a=torch.randn((10,)), b=torch.rand(4), c="a string!") + >>> data1 = TensorDict(a=torch.randn((11,)), b=torch.rand(4), c="another string!") + >>> _ = rb.add(data0) + >>> _ = rb.add(data1) + >>> rb.sample(10) + LazyStackedTensorDict( + fields={ + a: Tensor(shape=torch.Size([10, -1]), device=cpu, dtype=torch.float32, is_shared=False), + b: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False), + c: NonTensorStack( + ['another string!', 'another string!', 'another st..., + batch_size=torch.Size([10]), + device=None)}, + exclusive_fields={ + }, + batch_size=torch.Size([10]), + device=None, + is_shared=False, + stack_dim=0) + """ + + def __init__( + self, + max_size: int | None = None, + *, + compilable: bool = False, + stack_dim: int = -1, + ): + super().__init__(max_size=max_size, compilable=compilable) + self.stack_dim = stack_dim + + def get(self, index: Union[int, Sequence[int], slice]) -> Any: + out = super().get(index=index) + if isinstance(out, list): + stack_dim = self.stack_dim + if stack_dim < 0: + stack_dim = out[0].ndim + 1 + stack_dim + out = LazyStackedTensorDict(*out, stack_dim=stack_dim) + return out + return out + + class TensorStorage(Storage): """A storage for tensors and tensordicts.