-
Notifications
You must be signed in to change notification settings - Fork 329
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] batched trajectories - SliceSampler compatibility #1775
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/1775
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Merge Blocking SEVsThere is 1 active merge blocking SEVs. Please view them below:
If you must merge, use ⏳ No Failures, 42 PendingAs of commit 36ea29f with merge base bb44067 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_single | 61.7605ms | 61.4829ms | 16.2647 Ops/s | 15.5168 Ops/s | |
test_sync | 38.9363ms | 33.4051ms | 29.9356 Ops/s | 28.4242 Ops/s | |
test_async | 55.3167ms | 31.2616ms | 31.9881 Ops/s | 30.6990 Ops/s | |
test_simple | 0.5124s | 0.4391s | 2.2772 Ops/s | 2.3381 Ops/s | |
test_transformed | 0.6322s | 0.5858s | 1.7071 Ops/s | 1.7136 Ops/s | |
test_serial | 1.4805s | 1.4325s | 0.6981 Ops/s | 0.6939 Ops/s | |
test_parallel | 1.4721s | 1.4204s | 0.7040 Ops/s | 0.7030 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.1567ms | 22.0502μs | 45.3510 KOps/s | 46.6218 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 46.7070μs | 13.3538μs | 74.8849 KOps/s | 74.8470 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 43.0710μs | 12.8596μs | 77.7626 KOps/s | 79.8093 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 46.6170μs | 7.6798μs | 130.2115 KOps/s | 131.2072 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 58.1290μs | 23.6290μs | 42.3209 KOps/s | 43.7154 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 56.0450μs | 14.6297μs | 68.3541 KOps/s | 68.1826 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 41.9590μs | 13.9872μs | 71.4942 KOps/s | 72.5981 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 52.0170μs | 8.9891μs | 111.2456 KOps/s | 112.0181 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 54.6120μs | 24.9586μs | 40.0664 KOps/s | 40.7573 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 55.3240μs | 16.0185μs | 62.4277 KOps/s | 62.8074 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 65.9940μs | 14.0679μs | 71.0838 KOps/s | 72.8166 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 26.3090μs | 9.0204μs | 110.8594 KOps/s | 111.7196 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 80.9910μs | 25.9993μs | 38.4626 KOps/s | 38.9732 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 45.6960μs | 17.2574μs | 57.9462 KOps/s | 58.6715 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 49.7930μs | 15.0836μs | 66.2972 KOps/s | 66.9598 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 32.6910μs | 10.2589μs | 97.4759 KOps/s | 98.6249 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 57.3270μs | 24.8973μs | 40.1650 KOps/s | 40.7136 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 55.3840μs | 16.1713μs | 61.8378 KOps/s | 62.3015 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 38.9530μs | 16.4260μs | 60.8791 KOps/s | 62.2316 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 43.3820μs | 10.3518μs | 96.6017 KOps/s | 98.5983 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 42.6300μs | 26.7426μs | 37.3935 KOps/s | 38.3560 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 47.6090μs | 17.4203μs | 57.4043 KOps/s | 58.1870 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 62.2870μs | 17.5520μs | 56.9736 KOps/s | 58.1164 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 31.1980μs | 11.6036μs | 86.1805 KOps/s | 88.2029 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 70.3720μs | 27.3795μs | 36.5236 KOps/s | 37.0359 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 41.5080μs | 18.6850μs | 53.5189 KOps/s | 53.9434 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 63.2690μs | 17.5990μs | 56.8215 KOps/s | 57.9755 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 34.6550μs | 11.6243μs | 86.0267 KOps/s | 88.2819 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 76.1730μs | 28.6207μs | 34.9397 KOps/s | 35.8407 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 68.4680μs | 19.7379μs | 50.6638 KOps/s | 49.1215 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 43.6220μs | 18.5895μs | 53.7937 KOps/s | 54.3925 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 58.9200μs | 12.5987μs | 79.3734 KOps/s | 79.8398 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 11.2049ms | 9.3930ms | 106.4622 Ops/s | 111.2378 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 36.6587ms | 33.2840ms | 30.0444 Ops/s | 28.4262 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.2310ms | 0.1897ms | 5.2711 KOps/s | 6.0523 KOps/s | |
test_values[td1_return_estimate-False-False] | 23.7823ms | 23.4988ms | 42.5554 Ops/s | 44.7817 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 47.8651ms | 33.7301ms | 29.6471 Ops/s | 28.2776 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 53.8597ms | 33.7039ms | 29.6702 Ops/s | 31.0197 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 34.8422ms | 33.2351ms | 30.0887 Ops/s | 28.3415 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 9.0191ms | 8.1233ms | 123.1026 Ops/s | 125.2650 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 2.2621ms | 1.9807ms | 504.8803 Ops/s | 504.6952 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.4238ms | 0.3459ms | 2.8911 KOps/s | 2.9210 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 42.9739ms | 40.4440ms | 24.7255 Ops/s | 22.3914 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 4.1889ms | 3.0309ms | 329.9346 Ops/s | 331.6252 Ops/s | |
test_dqn_speed | 6.9495ms | 1.3933ms | 717.7203 Ops/s | 717.7059 Ops/s | |
test_ddpg_speed | 4.7414ms | 2.7765ms | 360.1635 Ops/s | 361.5827 Ops/s | |
test_sac_speed | 73.8254ms | 8.8557ms | 112.9212 Ops/s | 120.7015 Ops/s | |
test_redq_speed | 14.2549ms | 13.1249ms | 76.1911 Ops/s | 73.5567 Ops/s | |
test_redq_deprec_speed | 13.8899ms | 13.2558ms | 75.4388 Ops/s | 76.6079 Ops/s | |
test_td3_speed | 8.6680ms | 8.3621ms | 119.5865 Ops/s | 122.3667 Ops/s | |
test_cql_speed | 37.7721ms | 36.2090ms | 27.6175 Ops/s | 26.3220 Ops/s | |
test_a2c_speed | 8.6546ms | 7.4201ms | 134.7685 Ops/s | 132.4834 Ops/s | |
test_ppo_speed | 8.8849ms | 7.6707ms | 130.3670 Ops/s | 127.3111 Ops/s | |
test_reinforce_speed | 7.2483ms | 6.6013ms | 151.4852 Ops/s | 144.2457 Ops/s | |
test_iql_speed | 33.8294ms | 32.5761ms | 30.6974 Ops/s | 30.8052 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 2.5044ms | 2.2437ms | 445.6973 Ops/s | 357.2003 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.9650ms | 0.4969ms | 2.0127 KOps/s | 1.9663 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.9056ms | 0.5024ms | 1.9906 KOps/s | 1.9987 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 3.0501ms | 2.2505ms | 444.3364 Ops/s | 358.3803 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 1.0570ms | 0.4896ms | 2.0427 KOps/s | 1.9901 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.7268ms | 0.4675ms | 2.1389 KOps/s | 2.1060 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 3.4259ms | 2.4072ms | 415.4250 Ops/s | 346.6520 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.1513ms | 0.6113ms | 1.6358 KOps/s | 1.6057 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 1.0210ms | 0.5887ms | 1.6986 KOps/s | 1.6806 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 2.4080ms | 2.2041ms | 453.7061 Ops/s | 361.8144 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.5914ms | 0.4932ms | 2.0275 KOps/s | 1.9594 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 90.9529ms | 0.5666ms | 1.7650 KOps/s | 2.0816 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 3.2751ms | 2.2741ms | 439.7311 Ops/s | 356.6487 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 1.0627ms | 0.4897ms | 2.0419 KOps/s | 1.9883 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.8678ms | 0.4710ms | 2.1232 KOps/s | 2.0955 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 3.4008ms | 2.3610ms | 423.5445 Ops/s | 342.6617 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 88.4038ms | 0.7100ms | 1.4085 KOps/s | 1.5917 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.8451ms | 0.5899ms | 1.6953 KOps/s | 1.6538 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 87.5562ms | 5.2885ms | 189.0889 Ops/s | 134.4109 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 14.0285ms | 11.8655ms | 84.2782 Ops/s | 68.0966 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 1.6598ms | 1.0415ms | 960.1530 Ops/s | 394.2132 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 86.0137ms | 6.8708ms | 145.5441 Ops/s | 140.4029 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 14.2004ms | 11.8091ms | 84.6807 Ops/s | 76.0607 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 4.2582ms | 1.1295ms | 885.3269 Ops/s | 392.8976 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 88.2532ms | 7.3028ms | 136.9344 Ops/s | 109.5690 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 14.8371ms | 12.1849ms | 82.0690 Ops/s | 74.9371 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 2.0048ms | 1.3395ms | 746.5291 Ops/s | 360.2437 Ops/s |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_single | 0.1138s | 0.1133s | 8.8274 Ops/s | 8.3621 Ops/s | |
test_sync | 95.8094ms | 95.6176ms | 10.4583 Ops/s | 10.4368 Ops/s | |
test_async | 0.1818s | 91.9722ms | 10.8729 Ops/s | 10.9175 Ops/s | |
test_single_pixels | 0.1270s | 0.1266s | 7.8977 Ops/s | 7.3044 Ops/s | |
test_sync_pixels | 83.5962ms | 80.9378ms | 12.3552 Ops/s | 12.3215 Ops/s | |
test_async_pixels | 0.1506s | 67.8172ms | 14.7455 Ops/s | 14.7193 Ops/s | |
test_simple | 0.8807s | 0.8239s | 1.2138 Ops/s | 1.2196 Ops/s | |
test_transformed | 1.1043s | 1.0472s | 0.9549 Ops/s | 0.9614 Ops/s | |
test_serial | 2.4952s | 2.4411s | 0.4097 Ops/s | 0.4171 Ops/s | |
test_parallel | 2.1850s | 2.1109s | 0.4737 Ops/s | 0.4856 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.1227ms | 32.7150μs | 30.5670 KOps/s | 30.3818 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 44.6810μs | 19.8417μs | 50.3988 KOps/s | 51.2168 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 41.0000μs | 18.4188μs | 54.2925 KOps/s | 54.9081 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 32.5210μs | 11.1704μs | 89.5225 KOps/s | 92.0545 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 72.6810μs | 34.2269μs | 29.2168 KOps/s | 29.4303 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 40.5210μs | 21.5466μs | 46.4111 KOps/s | 47.7799 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 42.5410μs | 20.0660μs | 49.8356 KOps/s | 50.1369 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 35.1300μs | 13.0503μs | 76.6267 KOps/s | 79.0527 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 69.0210μs | 36.2206μs | 27.6086 KOps/s | 27.4127 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 47.1110μs | 23.2517μs | 43.0076 KOps/s | 43.1437 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 43.3910μs | 19.9767μs | 50.0584 KOps/s | 49.5004 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 29.5400μs | 12.9599μs | 77.1611 KOps/s | 78.4424 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 62.2210μs | 38.2642μs | 26.1341 KOps/s | 26.0432 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 84.0510μs | 25.1269μs | 39.7979 KOps/s | 40.0967 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 45.2510μs | 21.7080μs | 46.0659 KOps/s | 45.5785 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 40.4100μs | 14.7858μs | 67.6326 KOps/s | 68.1797 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 57.4620μs | 37.0373μs | 26.9998 KOps/s | 27.6064 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 48.2710μs | 23.4426μs | 42.6573 KOps/s | 43.3717 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 39.7410μs | 23.9434μs | 41.7652 KOps/s | 41.8496 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 31.7210μs | 14.8048μs | 67.5455 KOps/s | 67.6930 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 64.9710μs | 38.3905μs | 26.0481 KOps/s | 25.8688 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 55.4010μs | 25.4911μs | 39.2294 KOps/s | 39.8928 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 43.4600μs | 25.3919μs | 39.3826 KOps/s | 38.6222 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 37.4600μs | 16.5701μs | 60.3496 KOps/s | 59.9255 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 66.8810μs | 40.0311μs | 24.9806 KOps/s | 24.5661 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 50.4310μs | 26.8823μs | 37.1992 KOps/s | 37.3448 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 48.6700μs | 25.6584μs | 38.9735 KOps/s | 39.1097 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 34.9600μs | 16.7036μs | 59.8673 KOps/s | 60.8397 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 60.0720μs | 41.1239μs | 24.3168 KOps/s | 24.3450 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 47.7410μs | 28.9886μs | 34.4963 KOps/s | 34.8719 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 48.3310μs | 27.3976μs | 36.4995 KOps/s | 37.2188 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 46.5610μs | 18.3354μs | 54.5393 KOps/s | 55.2981 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 25.4036ms | 24.5794ms | 40.6845 Ops/s | 40.0527 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 82.3914ms | 3.2026ms | 312.2462 Ops/s | 307.9053 Ops/s | |
test_values[td0_return_estimate-False-False] | 99.1520μs | 59.1047μs | 16.9191 KOps/s | 16.7833 KOps/s | |
test_values[td1_return_estimate-False-False] | 53.3329ms | 51.6958ms | 19.3439 Ops/s | 18.1632 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 2.0812ms | 1.7529ms | 570.4759 Ops/s | 566.9102 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 85.3080ms | 82.8142ms | 12.0752 Ops/s | 11.4524 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 3.9945ms | 1.7913ms | 558.2577 Ops/s | 557.9985 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 23.9872ms | 23.3662ms | 42.7969 Ops/s | 41.6757 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 0.8772ms | 0.6879ms | 1.4538 KOps/s | 1.4517 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.7095ms | 0.6495ms | 1.5396 KOps/s | 1.4905 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 1.4905ms | 1.4445ms | 692.2683 Ops/s | 690.4366 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 0.9266ms | 0.6623ms | 1.5100 KOps/s | 1.5014 KOps/s | |
test_dqn_speed | 9.1106ms | 1.4882ms | 671.9731 Ops/s | 703.0199 Ops/s | |
test_ddpg_speed | 3.2327ms | 2.8339ms | 352.8712 Ops/s | 361.2393 Ops/s | |
test_sac_speed | 8.8175ms | 8.1519ms | 122.6713 Ops/s | 124.9361 Ops/s | |
test_redq_speed | 86.7981ms | 10.9651ms | 91.1983 Ops/s | 100.5016 Ops/s | |
test_redq_deprec_speed | 11.7090ms | 11.1497ms | 89.6883 Ops/s | 92.1160 Ops/s | |
test_td3_speed | 8.4923ms | 8.2300ms | 121.5071 Ops/s | 123.5740 Ops/s | |
test_cql_speed | 26.8404ms | 25.1334ms | 39.7877 Ops/s | 40.4258 Ops/s | |
test_a2c_speed | 5.8876ms | 5.4639ms | 183.0204 Ops/s | 183.8140 Ops/s | |
test_ppo_speed | 6.0088ms | 5.7648ms | 173.4668 Ops/s | 173.8513 Ops/s | |
test_reinforce_speed | 5.3957ms | 4.5003ms | 222.2057 Ops/s | 223.7258 Ops/s | |
test_iql_speed | 20.0048ms | 19.0825ms | 52.4041 Ops/s | 52.6664 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 2.9580ms | 2.8524ms | 350.5764 Ops/s | 277.6233 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.8090ms | 0.5295ms | 1.8885 KOps/s | 1.8198 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.7265ms | 0.5071ms | 1.9721 KOps/s | 1.9217 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 3.0744ms | 2.8903ms | 345.9855 Ops/s | 274.0566 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.6267ms | 0.5261ms | 1.9008 KOps/s | 1.8436 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 4.4417ms | 0.5093ms | 1.9637 KOps/s | 1.9451 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 3.0581ms | 2.9776ms | 335.8424 Ops/s | 265.9182 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 0.9132ms | 0.6617ms | 1.5112 KOps/s | 1.3263 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.8186ms | 0.6345ms | 1.5760 KOps/s | 1.5659 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 3.0149ms | 2.8582ms | 349.8752 Ops/s | 272.6674 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.7935ms | 0.5343ms | 1.8715 KOps/s | 1.8310 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 4.3150ms | 0.5136ms | 1.9471 KOps/s | 1.9145 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 3.0352ms | 2.8793ms | 347.3073 Ops/s | 272.9064 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.6454ms | 0.5273ms | 1.8965 KOps/s | 1.8451 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 8.8636ms | 0.5150ms | 1.9417 KOps/s | 1.6296 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 3.1023ms | 2.9845ms | 335.0682 Ops/s | 264.8870 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.4746ms | 0.6609ms | 1.5130 KOps/s | 1.4986 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.8272ms | 0.6363ms | 1.5716 KOps/s | 1.5560 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 0.1076s | 8.5689ms | 116.7006 Ops/s | 88.9979 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 19.5644ms | 14.1955ms | 70.4448 Ops/s | 63.3763 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 1.1576ms | 1.0507ms | 951.7723 Ops/s | 334.2845 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.1006s | 8.5018ms | 117.6225 Ops/s | 90.2459 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 16.6164ms | 14.0796ms | 71.0249 Ops/s | 63.6031 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 2.4815ms | 1.1792ms | 848.0278 Ops/s | 335.3423 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 99.2904ms | 6.9059ms | 144.8039 Ops/s | 105.3884 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 16.8748ms | 14.2394ms | 70.2277 Ops/s | 55.8256 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 7.5653ms | 1.6102ms | 621.0228 Ops/s | 304.5492 Ops/s |
# Conflicts: # torchrl/data/replay_buffers/replay_buffers.py # torchrl/envs/transforms/transforms.py
I just pushed a completely broken commit but here's the status: tensor([3, 1],
[4, 1],
[5, 1]) where the first column is the index along the time dimension and the second the index along the batch dimension (since dims are flipped in the buffer). |
storage = rb._storage._storage | ||
else: | ||
storage = rb._storage._storage.get("_data") | ||
storage = rb._storage._storage[:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest to use copy() or clone() to be more explicit than [:]
storage = rb._storage._storage[:] | |
storage = rb._storage._storage.copy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The [:]
does not copy but it indexes all valid elements of the storage (ie the storage up to its maximum length).
Note also that this can return anything (eg a tensordict, a tensorclass or any pytree), but not a Storage object.
tensordict.set("index", index) | ||
return | ||
tensordict.set("index", expand_as_right(index, tensordict)) | ||
|
||
def update_tensordict_priority(self, data: TensorDictBase) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The signature of the method has changed since we added return self.update_priority(index, priority)
isinstance(self._used_end_key, tuple) | ||
and self._used_end_key[0] == "_data" | ||
try: | ||
done = storage[:].get(self.end_key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, I would suggest to use copy() or clone() instead of [:] to be more explicit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment above
Didnt run the code, but overall looks good to me! I already expressed my small concern about this feature. I would prefer to keep everything flattened for the sake of consistency, even tho it adds some complexity when rolling out multiple environments. Intuitively, using multiple environments should speed things up, not change the way we store trajectories. But if you think this way makes more sense, I trust you! Thanks for the solid work on this feature and all the additional improvements. |
This will still be the default behaviour. I agree about your point regarding using multiple envs, but one might also argue that using a ParallelEnv should not change the training performance, just speed things up. In the current state of the buffer, if you are using |
Looks good to me. Thanks for addressing my comments :) |
What does this PR do?
We introduce a new way of extending replay buffers that is more faithful to the way trajectories are organised in torchrl.
Improvements:
SliceSampler
and subclasses) can handle different batch dimensions. For multidim storages, samples are always flattened (as if the storage was indexed by a mask)Why is this change needed
We currently have 2 ways of extending a buffer which I summarize in this ugly drawings
The first attempt we had at storing trajectories was to build the buffer according to
Sampling consists in getting trajectories returned by the collector. This isn't amazing because trajectories could be interrupted in the middle with a truncated/terminated signal, or simply be longer than
T
.The other issue is that your buffer size isn't the number of transitions but the number of traj stored. One can immediately see why this isn't as flexible as one could hope for!
Finaly drawback: subsampling slices needed to be done via
rl/torchrl/envs/transforms/transforms.py
Line 5712 in 31bea14
The second attempt was simply to flatten the tensordict and use
SliceSampler
to get sub-trajectoriesHere the drawings above come in handy: we still have the problem that trajectories are not contiguous and can be truncated in multiple places if a ParallelEnv is used.
But it's closer to what people want and the buffer size is accurate.
New API
The new API is quite simple: one just needs to pass the number of dimensions to the buffer and all the other components will know what to do:
This is the second picture I pasted earlier. The big advantage with SyncCollector + ParallelEnv is that you now have trajectories that are well structured and as long as you could wish.
## Other changes
TensorDictReplayBuffer
is a bit buggy because we can't store the indices so I made this not possible anymore. I believe it is the only bc-breaking change of this PR, but all the features that used to be supported still are via aReplayBuffer
usage."_data"
prefix in storages placed withinTensorDictReplayBuffer
instances.Gist: https://gist.github.com/vmoens/b928af1a3a9567b1f0862a58ec592f85
cc @btx0424 @albertbou92 @Cadene