Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 15, 2024
1 parent 775ffb6 commit 01bcc79
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 53 deletions.
33 changes: 5 additions & 28 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import argparse
import contextlib
import functools
import gc
import importlib
import os
import pickle
Expand Down Expand Up @@ -157,44 +158,20 @@ def _get_datum(self, datatype):

def _get_data(self, datatype, size):
if datatype is None:
data = torch.randint(
100,
(
size,
1,
),
)
data = torch.randint(100, (size, 1))
elif datatype == "tensor":
data = torch.randint(
100,
(
size,
1,
),
)
data = torch.randint(100, (size, 1))
elif datatype == "tensordict":
data = TensorDict(
{
"a": torch.randint(
100,
(
size,
1,
),
),
"a": torch.randint(100, (size, 1)),
"next": {"reward": torch.randn(size, 1)},
},
[size],
)
elif datatype == "pytree":
data = {
"a": torch.randint(
100,
(
size,
1,
),
),
"a": torch.randint(100, (size, 1)),
"b": {"c": [torch.zeros(size, 3), (torch.ones(size, 2),)]},
30: torch.zeros(size, 2),
}
Expand Down
4 changes: 0 additions & 4 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,13 +999,9 @@ def update_tensordict_priority(self, data: TensorDictBase) -> None:
else:
priority = torch.as_tensor(self._get_priority_item(data))
index = data.get("index")
if data.ndim:
valid_index = index >= 0
while index.shape != priority.shape:
# reduce index
index = index[..., 0]
if data.ndim:
return self.update_priority(index[valid_index], priority[valid_index])
return self.update_priority(index, priority)

def sample(
Expand Down
37 changes: 30 additions & 7 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from torchrl._utils import _replace_last
from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage
from torchrl.data.replay_buffers.utils import _is_int

try:
from torchrl._torchrl import (
Expand Down Expand Up @@ -416,10 +417,18 @@ def _add_or_extend(self, index: Union[int, torch.Tensor]) -> None:
"length as index"
)
# make sure everything is cast to cpu
if isinstance(index, torch.Tensor) and not index.is_cpu:
index = index.cpu()
if isinstance(priority, torch.Tensor) and not priority.is_cpu:
priority = priority.cpu()
index = torch.as_tensor(index, device=torch.device("cpu"), dtype=torch.long)
priority = torch.as_tensor(priority, device=torch.device("cpu"))
# MaxValueWriter will set -1 for items in the data that we don't want
# to update. We therefore have to keep only the non-negative indices.
valid_index = index >= 0
if not valid_index.all():
if valid_index.any():
index = index[valid_index]
if priority.numel() > 1:
priority = priority[valid_index]
else:
return

self._sum_tree[index] = priority
self._min_tree[index] = priority
Expand Down Expand Up @@ -451,9 +460,7 @@ def update_priority(
"""
priority = torch.as_tensor(priority, device=torch.device("cpu")).detach()
index = torch.as_tensor(
index, dtype=torch.long, device=torch.device("cpu")
).detach()
index = torch.as_tensor(index, dtype=torch.long, device=torch.device("cpu"))
# we need to reshape priority if it has more than one elements or if it has
# a different shape than index
if priority.numel() > 1 and priority.shape != index.shape:
Expand All @@ -468,6 +475,22 @@ def update_priority(
elif priority.numel() <= 1:
priority = priority.squeeze()

# MaxValueWriter will set -1 for items in the data that we don't want
# to update. We therefore have to keep only the non-negative indices.
if _is_int(index):
if index == -1:
return
else:
if index.ndim > 1:
raise ValueError(f"Unsupported index shape: {index.shape}.")
valid_index = index >= 0
if not valid_index.any():
return
if not valid_index.all():
index = index[valid_index]
if priority.numel():
priority = priority[valid_index]

self._max_priority = priority.max().clamp_min(self._max_priority).item()
priority = torch.pow(priority + self._eps, self._alpha)
self._sum_tree[index] = priority
Expand Down
25 changes: 11 additions & 14 deletions torchrl/data/replay_buffers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,29 +343,26 @@ def extend(self, data: TensorDictBase) -> None:
The ``rank_key`` in the data passed to this module should be structured as [B].
If it has more dimensions, it will be reduced to a single value using the ``reduction`` method.
"""
# a map of [idx_in_storage, idx_in_data]
data_to_replace = {}
for i, sample in enumerate(data):
index = self.get_insert_index(sample)
if index is not None:
data_to_replace[index] = i
for data_idx, sample in enumerate(data):
storage_idx = self.get_insert_index(sample)
if storage_idx is not None:
data_to_replace[storage_idx] = data_idx

# -1 will be interpreted as invalid by prioritized buffers
out_index = torch.full(data.shape, -1, dtype=torch.long)
# Replace the data in the storage all at once
if len(data_to_replace) > 0:
keys, values = zip(*data_to_replace.items())
storage_idx, data_idx = zip(*data_to_replace.items())
index = data.get("index", None)
dtype = index.dtype if index is not None else torch.long
device = index.device if index is not None else data.device
values = list(values)
keys = torch.tensor(keys, dtype=dtype, device=device)
out_index[values] = keys
if index is not None:
index[values] = keys
data.set("index", index)
self._storage.set(keys, data[values])
return self._replicate_index(out_index)
return out_index
data_idx = torch.as_tensor(data_idx, dtype=dtype, device=device)
storage_idx = torch.as_tensor(storage_idx, dtype=dtype, device=device)
out_index[data_idx] = storage_idx
self._storage.set(storage_idx, data[data_idx])
return self._replicate_index(out_index)

def _empty(self) -> None:
self._cursor = 0
Expand Down

0 comments on commit 01bcc79

Please sign in to comment.