Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 21, 2024
1 parent 253565d commit fb9e562
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 34 deletions.
63 changes: 51 additions & 12 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from torch import multiprocessing as mp
from torch.utils._pytree import tree_flatten, tree_map

from torchrl.collectors import SyncDataCollector, RandomPolicy
from torchrl.collectors import RandomPolicy, SyncDataCollector
from torchrl.data import (
PrioritizedReplayBuffer,
RemoteTensorDictReplayBuffer,
Expand Down Expand Up @@ -833,6 +833,8 @@ def test_extend_list_pytree(self, max_size, shape, storage):
for i in range(10)
]
memory.extend(data)
assert len(memory) == 10
assert len(memory._storage) == 10
sample = memory.sample(10)
for leaf in torch.utils._pytree.tree_leaves(sample):
assert (leaf.unique(sorted=True) == torch.arange(10)).all()
Expand Down Expand Up @@ -2565,25 +2567,62 @@ def test_rb_multidim(self, datatype, datadim, rbtype, storage_cls):
assert leaf.shape[0] == 4
assert (leaf == 1).all()

@pytest.mark.parametrize("writer_cls", [TensorDictMaxValueWriter, RoundRobinWriter, TensorDictRoundRobinWriter])
@pytest.mark.parametrize(
"writer_cls",
[TensorDictMaxValueWriter, RoundRobinWriter, TensorDictRoundRobinWriter],
)
@pytest.mark.parametrize("storage_cls", [LazyMemmapStorage, LazyTensorStorage])
@pytest.mark.parametrize("rbtype", [functools.partial(ReplayBuffer, batch_size=8), functools.partial(TensorDictReplayBuffer, batch_size=8)])
@pytest.mark.parametrize("sampler_cls", [functools.partial(SliceSampler, num_slices=2, strict_length=False),
RandomSampler,
functools.partial(SliceSamplerWithoutReplacement, num_slices=2, strict_length=False),
functools.partial(PrioritizedSampler, alpha=1.0, beta=1.0, max_capacity=10),
functools.partial(PrioritizedSliceSampler, alpha=1.0, beta=1.0, max_capacity=10, num_slices=2, strict_length=False)])
@pytest.mark.parametrize(
"rbtype",
[
functools.partial(ReplayBuffer, batch_size=8),
functools.partial(TensorDictReplayBuffer, batch_size=8),
],
)
@pytest.mark.parametrize(
"sampler_cls",
[
functools.partial(SliceSampler, num_slices=2, strict_length=False),
RandomSampler,
functools.partial(
SliceSamplerWithoutReplacement, num_slices=2, strict_length=False
),
functools.partial(PrioritizedSampler, alpha=1.0, beta=1.0, max_capacity=10),
functools.partial(
PrioritizedSliceSampler,
alpha=1.0,
beta=1.0,
max_capacity=10,
num_slices=2,
strict_length=False,
),
],
)
def test_rb_multidim_collector(self, rbtype, storage_cls, writer_cls, sampler_cls):
from _utils_internal import CARTPOLE_VERSIONED

torch.manual_seed(0)
env = SerialEnv(2, lambda: GymEnv(CARTPOLE_VERSIONED))
env.set_seed(0)
collector = SyncDataCollector(env, RandomPolicy(env.action_spec), frames_per_batch=4, total_frames=16)
collector = SyncDataCollector(
env, RandomPolicy(env.action_spec), frames_per_batch=4, total_frames=16
)
if writer_cls is TensorDictMaxValueWriter:
with pytest.raises(ValueError, match="TensorDictMaxValueWriter is not compatible with storages with more than one dimension"):
rb = rbtype(storage=storage_cls(max_size=10, ndim=2), sampler=sampler_cls(), writer=writer_cls())
with pytest.raises(
ValueError,
match="TensorDictMaxValueWriter is not compatible with storages with more than one dimension",
):
rb = rbtype(
storage=storage_cls(max_size=10, ndim=2),
sampler=sampler_cls(),
writer=writer_cls(),
)
return
rb = rbtype(storage=storage_cls(max_size=10, ndim=2), sampler=sampler_cls(), writer=writer_cls())
rb = rbtype(
storage=storage_cls(max_size=10, ndim=2),
sampler=sampler_cls(),
writer=writer_cls(),
)
for data in collector:
rb.extend(data)
rb.sample()
Expand Down
8 changes: 4 additions & 4 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,7 @@ def _get_priority_item(self, tensordict: TensorDictBase) -> float:
if self._storage.ndim > 1:
# We have to flatten the priority otherwise we'll be aggregating
# the priority across batches
priority = priority.flatten(0, self._storage.ndim-1)
priority = priority.flatten(0, self._storage.ndim - 1)
if priority is None:
return self._sampler.default_priority
try:
Expand All @@ -938,7 +938,7 @@ def _get_priority_item(self, tensordict: TensorDictBase) -> float:
)

if self._storage.ndim > 1:
priority = priority.unflatten(0, tensordict.shape[:self._storage.ndim])
priority = priority.unflatten(0, tensordict.shape[: self._storage.ndim])

return priority

Expand All @@ -953,13 +953,13 @@ def _get_priority_vector(self, tensordict: TensorDictBase) -> torch.Tensor:
if self._storage.ndim > 1:
# We have to flatten the priority otherwise we'll be aggregating
# the priority across batches
priority = priority.flatten(0, self._storage.ndim-1)
priority = priority.flatten(0, self._storage.ndim - 1)

priority = priority.reshape(priority.shape[0], -1)
priority = _reduce(priority, self._sampler.reduction, dim=1)

if self._storage.ndim > 1:
priority = priority.unflatten(0, tensordict.shape[:self._storage.ndim])
priority = priority.unflatten(0, tensordict.shape[: self._storage.ndim])

return priority

Expand Down
23 changes: 12 additions & 11 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,22 +1337,23 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
try:
seq_length = seq_length.unique().item()
except RuntimeError:
raise NotImplementedError(f"seq_length as a list is not supported for now. seq_length={seq_length}.")
print("start_idx, stop_idx, lengths", start_idx, stop_idx, lengths)
raise NotImplementedError(
f"seq_length as a list is not supported for now. seq_length={seq_length}."
)

subtractive_idx = torch.arange(
0, seq_length - 1, 1, device=stop_idx.device, dtype=stop_idx.dtype
)
preceding_stop_idx = (
stop_idx[..., 0, None] - subtractive_idx[None, ...]
)
print("preceding_stop_idx", preceding_stop_idx)
preceding_stop_idx = stop_idx[..., 0, None] - subtractive_idx[None, ...]
preceding_stop_idx = preceding_stop_idx.reshape(-1, 1)
preceding_stop_idx = torch.cat([preceding_stop_idx, stop_idx[:, 1:]], -1)
if storage.ndim > 1:
# convert the 2d index into a flat one to accomodate the _sum_tree
preceding_stop_idx = torch.as_tensor(np.ravel_multi_index(tuple(preceding_stop_idx.transpose(0, 1).numpy()), storage.shape))
print("preceding_stop_idx after", preceding_stop_idx)
preceding_stop_idx = torch.as_tensor(
np.ravel_multi_index(
tuple(preceding_stop_idx.transpose(0, 1).numpy()), storage.shape
)
)

# force to not sample index at the end of a trajectory
self._sum_tree[preceding_stop_idx] = 0.0
Expand All @@ -1361,8 +1362,6 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
starts, info = PrioritizedSampler.sample(
self, storage=storage, batch_size=batch_size // seq_length
)
print("starts", starts)
print("info", info)
if isinstance(starts, tuple):
starts = torch.stack(starts, -1)
# starts = torch.as_tensor(starts, device=lengths.device)
Expand Down Expand Up @@ -1393,7 +1392,9 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
truncated.view(num_slices, -1)[:, -1] = 1
else:
truncated[seq_length.cumsum(0) - 1] = 1
traj_terminated = (stop_idx[traj_idx] == start_idx[traj_idx]).all(-1) + seq_length - 1
traj_terminated = (
(stop_idx[traj_idx] == start_idx[traj_idx]).all(-1) + seq_length - 1
)
terminated = torch.zeros_like(truncated)
if traj_terminated.any():
if isinstance(seq_length, int):
Expand Down
38 changes: 33 additions & 5 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,11 +547,22 @@ def _rand_given_ndim(self, batch_size):

def flatten(self):
if self.ndim == 1:
return self.get(slice(None))
return self
if is_tensor_collection(self._storage):
return self._storage[: self._len_along_dim0].flatten(0, self.ndim-1)
return tree_map(
lambda x: x[: self._len_along_dim0].flatten(0, self.ndim-1), self._storage
if self._is_full:
return TensorStorage(self._storage.flatten(0, self.ndim - 1))
return TensorStorage(
self._storage[: self._len_along_dim0].flatten(0, self.ndim - 1)
)
if self._is_full:
return TensorStorage(
tree_map(lambda x: x.flatten(0, self.ndim - 1), self._storage)
)
return TensorStorage(
tree_map(
lambda x: x[: self._len_along_dim0].flatten(0, self.ndim - 1),
self._storage,
)
)

def __getstate__(self):
Expand Down Expand Up @@ -671,7 +682,6 @@ def set(
cursor: Union[int, Sequence[int], slice],
data: Union[TensorDictBase, torch.Tensor],
):
self._get_new_len(data, cursor)

if isinstance(data, list):
# flip list
Expand All @@ -688,6 +698,8 @@ def set(
f"for per-item addition."
)

self._get_new_len(data, cursor)

if not self.initialized:
if not isinstance(cursor, INT_CLASSES):
if is_tensor_collection(data):
Expand All @@ -707,6 +719,22 @@ def set( # noqa: F811
cursor: Union[int, Sequence[int], slice],
data: Union[TensorDictBase, torch.Tensor],
):

if isinstance(data, list):
# flip list
try:
data = _flip_list(data)
except Exception:
raise RuntimeError(
"Stacking the elements of the list resulted in "
"an error. "
f"Storages of type {type(self)} expect all elements of the list "
f"to have the same tree structure. If the list is compact (each "
f"leaf is itself a batch with the appropriate number of elements) "
f"consider using a tuple instead, as lists are used within `extend` "
f"for per-item addition."
)

self._get_new_len(data, cursor)

if not is_tensor_collection(data) and not isinstance(data, torch.Tensor):
Expand Down
6 changes: 4 additions & 2 deletions torchrl/data/replay_buffers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,10 @@ def __init__(self, rank_key=None, reduction: str = "sum", **kwargs) -> None:

def register_storage(self, storage: Storage) -> None:
if storage.ndim > 1:
raise ValueError("TensorDictMaxValueWriter is not compatible with storages with more than one dimension. "
"See the docstring constructor note about storing trajectories with TensorDictMaxValueWriter.")
raise ValueError(
"TensorDictMaxValueWriter is not compatible with storages with more than one dimension. "
"See the docstring constructor note about storing trajectories with TensorDictMaxValueWriter."
)
return super().register_storage(storage)

def get_insert_index(self, data: Any) -> int:
Expand Down

0 comments on commit fb9e562

Please sign in to comment.