diff --git a/test/test_rb.py b/test/test_rb.py index 7b763ec7dbe..1352e3ed629 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -775,6 +775,12 @@ class TC: if data == "tc": assert storage._storage.text == storage_recover._storage.text + def test_add_list_of_tds(self): + rb = ReplayBuffer(storage=LazyTensorStorage(100)) + rb.extend([TensorDict({"a": torch.randn(2, 3)}, [2])]) + assert len(rb) == 1 + assert rb[:].shape == torch.Size([1, 2]) + @pytest.mark.parametrize("max_size", [1000]) @pytest.mark.parametrize("shape", [[3, 4]]) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 4a91d6f2317..214777cb772 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -1511,6 +1511,8 @@ def save_tensor(tensor_path: str, tensor: torch.Tensor): def _flip_list(data): + if all(is_tensor_collection(_data) for _data in data): + return torch.stack(data) flat_data, flat_specs = zip(*[tree_flatten(item) for item in data]) flat_data = zip(*flat_data) stacks = [torch.stack(item) for item in flat_data]