Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 7, 2024
1 parent 34fadac commit fe960fa
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 31 deletions.
43 changes: 13 additions & 30 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,40 +883,21 @@ def add(self, data: TensorDictBase) -> int:
return index

def extend(self, tensordicts: TensorDictBase) -> torch.Tensor:
_data = tensordicts

if self._transform is not None:
_data = self._transform.inv(_data)

tensordicts = TensorDict(
{"_data": tensordicts},
batch_size=torch.Size([]),
{"_data": _data},
batch_size=_data.shape[:1],
device=_data.device,
names=_data.names[:1] if _data._has_names() else None,
)
if tensordicts.batch_dims > 1:
# we want the tensordict to have one dimension only. The batch size
# of the sampled tensordicts can be changed thereafter
if not isinstance(tensordicts, LazyStackedTensorDict):
tensordicts = tensordicts.clone(recurse=False)
else:
tensordicts = tensordicts.contiguous()
# we keep track of the batch size to reinstantiate it when sampling
if "_rb_batch_size" in tensordicts.keys():
raise KeyError(
"conflicting key '_rb_batch_size'. Consider removing from data."
)
shape = torch.tensor(tensordicts.batch_size[1:]).expand(
tensordicts.batch_size[0], tensordicts.batch_dims - 1
)
tensordicts.set("_rb_batch_size", shape)

if self._transform is not None:
data = self._transform.inv(tensordicts.get("_data"))
tensordicts._set_str("_data", data, validated=True, inplace=False)
if data.device is not None:
tensordicts = tensordicts.to(data.device)

_data = tensordicts.get("_data")
tensordicts.batch_size = _data.batch_size[:1]
tensordicts.names = _data.names[:1]
tensordicts.set(
"index",
torch.zeros(tensordicts.shape, device=tensordicts.device, dtype=torch.int),
torch.zeros(tensordicts.shape, device=tensordicts.device, dtype=torch.long),
)

index = super()._extend(tensordicts)
Expand Down Expand Up @@ -989,10 +970,12 @@ def _sample(self, batch_size: int) -> Tuple[Any, dict]:
with self._replay_lock:
index, info = self._sampler.sample(self._storage, batch_size)
info["index"] = index
data = self._storage[index]
data = self._storage.get(index)
if not isinstance(index, INT_CLASSES):
data = self._collate_fn(data)

if self._transform is not None and len(self._transform):
with data.unlock_(), _set_dispatch_td_nn_modules(True):
data = self._transform(data)
return data, info


Expand Down
3 changes: 2 additions & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,8 @@ def _init(
.to(self.device)
)
if self.names is not None:
out.names = self.names
names = self.names + [None] * (out.batch_dims - len(self.names))
out.refine_names(*names)
else:
# if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype
out = tree_map(
Expand Down

0 comments on commit fe960fa

Please sign in to comment.