Skip to content

Commit

Permalink
Add PrioritizedSliceSampler + few tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Cadene committed Feb 6, 2024
1 parent 5f82601 commit ac4fa9d
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 15 deletions.
48 changes: 34 additions & 14 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from torchrl.data.replay_buffers import samplers, writers
from torchrl.data.replay_buffers.samplers import (
PrioritizedSampler,
PrioritizedSliceSampler,
RandomSampler,
SamplerEnsemble,
SamplerWithoutReplacement,
Expand Down Expand Up @@ -1834,13 +1835,14 @@ def test_sampler_without_rep_state_dict(self, backend):
assert (s.exclude("index") == 0).all()

@pytest.mark.parametrize(
"batch_size,num_slices,slice_len",
"batch_size,num_slices,slice_len,prioritized",
[
[100, 20, None],
[120, 30, None],
[100, None, 5],
[120, None, 4],
[101, None, 101],
[100, 20, None, True],
[100, 20, None, False],
[120, 30, None, False],
[100, None, 5, False],
[120, None, 4, False],
[101, None, 101, False],
],
)
@pytest.mark.parametrize("episode_key", ["episode", ("some", "episode")])
Expand All @@ -1853,6 +1855,7 @@ def test_slice_sampler(
batch_size,
num_slices,
slice_len,
prioritized,
episode_key,
done_key,
match_episode,
Expand Down Expand Up @@ -1897,19 +1900,34 @@ def test_slice_sampler(
else:
strict_length = True

sampler = SliceSampler(
num_slices=num_slices,
traj_key=episode_key,
end_key=done_key,
slice_len=slice_len,
strict_length=strict_length,
)
if prioritized:
num_steps = data.shape[0]
sampler = PrioritizedSliceSampler(
max_capacity=num_steps,
alpha=0.7,
beta=0.9,
num_slices=num_slices,
traj_key=episode_key,
end_key=done_key,
slice_len=slice_len,
strict_length=strict_length,
)
index = torch.arange(0, num_steps, 1)
sampler.extend(index)
else:
sampler = SliceSampler(
num_slices=num_slices,
traj_key=episode_key,
end_key=done_key,
slice_len=slice_len,
strict_length=strict_length,
)
if slice_len is not None:
num_slices = batch_size // slice_len
trajs_unique_id = set()
too_short = False
count_unique = set()
for _ in range(10):
for _ in range(30):
index, info = sampler.sample(storage, batch_size=batch_size)
if _data_prefix:
samples = storage._storage["_data"][index]
Expand All @@ -1918,6 +1936,7 @@ def test_slice_sampler(
if strict_length:
# check that trajs are ok
samples = samples.view(num_slices, -1)

assert samples["another_episode"].unique(
dim=1
).squeeze().shape == torch.Size([num_slices])
Expand All @@ -1936,6 +1955,7 @@ def test_slice_sampler(
raise AssertionError(
f"Not all items can be sampled: {set(range(100))-count_unique} are missing"
)

if strict_length:
assert not too_short
else:
Expand Down
152 changes: 151 additions & 1 deletion torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,7 @@ def _sample_slices(
truncated[seq_length.cumsum(0) - 1] = 1
traj_terminated = stop_idx[traj_idx] == start_idx[traj_idx] + seq_length - 1
terminated = torch.zeros_like(truncated)
if terminated.any():
if traj_terminated.any():
if isinstance(seq_length, int):
truncated.view(num_slices, -1)[traj_terminated] = 1
else:
Expand Down Expand Up @@ -1079,6 +1079,156 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
return SamplerWithoutReplacement.load_state_dict(self, state_dict)


class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler):
"""Samples slices of data along the first dimension, given start and stop signals, using prioritized sampling.
For more info see: SliceSampler and PrioritizedSampler
"""

def __init__(
self,
max_capacity: int,
alpha: float,
beta: float,
eps: float = 1e-8,
dtype: torch.dtype = torch.float,
reduction: str = "max",
*,
num_slices: int = None,
slice_len: int = None,
end_key: NestedKey | None = None,
traj_key: NestedKey | None = None,
ends: torch.Tensor | None = None,
trajectories: torch.Tensor | None = None,
cache_values: bool = False,
truncated_key: NestedKey | None = ("next", "truncated"),
strict_length: bool = True,
) -> object:
SliceSampler.__init__(
self,
num_slices=num_slices,
slice_len=slice_len,
end_key=end_key,
traj_key=traj_key,
cache_values=cache_values,
truncated_key=truncated_key,
strict_length=strict_length,
ends=ends,
trajectories=trajectories,
)
PrioritizedSampler.__init__(
self,
max_capacity=max_capacity,
alpha=alpha,
beta=beta,
eps=eps,
dtype=dtype,
reduction=reduction,
)

def __getstate__(self):
state = SliceSampler.__getstate__(self)
state.update(PrioritizedSampler.__getstate__(self))

def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]:
# Sample `batch_size` indices representing the start of a slice.
# The sampling is based on a weight vector.
start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
seq_length, num_slices = self._adjusted_batch_size(batch_size)

num_trajs = lengths.shape[0]
traj_idx = torch.arange(0, num_trajs, 1, device=lengths.device)

if (lengths < seq_length).any():
if self.strict_length:
raise RuntimeError(
"Some stored trajectories have a length shorter than the slice that was asked for. "
"Create the sampler with `strict_length=False` to allow shorter trajectories to appear "
"in you batch."
)
# make seq_length a tensor with values clamped by lengths
seq_length = lengths[traj_idx].clamp_max(seq_length)

# build a list of index that we dont want to sample: all the steps at a `seq_length` distance of
# the end the trajectory, with the end of trajectory (`stop_idx`) included
if isinstance(seq_length, int):
subtractive_idx = torch.arange(
0, seq_length - 1, 1, device=stop_idx.device, dtype=stop_idx.dtype
)
preceding_stop_idx = (
stop_idx[..., None] - subtractive_idx[None, ...]
).view(-1)
else:
raise NotImplementedError()
# preceding_stop_idx = torch.cat(
# [
# stop_idx
# - torch.arange(_seq_len, device=stop_idx.device, dtype=stop_idx.dtype)
# for stop_idx, _seq_len in zip(stop_idx, seq_length)
# ]
# )

# force to not sample index at the end of a trajectory.
# it's ok to not touch self._min_tree.
self._sum_tree[preceding_stop_idx] = 0.0
# self._min_tree[preceding_stop_idx] = 0.0

starts, info = PrioritizedSampler.sample(
self, storage=storage, batch_size=batch_size // seq_length
)
starts = torch.from_numpy(starts).to(device=lengths.device)
index = self._tensor_slices_from_startend(seq_length, starts)
assert index.shape[0] == batch_size

if self.truncated_key is not None:
# following logics borrowed from SliceSampler
truncated_key = self.truncated_key
done_key = _replace_last(truncated_key, "done")
terminated_key = _replace_last(truncated_key, "terminated")

truncated = torch.zeros(
(*index.shape, 1), dtype=torch.bool, device=index.device
)
if isinstance(seq_length, int):
truncated.view(num_slices, -1)[:, -1] = 1
else:
truncated[seq_length.cumsum(0) - 1] = 1
traj_terminated = stop_idx[traj_idx] == start_idx[traj_idx] + seq_length - 1
terminated = torch.zeros_like(truncated)
if traj_terminated.any():
if isinstance(seq_length, int):
truncated.view(num_slices, -1)[traj_terminated] = 1
else:
truncated[(seq_length.cumsum(0) - 1)[traj_terminated]] = 1
truncated = truncated & ~terminated
done = terminated | truncated

info.update(
{
truncated_key: truncated,
done_key: done,
terminated_key: terminated,
}
)
return index.to(torch.long), info

def _empty(self):
# no op for SliceSampler
PrioritizedSampler._empty(self)

def dumps(self, path):
# no op for SliceSampler
PrioritizedSampler.dumps(self, path)

def loads(self, path):
# no op for SliceSampler
return PrioritizedSampler.loads(self, path)

def state_dict(self):
# no op for SliceSampler
return PrioritizedSampler.state_dict(self)


class SamplerEnsemble(Sampler):
"""An ensemble of samplers.
Expand Down

0 comments on commit ac4fa9d

Please sign in to comment.