Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 6, 2024
1 parent d567665 commit a05eaae
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 12 deletions.
6 changes: 4 additions & 2 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,11 +907,13 @@ def extend(self, tensordicts: TensorDictBase) -> torch.Tensor:

if self._transform is not None:
data = self._transform.inv(tensordicts.get("_data"))
tensordicts.set("_data", data)
tensordicts._set_str("_data", data, validated=True, inplace=False)
if data.device is not None:
tensordicts = tensordicts.to(data.device)

tensordicts.batch_size = tensordicts.get("_data").batch_size[:1]
_data = tensordicts.get("_data")
tensordicts.batch_size = _data.batch_size[:1]
tensordicts.names = _data.names[:1]
tensordicts.set(
"index",
torch.zeros(tensordicts.shape, device=tensordicts.device, dtype=torch.int),
Expand Down
56 changes: 46 additions & 10 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import abc
import json
Expand Down Expand Up @@ -258,6 +259,11 @@ class TensorStorage(Storage):
If "auto" is passed, the device is automatically gathered from the
first batch of data passed. This is not enabled by default to avoid
data placed on GPU by mistake, causing OOM issues.
names (list of str): the names of the dimensions of the storage if it is
a TensorDict or equivalent. This option has no effect for non-tensordict
storages. Defaults to ``None`` (no name, or name of the first dimension
of the first batch provided if initialized through
:meth:`~torchrl.data.ReplayBuffer.extend`).
Examples:
>>> data = TensorDict({
Expand Down Expand Up @@ -307,12 +313,15 @@ class TensorStorage(Storage):
"""

@classmethod
def __new__(cls, *args, **kwargs):
cls._storage = None
return super().__new__(cls)
_storage = None

def __init__(self, storage, max_size=None, device="cpu"):
def __init__(
self,
storage,
max_size=None,
device: torch.device = "cpu",
names: List[str] | None = None,
):
if not ((storage is None) ^ (max_size is None)):
if storage is None:
raise ValueError("Expected storage to be non-null.")
Expand All @@ -327,6 +336,7 @@ def __init__(self, storage, max_size=None, device="cpu"):
else:
max_size = tree_flatten(storage)[0][0].shape[0]
super().__init__(max_size)
self.names = names
self.initialized = storage is not None
if self.initialized:
self._len = max_size
Expand Down Expand Up @@ -543,6 +553,8 @@ def set(
self._len = max(self._len, max(cursor) + 1)

if not self.initialized:
if self.names is None and is_tensor_collection(data):
self.names = data.names[:1]
if not isinstance(cursor, INT_CLASSES):
if is_tensor_collection(data):
self._init(data[0])
Expand Down Expand Up @@ -637,6 +649,11 @@ class LazyTensorStorage(TensorStorage):
If "auto" is passed, the device is automatically gathered from the
first batch of data passed. This is not enabled by default to avoid
data placed on GPU by mistake, causing OOM issues.
names (list of str): the names of the dimensions of the storage if it is
a TensorDict or equivalent. This option has no effect for non-tensordict
storages. Defaults to ``None`` (no name, or name of the first dimension
of the first batch provided if initialized through
:meth:`~torchrl.data.ReplayBuffer.extend`).
Examples:
>>> data = TensorDict({
Expand Down Expand Up @@ -688,8 +705,13 @@ class LazyTensorStorage(TensorStorage):
"""

def __init__(self, max_size, device="cpu"):
super().__init__(storage=None, max_size=max_size, device=device)
def __init__(
self,
max_size: int,
device: torch.device = "cpu",
names: List[str] | None = None,
):
super().__init__(storage=None, max_size=max_size, device=device, names=names)

def _init(
self,
Expand All @@ -711,6 +733,8 @@ def _init(
.clone()
.to(self.device)
)
if self.names is not None:
out.names = self.names
else:
# if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype
out = tree_map(
Expand Down Expand Up @@ -739,6 +763,11 @@ class LazyMemmapStorage(LazyTensorStorage):
If ``None`` is provided, the device is automatically gathered from the
first batch of data passed. This is not enabled by default to avoid
data placed on GPU by mistake, causing OOM issues.
names (list of str): the names of the dimensions of the storage if it is
a TensorDict or equivalent. This option has no effect for non-tensordict
storages. Defaults to ``None`` (no name, or name of the first dimension
of the first batch provided if initialized through
:meth:`~torchrl.data.ReplayBuffer.extend`).
Examples:
>>> data = TensorDict({
Expand Down Expand Up @@ -789,8 +818,14 @@ class LazyMemmapStorage(LazyTensorStorage):
"""

def __init__(self, max_size, scratch_dir=None, device="cpu"):
super().__init__(max_size)
def __init__(
self,
max_size: int,
scratch_dir=None,
device: torch.device = "cpu",
names: List[str] | None = None,
):
super().__init__(max_size, names=names)
self.initialized = False
self.scratch_dir = None
if scratch_dir is not None:
Expand Down Expand Up @@ -874,7 +909,8 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
out = data.clone().to(self.device)
out = out.expand(self.max_size, *data.shape)
out = out.memmap_like(prefix=self.scratch_dir)

if self.names is not None:
out.names = self.names
for key, tensor in sorted(
out.items(include_nested=True, leaves_only=True), key=str
):
Expand Down

0 comments on commit a05eaae

Please sign in to comment.