-
Notifications
You must be signed in to change notification settings - Fork 333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Saving episodes with variable steps in the replay buffer #1671
Comments
It's something that makes sense but is hard to come by. There are 2 other options: |
I guess you are right. I could write a new ReplayBuffer class which stores in step dim and maintains a dict for the episode index to the boundaries and randomly sample in episodes from this dict to get the trajectories. It seems doable but I was hoping there is something much simpler. |
Oh no I just meant that we'll need to adapt the |
I played a bit with indexing with offsets and torch.compile to check if we could get any speedup with that, but it doesn't seem very promising The CUDA version is on my cluster considerably slower than the cpu counterpart. import torch
import timeit
device = torch.device('cuda') if torch.cuda.device_count() else torch.device('cpu')
x = torch.randn(100_000, device=device)
idx0, _ = torch.randint(100_000-50, (256,), device=device).sort()
idx1 = torch.randint(10, 50, (256,), device=device) + idx0
idx0 = idx0.tolist()
idx1 = idx1.tolist()
def stack_idx(idx0, idx1):
out = []
for _idx0, _idx1 in zip(idx0, idx1):
out.append(torch.arange(_idx0, _idx1, device=device))
return torch.cat(out)
def index_x(x):
return x[stack_idx(idx0, idx1)]
c_index_x = torch.compile(index_x, fullgraph=True)
index_x(x)
c_index_x( x)
print(timeit.repeat("1+index_x(x)", globals={"index_x": index_x, "idx0": idx0, "idx1": idx1, "x": x}, number=100))
print(timeit.repeat("1+index_x(x)", globals={"index_x": c_index_x, "idx0": idx0, "idx1": idx1, "x": x}, number=100)) Result on cpu
Result on cuda
I also tested the speed of indexing something of similar size when the sequence length is predictable: x = x.reshape(100_000//50, 50)
idx0 = torch.randint(x.shape[0], (256,))
print(timeit.repeat("1+x[idx0]", globals={"idx0": idx0, "x": x}, number=100)) which is significantly faster
|
I see! At this point I might just something that I already have which saves everything to npz (since I am using images, the env generates a lot of data) and then uses the pytorch dataloaders to load into gpu memory. Let me know if I you want me test anything. Thanks! |
Can you explain why this would work better than our RBs? What feature is missing? In theory, working with torch.tensors and not passing to numpy arrays back and forth should be faster! |
Speed is not of too much concern to me right now since the worker thread prefetches the data anyways. I have more data than my main memory so I was looking for a nice and simple API that does all the loading, saving on the disk and this seems more work than I want to commit to right now. |
Closed by #1775 |
In an env when the episode terminates early, I would like to still save in the buffer and fixed size episodes from it. This currently fails for me giving the error
To Reproduce
The actual trace for my env is below
I distilled down the whole thing to the following script
Expected behavior
The tensordict should be saved? Ultimately what I really want is to have a buffer that can save variable amount of steps in different episodes.
System info
Describe the characteristic of your environment:
0.2.1 1.23.5 3.11.5 (main, Sep 2 2023, 14:16:33) [GCC 13.2.1 20230801] linux
The text was updated successfully, but these errors were encountered: