diff --git a/test/test_rb.py b/test/test_rb.py index 359b245fd9f..af7140b3984 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -546,6 +546,20 @@ def test_errors(self, storage_type): ): storage_type(data, max_size=4) + def test_existsok_lazymemmap(self, tmpdir): + storage0 = LazyMemmapStorage(10, scratch_dir=tmpdir) + rb = ReplayBuffer(storage=storage0) + rb.extend(TensorDict(a=torch.randn(3), batch_size=[3])) + + storage1 = LazyMemmapStorage(10, scratch_dir=tmpdir) + rb = ReplayBuffer(storage=storage1) + with pytest.raises(RuntimeError, match="existsok"): + rb.extend(TensorDict(a=torch.randn(3), batch_size=[3])) + + storage2 = LazyMemmapStorage(10, scratch_dir=tmpdir, existsok=True) + rb = ReplayBuffer(storage=storage2) + rb.extend(TensorDict(a=torch.randn(3), batch_size=[3])) + @pytest.mark.parametrize( "data_type", ["tensor", "tensordict", "tensorclass", "pytree"] ) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index e49ab509a01..20b2169cc8e 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -923,6 +923,8 @@ class LazyMemmapStorage(LazyTensorStorage): Args: max_size (int): size of the storage, i.e. maximum number of elements stored in the buffer. + + Keyword Args: scratch_dir (str or path): directory where memmap-tensors will be written. device (torch.device, optional): device where the sampled tensors will be stored and sent. Default is :obj:`torch.device("cpu")`. @@ -933,6 +935,9 @@ class LazyMemmapStorage(LazyTensorStorage): measuring the storage size. For instance, a storage of shape ``[3, 4]`` has capacity ``3`` if ``ndim=1`` and ``12`` if ``ndim=2``. Defaults to ``1``. + existsok (bool, optional): whether an error should be raised if any of the + tensors already exists on disk. Defaults to ``True``. If ``False``, the + tensor will be opened as is, not overewritten. .. note:: When checkpointing a ``LazyMemmapStorage``, one can provide a path identical to where the storage is already stored to avoid executing long copies of data that is already stored on disk. @@ -1009,10 +1014,12 @@ def __init__( scratch_dir=None, device: torch.device = "cpu", ndim: int = 1, + existsok: bool = False, ): super().__init__(max_size, ndim=ndim) self.initialized = False self.scratch_dir = None + self.existsok = existsok if scratch_dir is not None: self.scratch_dir = str(scratch_dir) if self.scratch_dir[-1] != "/": @@ -1108,7 +1115,7 @@ def max_size_along_dim0(data_shape): if is_tensor_collection(data): out = data.clone().to(self.device) out = out.expand(max_size_along_dim0(data.shape)) - out = out.memmap_like(prefix=self.scratch_dir) + out = out.memmap_like(prefix=self.scratch_dir, existsok=self.existsok) for key, tensor in sorted( out.items(include_nested=True, leaves_only=True), key=str ):