Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 8, 2024
1 parent 1874e9a commit ffcdf51
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 26 deletions.
62 changes: 38 additions & 24 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,7 @@ def extend(self, data: Sequence) -> torch.Tensor:
Returns:
Indices of the data added to the replay buffer.
"""
if self._transform is not None and (
is_tensor_collection(data) or len(self._transform)
):
if self._transform is not None and (is_tensor_collection(data)):
data = self._transform.inv(data)
return self._extend(data)

Expand All @@ -380,12 +378,8 @@ def _sample(self, batch_size: int) -> Tuple[Any, dict]:
if not is_tensor_collection(data):
data = TensorDict({"data": data}, [])
is_td = False
is_locked = data.is_locked
if is_locked:
data.unlock_()
data = self._transform(data)
if is_locked:
data.lock_()
with data.unlock_():
data = self._transform(data)
if not is_td:
data = data["data"]

Expand Down Expand Up @@ -800,7 +794,7 @@ def extend(self, tensordicts: TensorDictBase) -> torch.Tensor:

tensordicts = TensorDict(
{"_data": tensordicts},
batch_size=tensordicts.batch_size[:1],
batch_size=torch.Size([]),
)
if tensordicts.batch_dims > 1:
# we want the tensordict to have one dimension only. The batch size
Expand All @@ -818,17 +812,19 @@ def extend(self, tensordicts: TensorDictBase) -> torch.Tensor:
tensordicts.batch_size[0], tensordicts.batch_dims - 1
)
tensordicts.set("_rb_batch_size", shape)
tensordicts.set(
"index",
torch.zeros(tensordicts.shape, device=tensordicts.device, dtype=torch.int),
)

if self._transform is not None:
data = self._transform.inv(tensordicts.get("_data"))
tensordicts.set("_data", data)
if data.device is not None:
tensordicts = tensordicts.to(data.device)

tensordicts.batch_size = tensordicts.get("_data").batch_size[:1]
tensordicts.set(
"index",
torch.zeros(tensordicts.shape, device=tensordicts.device, dtype=torch.int),
)

index = super()._extend(tensordicts)
self.update_tensordict_priority(tensordicts)
return index
Expand Down Expand Up @@ -877,20 +873,38 @@ def sample(

data, info = super().sample(batch_size, return_info=True)
if not is_tensorclass(data) and include_info in (True, None):
is_locked = data.is_locked
if is_locked:
data.unlock_()
for k, v in info.items():
v = _to_torch(v, data.device)
if v.shape[: data.batch_dims] != data.batch_size:
v = expand_as_right(v, data)
data.set(k, v)
if is_locked:
data.lock_()
with data.unlock_():
for k, v in info.items():
v = _to_torch(v, data.device)
if v.shape[: data.batch_dims] != data.batch_size:
v = expand_as_right(v, data)
data.set(k, v)

if self._transform is not None and len(self._transform):
is_td = True
if not is_tensor_collection(data):
data = TensorDict({"data": data}, [])
is_td = False
with data.unlock_():
data = self._transform(data)
if not is_td:
data = data["data"]

if return_info:
return data, info
return data

@pin_memory_output
def _sample(self, batch_size: int) -> Tuple[Any, dict]:
with self._replay_lock:
index, info = self._sampler.sample(self._storage, batch_size)
info["index"] = index
data = self._storage[index]
if not isinstance(index, INT_CLASSES):
data = self._collate_fn(data)

return data, info


class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
"""TensorDict-specific wrapper around the :class:`~torchrl.data.PrioritizedReplayBuffer` class.
Expand Down
3 changes: 1 addition & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:

@dispatch(source="in_keys_inv", dest="out_keys_inv")
def inv(self, tensordict: TensorDictBase) -> TensorDictBase:
out = self._inv_call(tensordict.clone(False))
return out
return self._inv_call(tensordict.clone(False))

def transform_env_device(self, device: torch.device):
"""Transforms the device of the parent env."""
Expand Down

0 comments on commit ffcdf51

Please sign in to comment.