Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Saving episodes with variable steps in the replay buffer #1671

Closed
truncs opened this issue Nov 3, 2023 · 8 comments
Closed

[BUG] Saving episodes with variable steps in the replay buffer #1671

truncs opened this issue Nov 3, 2023 · 8 comments
Assignees
Labels
bug Something isn't working

Comments

@truncs
Copy link

truncs commented Nov 3, 2023

In an env when the episode terminates early, I would like to still save in the buffer and fixed size episodes from it. This currently fails for me giving the error

RuntimeError: indexed destination TensorDict batch size is torch.Size([1001]) (batch_size = torch.Size([1000, 1001]), index=5), which differs from the source batch size torch.Size([40])

To Reproduce

The actual trace for my env is below

Traceback (most recent call last):
  File "/env/lib/python3.11/site-packages/tensordict/tensordict.py", line 3754, in __setitem__
    value = value.expand(indexed_bs)
            ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/env/lib/python3.11/site-packages/tensordict/tensordict.py", line 4430, in expand
    raise RuntimeError(
RuntimeError: Incompatible expanded shape: The expanded shape length at non-singleton dimension should be same as the original length. target_shape = (1001,), existing_shape = torch.Size([40])

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "train.py", line 56, in train
    trainer.train()
  File "trainer/online_trainer.py", line 93, in train
    self._ep_idx = self.buffer.add(torch.cat(self._tds))
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "buffer.py", line 99, in add
    self._buffer.add(tds)
  File "/env/lib/python3.11/site-packages/torchrl/data/replay_buffers/replay_buffers.py", line 255, in add
    return self._add(data)
           ^^^^^^^^^^^^^^^
  File "/env/lib/python3.11/site-packages/torchrl/data/replay_buffers/replay_buffers.py", line 259, in _add
    index = self._writer.add(data)
            ^^^^^^^^^^^^^^^^^^^^^^
  File "/env/lib/python3.11/site-packages/torchrl/data/replay_buffers/writers.py", line 55, in add
    self._storage[self._cursor] = data
    ~~~~~~~~~~~~~^^^^^^^^^^^^^^
  File "/zfs/aditya/workspace/tdmpc2/env/lib/python3.11/site-packages/torchrl/data/replay_buffers/storages.py", line 72, in __setitem__
    ret = self.set(index, value)
          ^^^^^^^^^^^^^^^^^^^^^^
  File "/env/lib/python3.11/site-packages/torchrl/data/replay_buffers/storages.py", line 323, in set
    self._storage[cursor] = data
    ~~~~~~~~~~~~~^^^^^^^^
  File "/env/lib/python3.11/site-packages/tensordict/tensordict.py", line 3756, in __setitem__
    raise RuntimeError(
RuntimeError: indexed destination TensorDict batch size is torch.Size([1001]) (batch_size = torch.Size([1000, 1001]), index=5), which differs from the source batch size torch.Size([40])

I distilled down the whole thing to the following script

import torchrl
from torchrl.data.replay_buffers.storages import TensorStorage
data = TensorDict({
                  "some data": torch.randn(1, 10, 11),
                 ("some", "nested", "data"): torch.randn(1, 10, 11, 12),
            }, batch_size=[1, 10, 11])
store = TensorStorage(data)
data = TensorDict({
                  "some data": torch.randn(1, 11, 11),
                 ("some", "nested", "data"): torch.randn(1, 11, 11, 12),
            }, batch_size=[1, 10, 11])
store[1] = data

Expected behavior

The tensordict should be saved? Ultimately what I really want is to have a buffer that can save variable amount of steps in different episodes.

System info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, source, ...)
  • Python version
  • Versions of any other relevant libraries
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)

0.2.1 1.23.5 3.11.5 (main, Sep 2 2023, 14:16:33) [GCC 13.2.1 20230801] linux

@truncs truncs added the bug Something isn't working label Nov 3, 2023
@truncs truncs changed the title [BUG] Saving episodes of different size in the replay buffer [BUG] Saving episodes with variable steps in the replay buffer Nov 3, 2023
@vmoens
Copy link
Contributor

vmoens commented Nov 3, 2023

It's something that makes sense but is hard to come by.
One way to do that would be to store everything along the same "time" dimension and index with some offset.
That would be more costly than simply sampling when the shape is predictable (like we do now) so I would expect some perf hit. That also means that you will need to create your storage accordingly (ie, tell it ahead of time that it has to store data of different shape along some dimension). It's a fun thing to work on, i'd love to give it a go!

There are 2 other options: nested_tensor (hard to implement when we don't know the traj length ahead of time) or padding (see tensordict.pad_sequence)

@truncs
Copy link
Author

truncs commented Nov 3, 2023

I guess you are right. I could write a new ReplayBuffer class which stores in step dim and maintains a dict for the episode index to the boundaries and randomly sample in episodes from this dict to get the trajectories. It seems doable but I was hoping there is something much simpler.
" tell it ahead of time that it has to store data of different shape along some dimension)"
I am not sure what do you mean by that? If I use per step has my storage dim then I don't have the variable dims problem? Or are you suggesting to save accumulated steps after every t secs or so?

@vmoens
Copy link
Contributor

vmoens commented Nov 5, 2023

Oh no I just meant that we'll need to adapt the add and extend as well as sample methods to match the logic if an item in the buffer corresponds to a range between two offsets.

@vmoens
Copy link
Contributor

vmoens commented Nov 5, 2023

I played a bit with indexing with offsets and torch.compile to check if we could get any speedup with that, but it doesn't seem very promising
In this script, I create a dummy dataset of 100K scalar values, and index them with random slices of specific lengths.

The CUDA version is on my cluster considerably slower than the cpu counterpart.

import torch
import timeit
device = torch.device('cuda') if torch.cuda.device_count() else torch.device('cpu')
x = torch.randn(100_000, device=device)
idx0, _ = torch.randint(100_000-50, (256,), device=device).sort()
idx1 = torch.randint(10, 50, (256,), device=device) + idx0
idx0 = idx0.tolist()
idx1 = idx1.tolist()
def stack_idx(idx0, idx1):
    out = []
    for _idx0, _idx1 in zip(idx0, idx1):
        out.append(torch.arange(_idx0, _idx1, device=device))
    return torch.cat(out)

def index_x(x):
    return x[stack_idx(idx0, idx1)]

c_index_x = torch.compile(index_x, fullgraph=True)

index_x(x)
c_index_x( x)

print(timeit.repeat("1+index_x(x)", globals={"index_x": index_x, "idx0": idx0, "idx1": idx1, "x": x}, number=100))
print(timeit.repeat("1+index_x(x)", globals={"index_x": c_index_x, "idx0": idx0, "idx1": idx1, "x": x}, number=100))

Result on cpu

[0.1146245559793897, 0.11403853702358902, 0.11392714199610054, 0.11378934100503102, 0.113719830987975]
[0.11027653800556436, 0.10902649199124426, 0.10891503200400621, 0.10922098101582378, 0.10902743798214942]

Result on cuda

[0.21882190398173407, 0.21670885902130976, 0.21903034401475452, 0.2166661199880764, 0.21840859300573356]
[0.4096674190077465, 0.4093878800049424, 0.4065087260096334, 0.40326547500444576, 0.40778385000885464]

I also tested the speed of indexing something of similar size when the sequence length is predictable:

x = x.reshape(100_000//50, 50)
idx0 = torch.randint(x.shape[0], (256,))
print(timeit.repeat("1+x[idx0]", globals={"idx0": idx0, "x": x}, number=100))

which is significantly faster

[0.005107899999984511, 0.0038608259999364236, 0.0038263060000645055, 0.003311509000013757, 0.003181937000135804]

@truncs
Copy link
Author

truncs commented Nov 6, 2023

I see! At this point I might just something that I already have which saves everything to npz (since I am using images, the env generates a lot of data) and then uses the pytorch dataloaders to load into gpu memory. Let me know if I you want me test anything. Thanks!

@vmoens
Copy link
Contributor

vmoens commented Nov 6, 2023

Can you explain why this would work better than our RBs? What feature is missing?

In theory, working with torch.tensors and not passing to numpy arrays back and forth should be faster!

@truncs
Copy link
Author

truncs commented Nov 6, 2023

Speed is not of too much concern to me right now since the worker thread prefetches the data anyways. I have more data than my main memory so I was looking for a nice and simple API that does all the loading, saving on the disk and this seems more work than I want to commit to right now.

@vmoens vmoens mentioned this issue Dec 15, 2023
6 tasks
@vmoens
Copy link
Contributor

vmoens commented Mar 5, 2024

Closed by #1775

@vmoens vmoens closed this as completed Mar 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants