Skip to content

[Feature Request] Multi-dim RBs #1885

Closed
@vmoens

Description

@vmoens

I'm looking for feedback regarding a feature that I'd like to ship v0.4.

Related PR: #1775

Pitch

Currently, we assume that RBs are built as 1d datasets from which you can sample items randomly or using slices.

Now imagine that you have a ParallelEnv (n=4 workers) within a data collector that spits out slices of trajectories (len=100) and you would like to store them all together in the buffer. Assume the envs are non-terminating (you have a gigantic trajectory in the end).

Currently you could:

  • flatten it in a tensordict of shape [400] and extend the buffer with that
  • store it as 4 different elements and have a storage of length [M, 100].

In the first case, you can use a regular sampler and get single transitions, or use a slice sampler and get sub-trajectories of length < 100.

In the second case, a regular sampler will give you trajectories of length 100 (always) unless you use a RandomCropTransform and get random slices out of it. Those slices will always be within the 100 steps window that you stored.

Another difference between these two approaches is the length of the buffer: if you want to store 1M transitions in the first case, you will just create a storage of size 1M. In the second, you will do 1M/100 and store 10K sub-trajs of length 100.

Both work with torchrl currently but it isn't still what people really want, I think. The best would be to store the results from the 4 envs together, hence have a buffer of size [4, 250K] (or [250K, 4]).

See this comment in particular: #1079 (comment)

Our API is heavily built on the assumption that time is along the first dimension, so my initial thinking was to just transpose the time dimension and the first dimension of the tensordict when populating the tensordict (see the description of the PR to see how that works). That works ok and users won't notice, so far so good.

The tricky part is how to use SliceSampler but in the PR I managed to make that base class work (will look into the version without repetitions and the prioritized version soon).

Asking for opinions

(1) that's an awful lot of ways of doing the same thing. Should we clean it up? Do we need 3 ways of doing the same thing?

(2) the one thing I can't get my head around is the capacity of the buffer. First thing is that we want to have a clear message regarding what we intend by capacity. For instance, with #1775 we build a buffer of capacity 250K (the total number of transitions for the 4 envs together) but when we indicate the batch size with slice sampler what we ask is a number of transitions, regardless of the number of envs. That discrepancy isn't good: the ideal soultion should be to have a buffer capacity of 1M, bringing together the capacity and batch size (both measures single env transitions).
Very well you tell me, but in the second solution above where we had a buffer of capacity 10K we were dividing 1M by the number of time steps! That's all very confusing.
You may ask: why this weird choice of not accounting for the 2nd or 3rd dim when measuring the number of elements in a tensordict? Why is the capacity restricted to the first dim? The answer is simply that we want interoperability between the RB and its storage, whether it's a contiguous storage or a list. If it's a list, we unbind along the first dimension, meaning that only the first dim counts. That's the one restriction that made it necessary to count the first dim only when measuring buffer capacity.

Should we unify all of this (at the risk of doing bc-breaking changes) and always measure a buffer capacity in terms of transitions per env in a batch?
If so, here is how I would approach this issue:

All storages will have a ndim attribute. At first, ndim will default to 1 (your buffer is 1d, capacity is measured along the first dim only).
Examples:

storage = SomeStorage(100, ndim=1)
# create rb here
rb.extend(TensorDict({}, batch_size=[4])
assert len(rb)==4
# reset rb here
rb.extend(TensorDict({}, batch_size=[4, 5])
assert len(rb)==4

storage = SomeStorage(100, ndim=2)
# create rb here
rb.extend(TensorDict({}, batch_size=[4])  # Error!
# reset rb here
rb.extend(TensorDict({}, batch_size=[4, 5])
assert len(rb)==20

This can work for ListStorage too (we just need to check the type of the data, and if it's a pytree or a tensordict we grab the dims

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions