Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] batched trajectories - SliceSampler compatibility #1775

Merged
merged 34 commits into from
Feb 22, 2024
Merged

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Jan 8, 2024

What does this PR do?

We introduce a new way of extending replay buffers that is more faithful to the way trajectories are organised in torchrl.
Improvements:

  • Replay buffers can handle an arbitrary number of dimensions and users can choose which dimension to use for extension.
  • Replay buffer size can reflect the total number of transitions even for multi-dimensional storages.
  • The samplers (incl. SliceSampler and subclasses) can handle different batch dimensions. For multidim storages, samples are always flattened (as if the storage was indexed by a mask)

Why is this change needed

We currently have 2 ways of extending a buffer which I summarize in this ugly drawings

image
image

The first attempt we had at storing trajectories was to build the buffer according to

T = ... # length of a trajectory
B = ... # number of envs simulated at once in the collector or in ParallelEnv
rb = ReplayBuffer(storage=LazyTensorStorage(max_size=total_num_transitions // T), transfrom=RandomCropTensorDict(...))
for data in collector:
   # data has shape [B, T]
   rb.extend(data)

Sampling consists in getting trajectories returned by the collector. This isn't amazing because trajectories could be interrupted in the middle with a truncated/terminated signal, or simply be longer than T.
The other issue is that your buffer size isn't the number of transitions but the number of traj stored. One can immediately see why this isn't as flexible as one could hope for!
Finaly drawback: subsampling slices needed to be done via

class RandomCropTensorDict(Transform):
which is slow.

The second attempt was simply to flatten the tensordict and use SliceSampler to get sub-trajectories

T = ... # length of a trajectory
B = ... # number of envs simulated at once in the collector or in ParallelEnv
rb = ReplayBuffer(storage=LazyTensorStorage(max_size=total_num_transitions), sampler=SliceSampler(...))
for data in collector:
   # data has shape [B, T]
   rb.extend(data.reshape(-1))

Here the drawings above come in handy: we still have the problem that trajectories are not contiguous and can be truncated in multiple places if a ParallelEnv is used.
But it's closer to what people want and the buffer size is accurate.

New API

The new API is quite simple: one just needs to pass the number of dimensions to the buffer and all the other components will know what to do:

T = ... # length of a trajectory
B = ... # number of envs simulated at once in the collector or in ParallelEnv
rb = ReplayBuffer(storage=LazyTensorStorage(max_size=total_num_transitions, ndim=2), sampler=SliceSampler(...))
for data in collector:
   # data has shape [B, T]
   rb.extend(data) # rb is `data.numel()` longer!

This is the second picture I pasted earlier. The big advantage with SyncCollector + ParallelEnv is that you now have trajectories that are well structured and as long as you could wish.

## Other changes

  • Using tensorclasses within TensorDictReplayBuffer is a bit buggy because we can't store the indices so I made this not possible anymore. I believe it is the only bc-breaking change of this PR, but all the features that used to be supported still are via a ReplayBuffer usage.
  • I removed the "_data" prefix in storages placed within TensorDictReplayBuffer instances.

Gist: https://gist.github.com/vmoens/b928af1a3a9567b1f0862a58ec592f85

cc @btx0424 @albertbou92 @Cadene

Copy link

pytorch-bot bot commented Jan 8, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/1775

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Merge Blocking SEVs

There is 1 active merge blocking SEVs. Please view them below:

If you must merge, use @pytorchbot merge -f.

⏳ No Failures, 42 Pending

As of commit 36ea29f with merge base bb44067 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 8, 2024
Copy link

github-actions bot commented Jan 8, 2024

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 89. Improved: $\large\color{#35bf28}19$. Worsened: $\large\color{#d91a1a}4$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_single 61.7605ms 61.4829ms 16.2647 Ops/s 15.5168 Ops/s $\color{#35bf28}+4.82\%$
test_sync 38.9363ms 33.4051ms 29.9356 Ops/s 28.4242 Ops/s $\textbf{\color{#35bf28}+5.32\%}$
test_async 55.3167ms 31.2616ms 31.9881 Ops/s 30.6990 Ops/s $\color{#35bf28}+4.20\%$
test_simple 0.5124s 0.4391s 2.2772 Ops/s 2.3381 Ops/s $\color{#d91a1a}-2.61\%$
test_transformed 0.6322s 0.5858s 1.7071 Ops/s 1.7136 Ops/s $\color{#d91a1a}-0.38\%$
test_serial 1.4805s 1.4325s 0.6981 Ops/s 0.6939 Ops/s $\color{#35bf28}+0.60\%$
test_parallel 1.4721s 1.4204s 0.7040 Ops/s 0.7030 Ops/s $\color{#35bf28}+0.15\%$
test_step_mdp_speed[True-True-True-True-True] 0.1567ms 22.0502μs 45.3510 KOps/s 46.6218 KOps/s $\color{#d91a1a}-2.73\%$
test_step_mdp_speed[True-True-True-True-False] 46.7070μs 13.3538μs 74.8849 KOps/s 74.8470 KOps/s $\color{#35bf28}+0.05\%$
test_step_mdp_speed[True-True-True-False-True] 43.0710μs 12.8596μs 77.7626 KOps/s 79.8093 KOps/s $\color{#d91a1a}-2.56\%$
test_step_mdp_speed[True-True-True-False-False] 46.6170μs 7.6798μs 130.2115 KOps/s 131.2072 KOps/s $\color{#d91a1a}-0.76\%$
test_step_mdp_speed[True-True-False-True-True] 58.1290μs 23.6290μs 42.3209 KOps/s 43.7154 KOps/s $\color{#d91a1a}-3.19\%$
test_step_mdp_speed[True-True-False-True-False] 56.0450μs 14.6297μs 68.3541 KOps/s 68.1826 KOps/s $\color{#35bf28}+0.25\%$
test_step_mdp_speed[True-True-False-False-True] 41.9590μs 13.9872μs 71.4942 KOps/s 72.5981 KOps/s $\color{#d91a1a}-1.52\%$
test_step_mdp_speed[True-True-False-False-False] 52.0170μs 8.9891μs 111.2456 KOps/s 112.0181 KOps/s $\color{#d91a1a}-0.69\%$
test_step_mdp_speed[True-False-True-True-True] 54.6120μs 24.9586μs 40.0664 KOps/s 40.7573 KOps/s $\color{#d91a1a}-1.70\%$
test_step_mdp_speed[True-False-True-True-False] 55.3240μs 16.0185μs 62.4277 KOps/s 62.8074 KOps/s $\color{#d91a1a}-0.60\%$
test_step_mdp_speed[True-False-True-False-True] 65.9940μs 14.0679μs 71.0838 KOps/s 72.8166 KOps/s $\color{#d91a1a}-2.38\%$
test_step_mdp_speed[True-False-True-False-False] 26.3090μs 9.0204μs 110.8594 KOps/s 111.7196 KOps/s $\color{#d91a1a}-0.77\%$
test_step_mdp_speed[True-False-False-True-True] 80.9910μs 25.9993μs 38.4626 KOps/s 38.9732 KOps/s $\color{#d91a1a}-1.31\%$
test_step_mdp_speed[True-False-False-True-False] 45.6960μs 17.2574μs 57.9462 KOps/s 58.6715 KOps/s $\color{#d91a1a}-1.24\%$
test_step_mdp_speed[True-False-False-False-True] 49.7930μs 15.0836μs 66.2972 KOps/s 66.9598 KOps/s $\color{#d91a1a}-0.99\%$
test_step_mdp_speed[True-False-False-False-False] 32.6910μs 10.2589μs 97.4759 KOps/s 98.6249 KOps/s $\color{#d91a1a}-1.17\%$
test_step_mdp_speed[False-True-True-True-True] 57.3270μs 24.8973μs 40.1650 KOps/s 40.7136 KOps/s $\color{#d91a1a}-1.35\%$
test_step_mdp_speed[False-True-True-True-False] 55.3840μs 16.1713μs 61.8378 KOps/s 62.3015 KOps/s $\color{#d91a1a}-0.74\%$
test_step_mdp_speed[False-True-True-False-True] 38.9530μs 16.4260μs 60.8791 KOps/s 62.2316 KOps/s $\color{#d91a1a}-2.17\%$
test_step_mdp_speed[False-True-True-False-False] 43.3820μs 10.3518μs 96.6017 KOps/s 98.5983 KOps/s $\color{#d91a1a}-2.02\%$
test_step_mdp_speed[False-True-False-True-True] 42.6300μs 26.7426μs 37.3935 KOps/s 38.3560 KOps/s $\color{#d91a1a}-2.51\%$
test_step_mdp_speed[False-True-False-True-False] 47.6090μs 17.4203μs 57.4043 KOps/s 58.1870 KOps/s $\color{#d91a1a}-1.35\%$
test_step_mdp_speed[False-True-False-False-True] 62.2870μs 17.5520μs 56.9736 KOps/s 58.1164 KOps/s $\color{#d91a1a}-1.97\%$
test_step_mdp_speed[False-True-False-False-False] 31.1980μs 11.6036μs 86.1805 KOps/s 88.2029 KOps/s $\color{#d91a1a}-2.29\%$
test_step_mdp_speed[False-False-True-True-True] 70.3720μs 27.3795μs 36.5236 KOps/s 37.0359 KOps/s $\color{#d91a1a}-1.38\%$
test_step_mdp_speed[False-False-True-True-False] 41.5080μs 18.6850μs 53.5189 KOps/s 53.9434 KOps/s $\color{#d91a1a}-0.79\%$
test_step_mdp_speed[False-False-True-False-True] 63.2690μs 17.5990μs 56.8215 KOps/s 57.9755 KOps/s $\color{#d91a1a}-1.99\%$
test_step_mdp_speed[False-False-True-False-False] 34.6550μs 11.6243μs 86.0267 KOps/s 88.2819 KOps/s $\color{#d91a1a}-2.55\%$
test_step_mdp_speed[False-False-False-True-True] 76.1730μs 28.6207μs 34.9397 KOps/s 35.8407 KOps/s $\color{#d91a1a}-2.51\%$
test_step_mdp_speed[False-False-False-True-False] 68.4680μs 19.7379μs 50.6638 KOps/s 49.1215 KOps/s $\color{#35bf28}+3.14\%$
test_step_mdp_speed[False-False-False-False-True] 43.6220μs 18.5895μs 53.7937 KOps/s 54.3925 KOps/s $\color{#d91a1a}-1.10\%$
test_step_mdp_speed[False-False-False-False-False] 58.9200μs 12.5987μs 79.3734 KOps/s 79.8398 KOps/s $\color{#d91a1a}-0.58\%$
test_values[generalized_advantage_estimate-True-True] 11.2049ms 9.3930ms 106.4622 Ops/s 111.2378 Ops/s $\color{#d91a1a}-4.29\%$
test_values[vec_generalized_advantage_estimate-True-True] 36.6587ms 33.2840ms 30.0444 Ops/s 28.4262 Ops/s $\textbf{\color{#35bf28}+5.69\%}$
test_values[td0_return_estimate-False-False] 0.2310ms 0.1897ms 5.2711 KOps/s 6.0523 KOps/s $\textbf{\color{#d91a1a}-12.91\%}$
test_values[td1_return_estimate-False-False] 23.7823ms 23.4988ms 42.5554 Ops/s 44.7817 Ops/s $\color{#d91a1a}-4.97\%$
test_values[vec_td1_return_estimate-False-False] 47.8651ms 33.7301ms 29.6471 Ops/s 28.2776 Ops/s $\color{#35bf28}+4.84\%$
test_values[td_lambda_return_estimate-True-False] 53.8597ms 33.7039ms 29.6702 Ops/s 31.0197 Ops/s $\color{#d91a1a}-4.35\%$
test_values[vec_td_lambda_return_estimate-True-False] 34.8422ms 33.2351ms 30.0887 Ops/s 28.3415 Ops/s $\textbf{\color{#35bf28}+6.16\%}$
test_gae_speed[generalized_advantage_estimate-False-1-512] 9.0191ms 8.1233ms 123.1026 Ops/s 125.2650 Ops/s $\color{#d91a1a}-1.73\%$
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 2.2621ms 1.9807ms 504.8803 Ops/s 504.6952 Ops/s $\color{#35bf28}+0.04\%$
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] 0.4238ms 0.3459ms 2.8911 KOps/s 2.9210 KOps/s $\color{#d91a1a}-1.02\%$
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] 42.9739ms 40.4440ms 24.7255 Ops/s 22.3914 Ops/s $\textbf{\color{#35bf28}+10.42\%}$
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 4.1889ms 3.0309ms 329.9346 Ops/s 331.6252 Ops/s $\color{#d91a1a}-0.51\%$
test_dqn_speed 6.9495ms 1.3933ms 717.7203 Ops/s 717.7059 Ops/s $+0.00\%$
test_ddpg_speed 4.7414ms 2.7765ms 360.1635 Ops/s 361.5827 Ops/s $\color{#d91a1a}-0.39\%$
test_sac_speed 73.8254ms 8.8557ms 112.9212 Ops/s 120.7015 Ops/s $\textbf{\color{#d91a1a}-6.45\%}$
test_redq_speed 14.2549ms 13.1249ms 76.1911 Ops/s 73.5567 Ops/s $\color{#35bf28}+3.58\%$
test_redq_deprec_speed 13.8899ms 13.2558ms 75.4388 Ops/s 76.6079 Ops/s $\color{#d91a1a}-1.53\%$
test_td3_speed 8.6680ms 8.3621ms 119.5865 Ops/s 122.3667 Ops/s $\color{#d91a1a}-2.27\%$
test_cql_speed 37.7721ms 36.2090ms 27.6175 Ops/s 26.3220 Ops/s $\color{#35bf28}+4.92\%$
test_a2c_speed 8.6546ms 7.4201ms 134.7685 Ops/s 132.4834 Ops/s $\color{#35bf28}+1.72\%$
test_ppo_speed 8.8849ms 7.6707ms 130.3670 Ops/s 127.3111 Ops/s $\color{#35bf28}+2.40\%$
test_reinforce_speed 7.2483ms 6.6013ms 151.4852 Ops/s 144.2457 Ops/s $\textbf{\color{#35bf28}+5.02\%}$
test_iql_speed 33.8294ms 32.5761ms 30.6974 Ops/s 30.8052 Ops/s $\color{#d91a1a}-0.35\%$
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 2.5044ms 2.2437ms 445.6973 Ops/s 357.2003 Ops/s $\textbf{\color{#35bf28}+24.78\%}$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 0.9650ms 0.4969ms 2.0127 KOps/s 1.9663 KOps/s $\color{#35bf28}+2.36\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 0.9056ms 0.5024ms 1.9906 KOps/s 1.9987 KOps/s $\color{#d91a1a}-0.40\%$
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 3.0501ms 2.2505ms 444.3364 Ops/s 358.3803 Ops/s $\textbf{\color{#35bf28}+23.98\%}$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 1.0570ms 0.4896ms 2.0427 KOps/s 1.9901 KOps/s $\color{#35bf28}+2.64\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 0.7268ms 0.4675ms 2.1389 KOps/s 2.1060 KOps/s $\color{#35bf28}+1.56\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 3.4259ms 2.4072ms 415.4250 Ops/s 346.6520 Ops/s $\textbf{\color{#35bf28}+19.84\%}$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 1.1513ms 0.6113ms 1.6358 KOps/s 1.6057 KOps/s $\color{#35bf28}+1.87\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 1.0210ms 0.5887ms 1.6986 KOps/s 1.6806 KOps/s $\color{#35bf28}+1.07\%$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 2.4080ms 2.2041ms 453.7061 Ops/s 361.8144 Ops/s $\textbf{\color{#35bf28}+25.40\%}$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 0.5914ms 0.4932ms 2.0275 KOps/s 1.9594 KOps/s $\color{#35bf28}+3.47\%$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 90.9529ms 0.5666ms 1.7650 KOps/s 2.0816 KOps/s $\textbf{\color{#d91a1a}-15.21\%}$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 3.2751ms 2.2741ms 439.7311 Ops/s 356.6487 Ops/s $\textbf{\color{#35bf28}+23.30\%}$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 1.0627ms 0.4897ms 2.0419 KOps/s 1.9883 KOps/s $\color{#35bf28}+2.70\%$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 0.8678ms 0.4710ms 2.1232 KOps/s 2.0955 KOps/s $\color{#35bf28}+1.32\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 3.4008ms 2.3610ms 423.5445 Ops/s 342.6617 Ops/s $\textbf{\color{#35bf28}+23.60\%}$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 88.4038ms 0.7100ms 1.4085 KOps/s 1.5917 KOps/s $\textbf{\color{#d91a1a}-11.51\%}$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 0.8451ms 0.5899ms 1.6953 KOps/s 1.6538 KOps/s $\color{#35bf28}+2.51\%$
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 87.5562ms 5.2885ms 189.0889 Ops/s 134.4109 Ops/s $\textbf{\color{#35bf28}+40.68\%}$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 14.0285ms 11.8655ms 84.2782 Ops/s 68.0966 Ops/s $\textbf{\color{#35bf28}+23.76\%}$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 1.6598ms 1.0415ms 960.1530 Ops/s 394.2132 Ops/s $\textbf{\color{#35bf28}+143.56\%}$
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 86.0137ms 6.8708ms 145.5441 Ops/s 140.4029 Ops/s $\color{#35bf28}+3.66\%$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 14.2004ms 11.8091ms 84.6807 Ops/s 76.0607 Ops/s $\textbf{\color{#35bf28}+11.33\%}$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 4.2582ms 1.1295ms 885.3269 Ops/s 392.8976 Ops/s $\textbf{\color{#35bf28}+125.33\%}$
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 88.2532ms 7.3028ms 136.9344 Ops/s 109.5690 Ops/s $\textbf{\color{#35bf28}+24.98\%}$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 14.8371ms 12.1849ms 82.0690 Ops/s 74.9371 Ops/s $\textbf{\color{#35bf28}+9.52\%}$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 2.0048ms 1.3395ms 746.5291 Ops/s 360.2437 Ops/s $\textbf{\color{#35bf28}+107.23\%}$

Copy link

github-actions bot commented Jan 8, 2024

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of GPU Benchmark Tests

Total Benchmarks: 92. Improved: $\large\color{#35bf28}21$. Worsened: $\large\color{#d91a1a}1$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_single 0.1138s 0.1133s 8.8274 Ops/s 8.3621 Ops/s $\textbf{\color{#35bf28}+5.56\%}$
test_sync 95.8094ms 95.6176ms 10.4583 Ops/s 10.4368 Ops/s $\color{#35bf28}+0.21\%$
test_async 0.1818s 91.9722ms 10.8729 Ops/s 10.9175 Ops/s $\color{#d91a1a}-0.41\%$
test_single_pixels 0.1270s 0.1266s 7.8977 Ops/s 7.3044 Ops/s $\textbf{\color{#35bf28}+8.12\%}$
test_sync_pixels 83.5962ms 80.9378ms 12.3552 Ops/s 12.3215 Ops/s $\color{#35bf28}+0.27\%$
test_async_pixels 0.1506s 67.8172ms 14.7455 Ops/s 14.7193 Ops/s $\color{#35bf28}+0.18\%$
test_simple 0.8807s 0.8239s 1.2138 Ops/s 1.2196 Ops/s $\color{#d91a1a}-0.48\%$
test_transformed 1.1043s 1.0472s 0.9549 Ops/s 0.9614 Ops/s $\color{#d91a1a}-0.68\%$
test_serial 2.4952s 2.4411s 0.4097 Ops/s 0.4171 Ops/s $\color{#d91a1a}-1.78\%$
test_parallel 2.1850s 2.1109s 0.4737 Ops/s 0.4856 Ops/s $\color{#d91a1a}-2.45\%$
test_step_mdp_speed[True-True-True-True-True] 0.1227ms 32.7150μs 30.5670 KOps/s 30.3818 KOps/s $\color{#35bf28}+0.61\%$
test_step_mdp_speed[True-True-True-True-False] 44.6810μs 19.8417μs 50.3988 KOps/s 51.2168 KOps/s $\color{#d91a1a}-1.60\%$
test_step_mdp_speed[True-True-True-False-True] 41.0000μs 18.4188μs 54.2925 KOps/s 54.9081 KOps/s $\color{#d91a1a}-1.12\%$
test_step_mdp_speed[True-True-True-False-False] 32.5210μs 11.1704μs 89.5225 KOps/s 92.0545 KOps/s $\color{#d91a1a}-2.75\%$
test_step_mdp_speed[True-True-False-True-True] 72.6810μs 34.2269μs 29.2168 KOps/s 29.4303 KOps/s $\color{#d91a1a}-0.73\%$
test_step_mdp_speed[True-True-False-True-False] 40.5210μs 21.5466μs 46.4111 KOps/s 47.7799 KOps/s $\color{#d91a1a}-2.86\%$
test_step_mdp_speed[True-True-False-False-True] 42.5410μs 20.0660μs 49.8356 KOps/s 50.1369 KOps/s $\color{#d91a1a}-0.60\%$
test_step_mdp_speed[True-True-False-False-False] 35.1300μs 13.0503μs 76.6267 KOps/s 79.0527 KOps/s $\color{#d91a1a}-3.07\%$
test_step_mdp_speed[True-False-True-True-True] 69.0210μs 36.2206μs 27.6086 KOps/s 27.4127 KOps/s $\color{#35bf28}+0.71\%$
test_step_mdp_speed[True-False-True-True-False] 47.1110μs 23.2517μs 43.0076 KOps/s 43.1437 KOps/s $\color{#d91a1a}-0.32\%$
test_step_mdp_speed[True-False-True-False-True] 43.3910μs 19.9767μs 50.0584 KOps/s 49.5004 KOps/s $\color{#35bf28}+1.13\%$
test_step_mdp_speed[True-False-True-False-False] 29.5400μs 12.9599μs 77.1611 KOps/s 78.4424 KOps/s $\color{#d91a1a}-1.63\%$
test_step_mdp_speed[True-False-False-True-True] 62.2210μs 38.2642μs 26.1341 KOps/s 26.0432 KOps/s $\color{#35bf28}+0.35\%$
test_step_mdp_speed[True-False-False-True-False] 84.0510μs 25.1269μs 39.7979 KOps/s 40.0967 KOps/s $\color{#d91a1a}-0.75\%$
test_step_mdp_speed[True-False-False-False-True] 45.2510μs 21.7080μs 46.0659 KOps/s 45.5785 KOps/s $\color{#35bf28}+1.07\%$
test_step_mdp_speed[True-False-False-False-False] 40.4100μs 14.7858μs 67.6326 KOps/s 68.1797 KOps/s $\color{#d91a1a}-0.80\%$
test_step_mdp_speed[False-True-True-True-True] 57.4620μs 37.0373μs 26.9998 KOps/s 27.6064 KOps/s $\color{#d91a1a}-2.20\%$
test_step_mdp_speed[False-True-True-True-False] 48.2710μs 23.4426μs 42.6573 KOps/s 43.3717 KOps/s $\color{#d91a1a}-1.65\%$
test_step_mdp_speed[False-True-True-False-True] 39.7410μs 23.9434μs 41.7652 KOps/s 41.8496 KOps/s $\color{#d91a1a}-0.20\%$
test_step_mdp_speed[False-True-True-False-False] 31.7210μs 14.8048μs 67.5455 KOps/s 67.6930 KOps/s $\color{#d91a1a}-0.22\%$
test_step_mdp_speed[False-True-False-True-True] 64.9710μs 38.3905μs 26.0481 KOps/s 25.8688 KOps/s $\color{#35bf28}+0.69\%$
test_step_mdp_speed[False-True-False-True-False] 55.4010μs 25.4911μs 39.2294 KOps/s 39.8928 KOps/s $\color{#d91a1a}-1.66\%$
test_step_mdp_speed[False-True-False-False-True] 43.4600μs 25.3919μs 39.3826 KOps/s 38.6222 KOps/s $\color{#35bf28}+1.97\%$
test_step_mdp_speed[False-True-False-False-False] 37.4600μs 16.5701μs 60.3496 KOps/s 59.9255 KOps/s $\color{#35bf28}+0.71\%$
test_step_mdp_speed[False-False-True-True-True] 66.8810μs 40.0311μs 24.9806 KOps/s 24.5661 KOps/s $\color{#35bf28}+1.69\%$
test_step_mdp_speed[False-False-True-True-False] 50.4310μs 26.8823μs 37.1992 KOps/s 37.3448 KOps/s $\color{#d91a1a}-0.39\%$
test_step_mdp_speed[False-False-True-False-True] 48.6700μs 25.6584μs 38.9735 KOps/s 39.1097 KOps/s $\color{#d91a1a}-0.35\%$
test_step_mdp_speed[False-False-True-False-False] 34.9600μs 16.7036μs 59.8673 KOps/s 60.8397 KOps/s $\color{#d91a1a}-1.60\%$
test_step_mdp_speed[False-False-False-True-True] 60.0720μs 41.1239μs 24.3168 KOps/s 24.3450 KOps/s $\color{#d91a1a}-0.12\%$
test_step_mdp_speed[False-False-False-True-False] 47.7410μs 28.9886μs 34.4963 KOps/s 34.8719 KOps/s $\color{#d91a1a}-1.08\%$
test_step_mdp_speed[False-False-False-False-True] 48.3310μs 27.3976μs 36.4995 KOps/s 37.2188 KOps/s $\color{#d91a1a}-1.93\%$
test_step_mdp_speed[False-False-False-False-False] 46.5610μs 18.3354μs 54.5393 KOps/s 55.2981 KOps/s $\color{#d91a1a}-1.37\%$
test_values[generalized_advantage_estimate-True-True] 25.4036ms 24.5794ms 40.6845 Ops/s 40.0527 Ops/s $\color{#35bf28}+1.58\%$
test_values[vec_generalized_advantage_estimate-True-True] 82.3914ms 3.2026ms 312.2462 Ops/s 307.9053 Ops/s $\color{#35bf28}+1.41\%$
test_values[td0_return_estimate-False-False] 99.1520μs 59.1047μs 16.9191 KOps/s 16.7833 KOps/s $\color{#35bf28}+0.81\%$
test_values[td1_return_estimate-False-False] 53.3329ms 51.6958ms 19.3439 Ops/s 18.1632 Ops/s $\textbf{\color{#35bf28}+6.50\%}$
test_values[vec_td1_return_estimate-False-False] 2.0812ms 1.7529ms 570.4759 Ops/s 566.9102 Ops/s $\color{#35bf28}+0.63\%$
test_values[td_lambda_return_estimate-True-False] 85.3080ms 82.8142ms 12.0752 Ops/s 11.4524 Ops/s $\textbf{\color{#35bf28}+5.44\%}$
test_values[vec_td_lambda_return_estimate-True-False] 3.9945ms 1.7913ms 558.2577 Ops/s 557.9985 Ops/s $\color{#35bf28}+0.05\%$
test_gae_speed[generalized_advantage_estimate-False-1-512] 23.9872ms 23.3662ms 42.7969 Ops/s 41.6757 Ops/s $\color{#35bf28}+2.69\%$
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 0.8772ms 0.6879ms 1.4538 KOps/s 1.4517 KOps/s $\color{#35bf28}+0.14\%$
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] 0.7095ms 0.6495ms 1.5396 KOps/s 1.4905 KOps/s $\color{#35bf28}+3.30\%$
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] 1.4905ms 1.4445ms 692.2683 Ops/s 690.4366 Ops/s $\color{#35bf28}+0.27\%$
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 0.9266ms 0.6623ms 1.5100 KOps/s 1.5014 KOps/s $\color{#35bf28}+0.57\%$
test_dqn_speed 9.1106ms 1.4882ms 671.9731 Ops/s 703.0199 Ops/s $\color{#d91a1a}-4.42\%$
test_ddpg_speed 3.2327ms 2.8339ms 352.8712 Ops/s 361.2393 Ops/s $\color{#d91a1a}-2.32\%$
test_sac_speed 8.8175ms 8.1519ms 122.6713 Ops/s 124.9361 Ops/s $\color{#d91a1a}-1.81\%$
test_redq_speed 86.7981ms 10.9651ms 91.1983 Ops/s 100.5016 Ops/s $\textbf{\color{#d91a1a}-9.26\%}$
test_redq_deprec_speed 11.7090ms 11.1497ms 89.6883 Ops/s 92.1160 Ops/s $\color{#d91a1a}-2.64\%$
test_td3_speed 8.4923ms 8.2300ms 121.5071 Ops/s 123.5740 Ops/s $\color{#d91a1a}-1.67\%$
test_cql_speed 26.8404ms 25.1334ms 39.7877 Ops/s 40.4258 Ops/s $\color{#d91a1a}-1.58\%$
test_a2c_speed 5.8876ms 5.4639ms 183.0204 Ops/s 183.8140 Ops/s $\color{#d91a1a}-0.43\%$
test_ppo_speed 6.0088ms 5.7648ms 173.4668 Ops/s 173.8513 Ops/s $\color{#d91a1a}-0.22\%$
test_reinforce_speed 5.3957ms 4.5003ms 222.2057 Ops/s 223.7258 Ops/s $\color{#d91a1a}-0.68\%$
test_iql_speed 20.0048ms 19.0825ms 52.4041 Ops/s 52.6664 Ops/s $\color{#d91a1a}-0.50\%$
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 2.9580ms 2.8524ms 350.5764 Ops/s 277.6233 Ops/s $\textbf{\color{#35bf28}+26.28\%}$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 0.8090ms 0.5295ms 1.8885 KOps/s 1.8198 KOps/s $\color{#35bf28}+3.77\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 0.7265ms 0.5071ms 1.9721 KOps/s 1.9217 KOps/s $\color{#35bf28}+2.62\%$
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 3.0744ms 2.8903ms 345.9855 Ops/s 274.0566 Ops/s $\textbf{\color{#35bf28}+26.25\%}$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 0.6267ms 0.5261ms 1.9008 KOps/s 1.8436 KOps/s $\color{#35bf28}+3.11\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 4.4417ms 0.5093ms 1.9637 KOps/s 1.9451 KOps/s $\color{#35bf28}+0.95\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 3.0581ms 2.9776ms 335.8424 Ops/s 265.9182 Ops/s $\textbf{\color{#35bf28}+26.30\%}$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 0.9132ms 0.6617ms 1.5112 KOps/s 1.3263 KOps/s $\textbf{\color{#35bf28}+13.94\%}$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 0.8186ms 0.6345ms 1.5760 KOps/s 1.5659 KOps/s $\color{#35bf28}+0.64\%$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 3.0149ms 2.8582ms 349.8752 Ops/s 272.6674 Ops/s $\textbf{\color{#35bf28}+28.32\%}$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 0.7935ms 0.5343ms 1.8715 KOps/s 1.8310 KOps/s $\color{#35bf28}+2.21\%$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 4.3150ms 0.5136ms 1.9471 KOps/s 1.9145 KOps/s $\color{#35bf28}+1.70\%$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 3.0352ms 2.8793ms 347.3073 Ops/s 272.9064 Ops/s $\textbf{\color{#35bf28}+27.26\%}$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 0.6454ms 0.5273ms 1.8965 KOps/s 1.8451 KOps/s $\color{#35bf28}+2.79\%$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 8.8636ms 0.5150ms 1.9417 KOps/s 1.6296 KOps/s $\textbf{\color{#35bf28}+19.15\%}$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 3.1023ms 2.9845ms 335.0682 Ops/s 264.8870 Ops/s $\textbf{\color{#35bf28}+26.49\%}$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 1.4746ms 0.6609ms 1.5130 KOps/s 1.4986 KOps/s $\color{#35bf28}+0.96\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 0.8272ms 0.6363ms 1.5716 KOps/s 1.5560 KOps/s $\color{#35bf28}+1.01\%$
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 0.1076s 8.5689ms 116.7006 Ops/s 88.9979 Ops/s $\textbf{\color{#35bf28}+31.13\%}$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 19.5644ms 14.1955ms 70.4448 Ops/s 63.3763 Ops/s $\textbf{\color{#35bf28}+11.15\%}$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 1.1576ms 1.0507ms 951.7723 Ops/s 334.2845 Ops/s $\textbf{\color{#35bf28}+184.72\%}$
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 0.1006s 8.5018ms 117.6225 Ops/s 90.2459 Ops/s $\textbf{\color{#35bf28}+30.34\%}$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 16.6164ms 14.0796ms 71.0249 Ops/s 63.6031 Ops/s $\textbf{\color{#35bf28}+11.67\%}$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 2.4815ms 1.1792ms 848.0278 Ops/s 335.3423 Ops/s $\textbf{\color{#35bf28}+152.88\%}$
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 99.2904ms 6.9059ms 144.8039 Ops/s 105.3884 Ops/s $\textbf{\color{#35bf28}+37.40\%}$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 16.8748ms 14.2394ms 70.2277 Ops/s 55.8256 Ops/s $\textbf{\color{#35bf28}+25.80\%}$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 7.5653ms 1.6102ms 621.0228 Ops/s 304.5492 Ops/s $\textbf{\color{#35bf28}+103.92\%}$

# Conflicts:
#	torchrl/data/replay_buffers/replay_buffers.py
#	torchrl/envs/transforms/transforms.py
@vmoens vmoens added the enhancement New feature or request label Feb 7, 2024
@vmoens
Copy link
Contributor Author

vmoens commented Feb 7, 2024

I just pushed a completely broken commit but here's the status:
For slice samplers we now need to be able to get indices of the type

tensor([3, 1],
[4, 1],
[5, 1])

where the first column is the index along the time dimension and the second the index along the batch dimension (since dims are flipped in the buffer).
That means that indexing the storage with storage[idx] won't work since we need to use the first and second columns separately. We will need to do storage[idx.unbind(-1)] which will work just fine.

storage = rb._storage._storage
else:
storage = rb._storage._storage.get("_data")
storage = rb._storage._storage[:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to use copy() or clone() to be more explicit than [:]

Suggested change
storage = rb._storage._storage[:]
storage = rb._storage._storage.copy()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The [:] does not copy but it indexes all valid elements of the storage (ie the storage up to its maximum length).
Note also that this can return anything (eg a tensordict, a tensorclass or any pytree), but not a Storage object.

tensordict.set("index", index)
return
tensordict.set("index", expand_as_right(index, tensordict))

def update_tensordict_priority(self, data: TensorDictBase) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The signature of the method has changed since we added return self.update_priority(index, priority)

isinstance(self._used_end_key, tuple)
and self._used_end_key[0] == "_data"
try:
done = storage[:].get(self.end_key)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, I would suggest to use copy() or clone() instead of [:] to be more explicit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above

@Cadene
Copy link
Contributor

Cadene commented Feb 16, 2024

Didnt run the code, but overall looks good to me!

I already expressed my small concern about this feature. I would prefer to keep everything flattened for the sake of consistency, even tho it adds some complexity when rolling out multiple environments. Intuitively, using multiple environments should speed things up, not change the way we store trajectories. But if you think this way makes more sense, I trust you!

Thanks for the solid work on this feature and all the additional improvements.

@vmoens
Copy link
Contributor Author

vmoens commented Feb 20, 2024

@Cadene

I would prefer to keep everything flattened for the sake of consistency

This will still be the default behaviour. I agree about your point regarding using multiple envs, but one might also argue that using a ParallelEnv should not change the training performance, just speed things up. In the current state of the buffer, if you are using ParallelEnv or MultiSyncCollector you get a buffer that is organised differently than when using a single env, and if you're using a SliceSampler the training with a single env will give you to data that is better organised, hence a better training performance. This PR allows you to solve that.

@Cadene
Copy link
Contributor

Cadene commented Feb 21, 2024

Looks good to me. Thanks for addressing my comments :)

@vmoens vmoens marked this pull request as ready for review February 22, 2024 01:28
@vmoens vmoens merged commit c3bda41 into main Feb 22, 2024
21 of 30 checks passed
@vmoens vmoens deleted the extend-rb-dim1 branch February 22, 2024 03:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants