Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] batched trajectories - SliceSampler compatibility #1775

Merged
merged 34 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
ffcdf51
init
vmoens Jan 8, 2024
d567665
Merge remote-tracking branch 'origin/main' into extend-rb-dim1
vmoens Feb 6, 2024
a05eaae
amend
vmoens Feb 6, 2024
34fadac
Merge remote-tracking branch 'origin/main' into extend-rb-dim1
vmoens Feb 7, 2024
fe960fa
amend
vmoens Feb 7, 2024
91941ea
amend
vmoens Feb 7, 2024
5e06cd4
amend
vmoens Feb 7, 2024
2523ab0
amend
vmoens Feb 8, 2024
d1e8878
amend
vmoens Feb 12, 2024
238e9b3
amend
vmoens Feb 12, 2024
eb731a4
Merge remote-tracking branch 'origin/main' into extend-rb-dim1
vmoens Feb 13, 2024
cdd5f9a
amend
vmoens Feb 14, 2024
775ffb6
amend
vmoens Feb 14, 2024
01bcc79
amend
vmoens Feb 15, 2024
136ad70
amend
vmoens Feb 15, 2024
2fbb08b
Merge remote-tracking branch 'origin/main' into extend-rb-dim1
vmoens Feb 15, 2024
2c717c8
amend
vmoens Feb 15, 2024
94a9b61
amend
vmoens Feb 16, 2024
358a15f
amend
vmoens Feb 16, 2024
1495c37
amend
vmoens Feb 16, 2024
08e4084
Merge remote-tracking branch 'origin/main' into extend-rb-dim1
vmoens Feb 20, 2024
b049c49
Merge remote-tracking branch 'origin/main' into extend-rb-dim1
vmoens Feb 21, 2024
253565d
fix
vmoens Feb 21, 2024
fb9e562
fix
vmoens Feb 21, 2024
b349086
fix
vmoens Feb 21, 2024
9094798
fix
vmoens Feb 21, 2024
49f32b8
fix
vmoens Feb 21, 2024
9aa80ee
fix
vmoens Feb 21, 2024
ff535a6
fix
vmoens Feb 21, 2024
ee61ed5
Merge remote-tracking branch 'origin/main' into extend-rb-dim1
vmoens Feb 21, 2024
15f3d53
amend
vmoens Feb 21, 2024
d7118d0
amend
vmoens Feb 22, 2024
9612977
amend
vmoens Feb 22, 2024
36ea29f
amend
vmoens Feb 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 106 additions & 2 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -218,24 +218,128 @@ Storing trajectories
~~~~~~~~~~~~~~~~~~~~

It is not too difficult to store trajectories in the replay buffer.
One element to pay attention to is that the size of the replay buffer is always
One element to pay attention to is that the size of the replay buffer is by default
the size of the leading dimension of the storage: in other words, creating a
replay buffer with a storage of size 1M when storing multidimensional data
does not mean storing 1M frames but 1M trajectories. However, if trajectories
(or episodes/rollouts) are flattened before being stored, the capacity will still
be 1M steps.

There is a way to circumvent this by telling the storage how many dimensions
it should take into account when saving data. This can be done through the ``ndim``
keyword argument which is accepted by all contiguous storages such as
:class:`~torchrl.data.replay_buffers.TensorStorage` and the likes. When a
multidimensional storage is passed to a buffer, the buffer will automatically
consider the last dimension as the "time" dimension, as it is conventional in
TorchRL. This can be overridden through the ``dim_extend`` keyword argument
in :class:`~torchrl.data.ReplayBuffer`.
This is the recommended way to save trajectories that are obtained through
:class:`~torchrl.envs.ParallelEnv` or its serial counterpart, as we will see
below.

When sampling trajectories, it may be desirable to sample sub-trajectories
to diversify learning or make the sampling more efficient.
TorchRL offers two distinctive ways of accomplishing this:

- The :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` allows to
sample a given number of slices of trajectories stored one after another
along the leading dimension of the :class:`~torchrl.data.replay_buffers.samplers.TensorStorage`.
This is the recommended way of sampling sub-trajectories in TorchRL __especially__
when using offline datasets (which are stored using that convention).
This strategy requires to flatten the trajectories before extending the replay
buffer and reshaping them after sampling. The :class:`~torchrl.data.replay_buffers.samplers.SliceSampler`
buffer and reshaping them after sampling.
The :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` class docstrings
gives extensive details about this storage and sampling strategy.
Note that :class:`~torchrl.data.replay_buffers.samplers.SliceSampler`
is compatible with multidimensional storages. The following examples show
how to use this feature with and without flattening of the tensordict.
In the first scenario, we are collecting data from a single environment. In
that case, we are happy with a storage that concatenates the data coming in
along the first dimension, since there will be no interruption introduced
by the collection schedule:

>>> from torchrl.envs import TransformedEnv, StepCounter, GymEnv
>>> from torchrl.collectors import SyncDataCollector, RandomPolicy
>>> from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler
>>> env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter())
>>> collector = SyncDataCollector(env,
... RandomPolicy(env.action_spec),
... frames_per_batch=10, total_frames=-1)
>>> rb = ReplayBuffer(
... storage=LazyTensorStorage(100),
... sampler=SliceSampler(num_slices=8, traj_key=("collector", "traj_ids"),
... truncated_key=None, strict_length=False),
... batch_size=64)
>>> for i, data in enumerate(collector):
... rb.extend(data)
... if i == 10:
... break
>>> assert len(rb) == 100, len(rb)
>>> print(rb[:]["next", "step_count"])
tensor([[32],
[33],
[34],
[35],
[36],
[37],
[38],
[39],
[40],
[41],
[11],
[12],
[13],
[14],
[15],
[16],
[17],
[...

If there are more than one environment run in a batch, we could still store
the data in the same buffer as before by calling ``data.reshape(-1)`` which
will flatten the ``[B, T]`` size into ``[B * T]`` but that means that the
trajectories of, say, the first environment of the batch will be interleaved
by trajectories of the other environments, a scenario that ``SliceSampler``
cannot handle. To solve this, we suggest to use the ``ndim`` argument in the
storage constructor:

>>> env = TransformedEnv(SerialEnv(2,
... lambda: GymEnv("CartPole-v1")), StepCounter())
>>> collector = SyncDataCollector(env,
... RandomPolicy(env.action_spec),
... frames_per_batch=1, total_frames=-1)
>>> rb = ReplayBuffer(
... storage=LazyTensorStorage(100, ndim=2),
... sampler=SliceSampler(num_slices=8, traj_key=("collector", "traj_ids"),
... truncated_key=None, strict_length=False),
... batch_size=64)
>>> for i, data in enumerate(collector):
... rb.extend(data)
... if i == 100:
... break
>>> assert len(rb) == 100, len(rb)
>>> print(rb[:]["next", "step_count"].squeeze())
tensor([[ 6, 5],
[ 2, 2],
[ 3, 3],
[ 4, 4],
[ 5, 5],
[ 6, 6],
[ 7, 7],
[ 8, 8],
[ 9, 9],
[10, 10],
[11, 11],
[12, 12],
[13, 13],
[14, 14],
[15, 15],
[16, 16],
[17, 17],
[18, 1],
[19, 2],
[...


- Trajectories can also be stored independently, with the each element of the
leading dimension pointing to a different trajectory. This requires
Expand Down
Loading
Loading