-
Notifications
You must be signed in to change notification settings - Fork 326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] Robust sync for non_blocking=True #2034
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2034
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 2325a90 with merge base 247ed6e (): NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_single | 53.3812ms | 52.2538ms | 19.1374 Ops/s | 18.0769 Ops/s | |
test_sync | 47.9155ms | 32.2791ms | 30.9798 Ops/s | 30.1361 Ops/s | |
test_async | 62.0251ms | 27.5733ms | 36.2670 Ops/s | 36.8335 Ops/s | |
test_simple | 0.4056s | 0.3469s | 2.8828 Ops/s | 2.8572 Ops/s | |
test_transformed | 0.5377s | 0.4787s | 2.0892 Ops/s | 2.0005 Ops/s | |
test_serial | 1.2483s | 1.1885s | 0.8414 Ops/s | 0.8267 Ops/s | |
test_parallel | 1.0951s | 1.0241s | 0.9765 Ops/s | 0.9974 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.1779ms | 20.9846μs | 47.6539 KOps/s | 47.5760 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 42.5900μs | 12.8893μs | 77.5837 KOps/s | 76.6016 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 37.9200μs | 12.4010μs | 80.6385 KOps/s | 80.5848 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 56.9160μs | 7.5341μs | 132.7294 KOps/s | 131.4300 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 57.0570μs | 22.2721μs | 44.8992 KOps/s | 44.5525 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 49.9840μs | 13.9849μs | 71.5057 KOps/s | 70.0538 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 39.8350μs | 13.5621μs | 73.7348 KOps/s | 73.0396 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 43.2710μs | 8.7474μs | 114.3194 KOps/s | 112.8041 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 70.6820μs | 23.5723μs | 42.4226 KOps/s | 42.0586 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 74.6430μs | 15.2511μs | 65.5689 KOps/s | 63.8310 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 55.4640μs | 13.5681μs | 73.7025 KOps/s | 73.6789 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 35.0460μs | 8.8790μs | 112.6259 KOps/s | 113.0591 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 61.7960μs | 24.7389μs | 40.4222 KOps/s | 40.4233 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 52.7380μs | 16.5347μs | 60.4790 KOps/s | 59.3943 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 54.4820μs | 14.7809μs | 67.6550 KOps/s | 68.1479 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 43.4420μs | 9.9843μs | 100.1574 KOps/s | 99.2767 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 60.5930μs | 23.6723μs | 42.2435 KOps/s | 41.4740 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 51.6860μs | 15.3864μs | 64.9926 KOps/s | 63.6592 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 45.7050μs | 15.7093μs | 63.6566 KOps/s | 63.3692 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 43.4010μs | 9.9342μs | 100.6622 KOps/s | 100.1730 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 48.5510μs | 25.0053μs | 39.9916 KOps/s | 39.2117 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 46.3870μs | 16.6133μs | 60.1926 KOps/s | 59.0957 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 49.6630μs | 16.9011μs | 59.1678 KOps/s | 59.1020 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 46.8880μs | 11.1301μs | 89.8461 KOps/s | 89.5452 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 63.3090μs | 26.0027μs | 38.4576 KOps/s | 37.0821 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 53.4600μs | 17.7800μs | 56.2431 KOps/s | 53.3585 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 67.7570μs | 17.0405μs | 58.6837 KOps/s | 58.5954 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 44.6640μs | 11.1831μs | 89.4204 KOps/s | 88.9091 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 69.3400μs | 26.9193μs | 37.1480 KOps/s | 36.5243 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 54.5930μs | 18.7920μs | 53.2143 KOps/s | 52.2920 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 46.2070μs | 17.8902μs | 55.8964 KOps/s | 55.9680 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 45.8760μs | 12.2180μs | 81.8463 KOps/s | 81.5936 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 12.3607ms | 9.5962ms | 104.2079 Ops/s | 105.8152 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 39.2443ms | 33.8291ms | 29.5604 Ops/s | 28.2883 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.2455ms | 0.1706ms | 5.8618 KOps/s | 5.7371 KOps/s | |
test_values[td1_return_estimate-False-False] | 26.4533ms | 23.8980ms | 41.8445 Ops/s | 42.0094 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 35.3396ms | 33.5788ms | 29.7807 Ops/s | 28.2830 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 37.4173ms | 34.1682ms | 29.2670 Ops/s | 29.8069 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 35.0090ms | 33.5188ms | 29.8340 Ops/s | 28.2407 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 9.6570ms | 8.3257ms | 120.1103 Ops/s | 121.2668 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 2.3614ms | 2.0689ms | 483.3427 Ops/s | 514.6278 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.4112ms | 0.3533ms | 2.8306 KOps/s | 2.8742 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 50.5891ms | 47.7565ms | 20.9395 Ops/s | 23.9934 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 3.5888ms | 3.0616ms | 326.6311 Ops/s | 329.5421 Ops/s | |
test_dqn_speed | 7.8260ms | 1.3699ms | 729.9610 Ops/s | 669.3932 Ops/s | |
test_ddpg_speed | 3.6461ms | 2.7266ms | 366.7597 Ops/s | 367.1627 Ops/s | |
test_sac_speed | 10.4166ms | 8.4116ms | 118.8829 Ops/s | 121.2214 Ops/s | |
test_redq_speed | 94.9047ms | 14.3795ms | 69.5433 Ops/s | 75.5344 Ops/s | |
test_redq_deprec_speed | 14.9376ms | 13.7393ms | 72.7841 Ops/s | 73.3106 Ops/s | |
test_td3_speed | 8.7984ms | 8.3424ms | 119.8693 Ops/s | 120.9407 Ops/s | |
test_cql_speed | 37.6596ms | 36.1955ms | 27.6278 Ops/s | 27.4600 Ops/s | |
test_a2c_speed | 8.1571ms | 7.3789ms | 135.5220 Ops/s | 134.6131 Ops/s | |
test_ppo_speed | 8.6615ms | 7.7069ms | 129.7543 Ops/s | 128.5031 Ops/s | |
test_reinforce_speed | 7.2158ms | 6.6299ms | 150.8308 Ops/s | 152.2161 Ops/s | |
test_iql_speed | 34.5493ms | 33.0329ms | 30.2729 Ops/s | 30.9841 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 2.8972ms | 2.2395ms | 446.5236 Ops/s | 444.9158 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 1.4089ms | 0.4939ms | 2.0247 KOps/s | 2.0215 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.7395ms | 0.4676ms | 2.1387 KOps/s | 2.1325 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 3.3491ms | 2.2466ms | 445.1155 Ops/s | 428.7314 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.7539ms | 0.4879ms | 2.0496 KOps/s | 1.9979 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 3.6852ms | 0.4682ms | 2.1360 KOps/s | 2.0642 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 2.3341ms | 1.2439ms | 803.8918 Ops/s | 815.7358 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 4.6724ms | 1.1656ms | 857.9515 Ops/s | 868.8682 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 3.5226ms | 2.4457ms | 408.8833 Ops/s | 405.1121 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 0.8853ms | 0.6190ms | 1.6155 KOps/s | 1.6165 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.8152ms | 0.5851ms | 1.7092 KOps/s | 1.6744 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 3.3197ms | 2.2710ms | 440.3430 Ops/s | 433.4284 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.8321ms | 0.5008ms | 1.9968 KOps/s | 2.0101 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.6422ms | 0.4767ms | 2.0979 KOps/s | 2.1251 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 3.4052ms | 2.3545ms | 424.7185 Ops/s | 408.5122 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 1.0240ms | 0.4972ms | 2.0111 KOps/s | 1.9995 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.7533ms | 0.4750ms | 2.1052 KOps/s | 2.0758 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 3.5436ms | 2.4355ms | 410.5885 Ops/s | 400.5543 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 0.9506ms | 0.6238ms | 1.6032 KOps/s | 1.5984 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.7556ms | 0.5944ms | 1.6825 KOps/s | 1.6648 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 0.1193s | 7.8829ms | 126.8566 Ops/s | 127.6588 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 14.7851ms | 12.2329ms | 81.7470 Ops/s | 82.5106 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 1.2416ms | 1.0454ms | 956.5410 Ops/s | 947.9108 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.1122s | 5.8657ms | 170.4839 Ops/s | 171.4537 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 15.2906ms | 12.4909ms | 80.0581 Ops/s | 71.2459 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 1.6396ms | 1.1473ms | 871.6401 Ops/s | 951.7578 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.1136s | 8.3589ms | 119.6327 Ops/s | 166.1782 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 15.2635ms | 12.7276ms | 78.5697 Ops/s | 81.2335 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 1.9696ms | 1.4017ms | 713.4371 Ops/s | 732.5273 Ops/s |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_single | 0.1014s | 97.8544ms | 10.2193 Ops/s | 9.5566 Ops/s | |
test_sync | 86.0001ms | 85.6991ms | 11.6687 Ops/s | 11.5169 Ops/s | |
test_async | 0.1594s | 69.8134ms | 14.3239 Ops/s | 14.1488 Ops/s | |
test_single_pixels | 0.1063s | 0.1059s | 9.4391 Ops/s | 9.2682 Ops/s | |
test_sync_pixels | 67.6936ms | 65.7022ms | 15.2202 Ops/s | 15.2665 Ops/s | |
test_async_pixels | 0.1232s | 55.5644ms | 17.9971 Ops/s | 18.1080 Ops/s | |
test_simple | 0.7305s | 0.6634s | 1.5074 Ops/s | 1.4933 Ops/s | |
test_transformed | 0.9271s | 0.8621s | 1.1600 Ops/s | 1.1277 Ops/s | |
test_serial | 2.0956s | 2.0272s | 0.4933 Ops/s | 0.4803 Ops/s | |
test_parallel | 1.7995s | 1.7449s | 0.5731 Ops/s | 0.5696 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 99.7420μs | 32.9167μs | 30.3797 KOps/s | 30.4911 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 36.7910μs | 19.4978μs | 51.2879 KOps/s | 50.4947 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 46.3100μs | 18.6797μs | 53.5341 KOps/s | 54.2319 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 31.4100μs | 11.0890μs | 90.1795 KOps/s | 89.4312 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 55.2810μs | 34.6099μs | 28.8935 KOps/s | 29.2397 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 45.0610μs | 21.1649μs | 47.2480 KOps/s | 46.6955 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 42.2810μs | 20.4130μs | 48.9884 KOps/s | 49.3859 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 35.5300μs | 12.8817μs | 77.6295 KOps/s | 75.9785 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 63.8310μs | 35.9247μs | 27.8360 KOps/s | 27.5604 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 41.2800μs | 23.3163μs | 42.8885 KOps/s | 42.3196 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 38.3000μs | 20.5465μs | 48.6701 KOps/s | 49.2037 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 33.7210μs | 12.9322μs | 77.3261 KOps/s | 74.9765 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 64.9610μs | 38.2332μs | 26.1553 KOps/s | 26.1388 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 54.6310μs | 25.0452μs | 39.9278 KOps/s | 39.8992 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 49.9910μs | 22.2770μs | 44.8893 KOps/s | 44.3974 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 33.7810μs | 14.7767μs | 67.6740 KOps/s | 67.8884 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 54.1220μs | 36.4078μs | 27.4666 KOps/s | 27.3672 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 42.7510μs | 23.3480μs | 42.8302 KOps/s | 41.5752 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 52.4310μs | 24.2670μs | 41.2082 KOps/s | 40.7793 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 39.2310μs | 14.8430μs | 67.3718 KOps/s | 67.6280 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 61.5010μs | 38.5556μs | 25.9366 KOps/s | 26.0764 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 46.2710μs | 25.3245μs | 39.4874 KOps/s | 39.0774 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 46.1110μs | 26.1657μs | 38.2180 KOps/s | 37.6717 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 34.7300μs | 16.7556μs | 59.6817 KOps/s | 60.0190 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 68.4110μs | 40.1688μs | 24.8949 KOps/s | 25.2258 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 44.9910μs | 27.1752μs | 36.7983 KOps/s | 37.2913 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 83.0020μs | 26.3603μs | 37.9358 KOps/s | 37.8107 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 42.8310μs | 16.8230μs | 59.4425 KOps/s | 60.1464 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 61.7010μs | 41.5335μs | 24.0770 KOps/s | 23.9607 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 53.8810μs | 28.6718μs | 34.8775 KOps/s | 34.2294 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 47.7800μs | 27.8960μs | 35.8474 KOps/s | 35.5418 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 38.1210μs | 18.5470μs | 53.9170 KOps/s | 54.2898 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 24.0514ms | 23.7075ms | 42.1807 Ops/s | 39.8152 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 87.1013ms | 3.3046ms | 302.6097 Ops/s | 313.7784 Ops/s | |
test_values[td0_return_estimate-False-False] | 97.5320μs | 63.7114μs | 15.6958 KOps/s | 15.3121 KOps/s | |
test_values[td1_return_estimate-False-False] | 53.3282ms | 52.8591ms | 18.9182 Ops/s | 18.2509 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 2.1155ms | 1.7569ms | 569.1972 Ops/s | 565.3292 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 86.0341ms | 84.2941ms | 11.8632 Ops/s | 11.2388 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 2.1186ms | 1.7599ms | 568.2181 Ops/s | 565.3171 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 23.9720ms | 23.7783ms | 42.0551 Ops/s | 39.7676 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 0.9166ms | 0.6981ms | 1.4324 KOps/s | 1.3685 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.7198ms | 0.6446ms | 1.5514 KOps/s | 1.4692 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 1.4931ms | 1.4516ms | 688.8765 Ops/s | 681.7670 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 0.9518ms | 0.6674ms | 1.4983 KOps/s | 1.4534 KOps/s | |
test_dqn_speed | 8.2818ms | 1.4277ms | 700.4214 Ops/s | 690.3867 Ops/s | |
test_ddpg_speed | 3.4031ms | 2.6670ms | 374.9577 Ops/s | 363.5583 Ops/s | |
test_sac_speed | 8.6774ms | 7.9749ms | 125.3933 Ops/s | 122.1582 Ops/s | |
test_redq_speed | 10.8731ms | 10.0068ms | 99.9321 Ops/s | 97.6933 Ops/s | |
test_redq_deprec_speed | 11.2814ms | 10.6843ms | 93.5953 Ops/s | 88.6557 Ops/s | |
test_td3_speed | 8.5192ms | 7.9164ms | 126.3206 Ops/s | 123.2286 Ops/s | |
test_cql_speed | 25.5961ms | 24.7304ms | 40.4360 Ops/s | 36.4064 Ops/s | |
test_a2c_speed | 5.7827ms | 5.5043ms | 181.6757 Ops/s | 177.9162 Ops/s | |
test_ppo_speed | 6.8785ms | 5.8617ms | 170.5994 Ops/s | 167.5779 Ops/s | |
test_reinforce_speed | 5.2829ms | 4.4497ms | 224.7345 Ops/s | 220.7949 Ops/s | |
test_iql_speed | 19.9920ms | 19.3294ms | 51.7347 Ops/s | 51.2957 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 3.0109ms | 2.8586ms | 349.8204 Ops/s | 344.3401 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.6297ms | 0.5317ms | 1.8807 KOps/s | 1.8395 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 4.2383ms | 0.5133ms | 1.9483 KOps/s | 1.9396 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 3.0620ms | 2.8855ms | 346.5606 Ops/s | 341.0314 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.6239ms | 0.5211ms | 1.9190 KOps/s | 1.8634 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.1128s | 0.6022ms | 1.6604 KOps/s | 1.9580 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 1.5094ms | 1.4045ms | 711.9955 Ops/s | 683.6311 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 1.4845ms | 1.3316ms | 750.9801 Ops/s | 719.1052 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 3.1987ms | 2.9981ms | 333.5391 Ops/s | 331.7543 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.5450ms | 0.6546ms | 1.5277 KOps/s | 1.4292 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.8553ms | 0.6304ms | 1.5864 KOps/s | 1.5492 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 2.9615ms | 2.8643ms | 349.1195 Ops/s | 342.4073 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 1.3285ms | 0.5334ms | 1.8748 KOps/s | 1.8526 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.7129ms | 0.5145ms | 1.9438 KOps/s | 1.9274 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 3.0832ms | 2.9075ms | 343.9372 Ops/s | 342.7381 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.6510ms | 0.5273ms | 1.8965 KOps/s | 1.8657 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 4.4694ms | 0.5076ms | 1.9702 KOps/s | 1.9559 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 3.1057ms | 3.0055ms | 332.7223 Ops/s | 327.3884 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.6531ms | 0.6566ms | 1.5231 KOps/s | 1.4853 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.7896ms | 0.6328ms | 1.5802 KOps/s | 1.5270 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 0.1341s | 7.4488ms | 134.2506 Ops/s | 135.5788 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 17.2711ms | 14.8093ms | 67.5253 Ops/s | 64.7254 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 2.4231ms | 1.1622ms | 860.4394 Ops/s | 923.4219 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.1170s | 7.0504ms | 141.8363 Ops/s | 142.1833 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 16.8077ms | 14.5546ms | 68.7067 Ops/s | 56.5662 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 7.4339ms | 1.2666ms | 789.5193 Ops/s | 938.2037 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.1183s | 9.6925ms | 103.1722 Ops/s | 133.9202 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 17.3496ms | 14.9293ms | 66.9822 Ops/s | 63.5687 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 7.2612ms | 1.5622ms | 640.1262 Ops/s | 699.1795 Ops/s |
Happy to test when ready! |
@skandermoalla It should be ready for a test Here is my script: import torch.utils.benchmark
from torchrl.envs import GymEnv, ParallelEnv, SerialEnv
if __name__ == "__main__":
for mp_start_method in [None, "spawn", "fork"]:
for worker_device in ["cpu"]:
for main_device in ["mps", "cpu"]:
print(mp_start_method, worker_device, main_device)
env = ParallelEnv(
16,
lambda: GymEnv("CartPole-v1", device=worker_device),
device=torch.device(main_device),
non_blocking=True,
mp_start_method=mp_start_method)
env.rollout(2)
print(torch.utils.benchmark.Timer("env.rollout(1000, break_when_any_done=False)",
globals=globals()).adaptive_autorange())
torch.manual_seed(0)
env.set_seed(0)
rollout = env.rollout(1000, break_when_any_done=False)
act0 = rollout["observation"].squeeze()
torch.manual_seed(0)
env.set_seed(0)
rollout = env.rollout(1000, break_when_any_done=False)
act1 = rollout["observation"].squeeze()
torch.testing.assert_close(act0, act1)
env.close()
del env I get a throughput of 11K fps on cpu-cpu-"fork" for Pendulum and 9.1K for cartpole (I get to 16K with 32 workers!) One thing I'm considering is to provide a transform that maps only certain tensors to a given device: if your policy is on CUDA you may not need to move everything to CUDA, just the input keys. Since we now support partial device (i.e. envs with no proper device) this could be a solution to mimize data movements. |
@vmoens You're using the Gymnasium 1.0 alpha? I get this with the latest stable Gymnasium, which I think has been changed in 1.0. torchrl ❯ python vmoens.py
/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/_pytree.py:147: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
register_pytree_node(
None cpu mps
Traceback (most recent call last):
File "/Users/moalla/projects/open-source/TorchRL/tests/mps/vmoens.py", line 11, in <module>
env = ParallelEnv(
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 125, in __call__
return super().__call__(*args, **kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 160, in __call__
instance: EnvBase = super().__call__(*args, **kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 335, in __init__
self._get_metadata(create_env_fn, create_env_kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 449, in _get_metadata
meta_data = get_env_metadata(create_env_fn[0], create_env_kwargs[0])
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/env_creator.py", line 180, in get_env_metadata
env = env_or_creator(**kwargs)
File "/Users/moalla/projects/open-source/TorchRL/tests/mps/vmoens.py", line 13, in <lambda>
lambda: GymEnv("CartPole-v1", device=worker_device),
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/libs/gym.py", line 543, in __call__
instance: GymWrapper = super().__call__(*args, **kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 160, in __call__
instance: EnvBase = super().__call__(*args, **kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/libs/gym.py", line 1277, in __init__
super().__init__(**kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/libs/gym.py", line 731, in __init__
super().__init__(**kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2930, in __init__
self._init_env() # runs all the steps to have a ready-to-use env
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/libs/gym.py", line 1139, in _init_env
self.reset()
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2070, in reset
tensordict_reset = self._reset(tensordict, **kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/libs/gym.py", line 1165, in _reset
return super()._reset(tensordict, **kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/gym_like.py", line 378, in _reset
self._sync_device()
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2945, in _sync_device
return self._sync_orig_device
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2967, in __getattr__
return getattr(env, attr)
File "/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py", line 310, in __getattr__
raise AttributeError(f"accessing private attribute '{name}' is prohibited")
AttributeError: accessing private attribute '_sync_orig_device' is prohibited
~/projects/open-source/TorchRL/tests/mps main*
torchrl ❯ pip list | grep gymnasium
gymnasium 0.29.1 |
Still doesn't work with 1.0 actually. torchrl ❯ python vmoens.py
/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/_pytree.py:147: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
register_pytree_node(
None cpu mps
Traceback (most recent call last):
File "/Users/moalla/projects/open-source/TorchRL/tests/mps/vmoens.py", line 11, in <module>
env = ParallelEnv(
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 125, in __call__
return super().__call__(*args, **kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 160, in __call__
instance: EnvBase = super().__call__(*args, **kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 335, in __init__
self._get_metadata(create_env_fn, create_env_kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 449, in _get_metadata
meta_data = get_env_metadata(create_env_fn[0], create_env_kwargs[0])
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/env_creator.py", line 180, in get_env_metadata
env = env_or_creator(**kwargs)
File "/Users/moalla/projects/open-source/TorchRL/tests/mps/vmoens.py", line 13, in <lambda>
lambda: GymEnv("CartPole-v1", device=worker_device),
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/libs/gym.py", line 543, in __call__
instance: GymWrapper = super().__call__(*args, **kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 160, in __call__
instance: EnvBase = super().__call__(*args, **kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/libs/gym.py", line 1277, in __init__
super().__init__(**kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/libs/gym.py", line 731, in __init__
super().__init__(**kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2930, in __init__
self._init_env() # runs all the steps to have a ready-to-use env
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/libs/gym.py", line 1139, in _init_env
self.reset()
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2070, in reset
tensordict_reset = self._reset(tensordict, **kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/libs/gym.py", line 1165, in _reset
return super()._reset(tensordict, **kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/gym_like.py", line 378, in _reset
self._sync_device()
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2945, in _sync_device
return self._sync_orig_device
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2967, in __getattr__
return getattr(env, attr)
AttributeError: 'TimeLimit' object has no attribute '_sync_orig_device'
~/projects/open-source/TorchRL/tests/mps main*
torchrl ❯ pip list | grep gymnasium
gymnasium 1.0.0a1 |
I guess Here is what I have: torchrl 0.4.0+404805a /Users/moalla/projects/open-source/TorchRL/repos/rl
tensordict 0.4.0+2dc0285 /Users/moalla/projects/open-source/TorchRL/repos/tensordict |
Okay, something broke between commit On torchrl ❯ PYTHONWARNINGS="ignore" python vmoens.py
None cpu mps
<torch.utils.benchmark.utils.common.Measurement object at 0x132852710>
env.rollout(1000, break_when_any_done=False)
Median: 2.76 s
IQR: 0.01 s (2.75 to 2.76)
4 measurements, 1 runs per measurement, 1 thread
None cpu cpu
<torch.utils.benchmark.utils.common.Measurement object at 0x127b9c850>
env.rollout(1000, break_when_any_done=False)
Median: 1.85 s
IQR: 0.01 s (1.85 to 1.85)
4 measurements, 1 runs per measurement, 1 thread
spawn cpu mps
<torch.utils.benchmark.utils.common.Measurement object at 0x1409a1150>
env.rollout(1000, break_when_any_done=False)
Median: 2.77 s
IQR: 0.03 s (2.76 to 2.79)
4 measurements, 1 runs per measurement, 1 thread
spawn cpu cpu
<torch.utils.benchmark.utils.common.Measurement object at 0x127bcf610>
env.rollout(1000, break_when_any_done=False)
Median: 1.86 s
IQR: 0.02 s (1.86 to 1.88)
4 measurements, 1 runs per measurement, 1 thread
fork cpu mps
<torch.utils.benchmark.utils.common.Measurement object at 0x1328531f0>
env.rollout(1000, break_when_any_done=False)
Median: 2.79 s
IQR: 0.01 s (2.78 to 2.79)
4 measurements, 1 runs per measurement, 1 thread
fork cpu cpu
<torch.utils.benchmark.utils.common.Measurement object at 0x104d87eb0>
env.rollout(1000, break_when_any_done=False)
Median: 1.94 s
IQR: 0.03 s (1.92 to 1.95)
4 measurements, 1 runs per measurement, 1 thread
|
With The previous issue in #1864 (comment) is solved. However, I found another bug (which may be coming from PyTorch). from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import (
ExplorationType,
ParallelEnv,
TransformedEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor
max_step = 10
env_id = "CartPole-v1"
device = "mps"
def build_actor(env):
return ProbabilisticActor(
module=TensorDictModule(
nn.LazyLinear(env.action_spec.space.n),
in_keys=["observation"],
out_keys=["logits"],
),
spec=env.action_spec,
distribution_class=OneHotCategorical,
in_keys=["logits"],
default_interaction_type=ExplorationType.RANDOM,
)
if __name__ == "__main__":
env = ParallelEnv(4, lambda: GymEnv(env_id), device=device)
# Changing to serial env removes the problem.
env = TransformedEnv(env)
# Or removing the transformed env removes the problem.
policy_module = build_actor(env)
policy_module.to(device)
policy_module(env.reset())
for i in range(10):
batches = env.rollout(max_step + 3, policy=policy_module, break_when_any_done=False) torchrl ❯ python issue_transformed_env.py
/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/_pytree.py:147: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
register_pytree_node(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/torch/nn/modules/lazy.py:181: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
warnings.warn('Lazy modules are a new feature under heavy development '
/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/_pytree.py:147: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
register_pytree_node(
/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/_pytree.py:147: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
register_pytree_node(
/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/_pytree.py:147: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
register_pytree_node(
/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/_pytree.py:147: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
register_pytree_node(
/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/utils.py:166: UserWarning: The expected key set and actual key set differ. This will work but with a slower throughput than when the specs match exactly the actual key set in the data. Expected - Actual keys=set(),
Actual - Expected keys={'logits'}.
warnings.warn(
Traceback (most recent call last):
File "/Users/moalla/projects/open-source/TorchRL/tests/mps/issue_transformed_env.py", line 40, in <module>
batches = env.rollout(max_step + 3, policy=policy_module, break_when_any_done=False)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2504, in rollout
tensordicts = self._rollout_nonstop(**kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2606, in _rollout_nonstop
tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2662, in step_and_maybe_reset
tensordict_ = self.maybe_reset(tensordict_)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2707, in maybe_reset
tensordict = self.reset(tensordict)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2069, in reset
tensordict_reset = self._reset(tensordict, **kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/transforms/transforms.py", line 798, in _reset
tensordict_reset = self.base_env._reset(tensordict=tensordict, **kwargs)
File "/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 60, in decorated_fun
return fun(self, *args, **kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 1486, in _reset
self.shared_tensordicts[i].update_(
File "/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/base.py", line 2859, in update_
self._apply_nest(
File "/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/_td.py", line 761, in _apply_nest
item_trsf = fn(unravel_key(prefix + (key,)), item, *_others)
File "/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/base.py", line 2843, in inplace_update
dest.copy_(source, non_blocking=non_blocking)
RuntimeError: destOffset % 4 == 0 INTERNAL ASSERT FAILED at "/Users/runner/work/_temp/anaconda/conda-bld/pytorch_1704987091277/work/aten/src/ATen/native/mps/operations/Copy.mm":107, please report a bug to PyTorch. Unaligned blit request |
Yeah that was an error on my side, I pushed a commit with a typo. Should be fixed now |
Cool. Works with the latest commit.
None cpu mps
<torch.utils.benchmark.utils.common.Measurement object at 0x14bda42e0>
env.rollout(1000, break_when_any_done=False)
Median: 2.97 s
IQR: 0.09 s (2.93 to 3.02)
4 measurements, 1 runs per measurement, 1 thread
None cpu cpu
<torch.utils.benchmark.utils.common.Measurement object at 0x1509a51e0>
env.rollout(1000, break_when_any_done=False)
Median: 1.93 s
IQR: 0.02 s (1.92 to 1.95)
4 measurements, 1 runs per measurement, 1 thread
spawn cpu mps
<torch.utils.benchmark.utils.common.Measurement object at 0x13299c850>
env.rollout(1000, break_when_any_done=False)
Median: 2.95 s
IQR: 0.04 s (2.93 to 2.97)
4 measurements, 1 runs per measurement, 1 thread
spawn cpu cpu
<torch.utils.benchmark.utils.common.Measurement object at 0x14bd47e80>
env.rollout(1000, break_when_any_done=False)
Median: 1.99 s
IQR: 0.06 s (1.98 to 2.05)
4 measurements, 1 runs per measurement, 1 thread
fork cpu mps
<torch.utils.benchmark.utils.common.Measurement object at 0x150937460>
env.rollout(1000, break_when_any_done=False)
Median: 3.19 s
IQR: 0.05 s (3.16 to 3.21)
4 measurements, 1 runs per measurement, 1 thread
fork cpu cpu
<torch.utils.benchmark.utils.common.Measurement object at 0x14ba7c190>
env.rollout(1000, break_when_any_done=False)
Median: 2.25 s
IQR: 0.07 s (2.21 to 2.28)
4 measurements, 1 runs per measurement, 1 thread
Thanks for this! With 1 and 2 I guess the problem is solved. I'm not sure if 3 is in the scope of this PR. |
3 is still a barrier to using MPS in practice with TorchRL though, so I'd be good to address it soon too. |
…more # Conflicts: # torchrl/envs/utils.py
Unfortunately this one is known and still not solved: Upvote, comment and make yourself heard on that issue! |
elif self.device is not None and self.device.type == "mps": | ||
# copy_ fails when moving mps->cpu using copy_ | ||
# in some cases when a view of an mps tensor is used. | ||
# We know the shared tensors are not MPS, so we can | ||
# safely assume that the shared tensors are on cpu | ||
tensordict_ = tensordict_.to("cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@skandermoalla this will make the class usable, though it's not optimal.
this _reset will not be used much in practice and the info copied will usually be lightweight (e.g. "_reset" info etc) so that should not slow things down too much
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this!
warnings.warn( | ||
"Your wrapper was not given a device. Currently, this " | ||
"value will default to 'cpu'. From v0.5 it will " | ||
"default to `None`. With a device of None, no device casting " | ||
"is performed and the resulting tensordicts are deviceless. " | ||
"Please set your device accordingly.", | ||
category=DeprecationWarning, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Asking for feedback on this
@matteobettini @skandermoalla @albertbou92 @BY571 @alexander-soare @dennismalmgren
I know this is going to be massively annoying but we need to deprecate the default device of envs to CPU.
First it's arbitrary (an env can be on cuda by default) and restrictive (we now support envs that have parts of their input/output on cpu and part on another device).
The only option I see to avoid having this warning popping up all the time is to make an environment variable that you can set to use the non-deprecated behaviour. I personally don't like relying on environment variables...
Wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I understand the problem.
Here's my understanding:
We want to deprecate the default device of an env.
If users don't want to do casting, with the current version they can already just not set the device (or set it to None?). This is valid behavior.
However, this will raise a (useless) warning until v0.5 and we want to avoid that.
Nevertheless, we want the warning to be raised when the user is not aware of the deprecation and just didn't put a device.
This is a new feature, so anyway users will be changing their code when adopting it, so an environment variable is not necessary I'd say. They can set an optional keyword argument no_default_device=True
until v0.5.
"Please set your device accordingly." was a bit confusing. On my first read I thought that meant I have to set a device explicitly and None is not a device. On my second read, I think it just means I have to be careful when picking between a device and None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
However, this will raise a (useless) warning until v0.5 and we want to avoid that.
Nevertheless, we want the warning to be raised when the user is not aware of the deprecation and just didn't put a device.
Yes so the only thing I want is that people who used to do assert env.device == "cpu"
don't see their code breaking after 0.4, but they know that it will break in v0.5.
So I take it that we should keep the warning then, right? I should maybe use torchrl logger to allow people to disable warnings all at once.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes so the only thing I want is that people who used to do assert env.device == "cpu" don't see their code breaking after 0.4, but they know that it will break in v0.5.
This is what the current code does. So yes keep the warning.
For the early adopters, they will change their code anyway so it's fine to add an argument to suppress the warning IMO.
(cherry picked from commit 2b95b41)
@skandermoalla @dennismalmgren @albanD @alexander-soare
To test this (replace mps with cuda if available):
The final assertion fails on main but succeeds here.
The runtime is 2x higher with the sync on my machine (but the results are accurate)