Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Robust sync for non_blocking=True #2034

Merged
merged 24 commits into from
Mar 26, 2024
Merged

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Mar 21, 2024

@skandermoalla @dennismalmgren @albanD @alexander-soare

To test this (replace mps with cuda if available):

import torch
from tensordict.nn import TensorDictModule

from torchrl.envs import GymEnv
from torchrl.modules import MLP
import torch.utils.benchmark

if __name__ == "__main__":
    env = GymEnv("Pendulum-v1", device="cpu")
    policy = TensorDictModule(
        MLP(
            env.observation_spec["observation"].shape[-1],
            env.action_spec.shape[-1],
            num_cells=(64, 64),
            device="mps",
        ),
        in_keys=["observation"], out_keys=["action"])

    print(torch.utils.benchmark.Timer("env.rollout(1000, policy=policy, auto_cast_to_device=True, break_when_any_done=False)", globals=globals()).adaptive_autorange())

    torch.manual_seed(0)
    env.set_seed(0)
    rollout = env.rollout(1000, policy=policy, auto_cast_to_device=True, break_when_any_done=False)
    act0 = rollout["action"].squeeze()
    torch.manual_seed(0)
    env.set_seed(0)
    rollout = env.rollout(1000, policy=policy, auto_cast_to_device=True, break_when_any_done=False)
    act1 = rollout["action"].squeeze()
    torch.testing.assert_close(act0, act1)

The final assertion fails on main but succeeds here.
The runtime is 2x higher with the sync on my machine (but the results are accurate)

Copy link

pytorch-bot bot commented Mar 21, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2034

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 2325a90 with merge base 247ed6e (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 21, 2024
@vmoens vmoens marked this pull request as draft March 21, 2024 17:55
@vmoens vmoens added the bug Something isn't working label Mar 21, 2024
Copy link

github-actions bot commented Mar 21, 2024

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 91. Improved: $\large\color{#35bf28}6$. Worsened: $\large\color{#d91a1a}5$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_single 53.3812ms 52.2538ms 19.1374 Ops/s 18.0769 Ops/s $\textbf{\color{#35bf28}+5.87\%}$
test_sync 47.9155ms 32.2791ms 30.9798 Ops/s 30.1361 Ops/s $\color{#35bf28}+2.80\%$
test_async 62.0251ms 27.5733ms 36.2670 Ops/s 36.8335 Ops/s $\color{#d91a1a}-1.54\%$
test_simple 0.4056s 0.3469s 2.8828 Ops/s 2.8572 Ops/s $\color{#35bf28}+0.90\%$
test_transformed 0.5377s 0.4787s 2.0892 Ops/s 2.0005 Ops/s $\color{#35bf28}+4.43\%$
test_serial 1.2483s 1.1885s 0.8414 Ops/s 0.8267 Ops/s $\color{#35bf28}+1.78\%$
test_parallel 1.0951s 1.0241s 0.9765 Ops/s 0.9974 Ops/s $\color{#d91a1a}-2.09\%$
test_step_mdp_speed[True-True-True-True-True] 0.1779ms 20.9846μs 47.6539 KOps/s 47.5760 KOps/s $\color{#35bf28}+0.16\%$
test_step_mdp_speed[True-True-True-True-False] 42.5900μs 12.8893μs 77.5837 KOps/s 76.6016 KOps/s $\color{#35bf28}+1.28\%$
test_step_mdp_speed[True-True-True-False-True] 37.9200μs 12.4010μs 80.6385 KOps/s 80.5848 KOps/s $\color{#35bf28}+0.07\%$
test_step_mdp_speed[True-True-True-False-False] 56.9160μs 7.5341μs 132.7294 KOps/s 131.4300 KOps/s $\color{#35bf28}+0.99\%$
test_step_mdp_speed[True-True-False-True-True] 57.0570μs 22.2721μs 44.8992 KOps/s 44.5525 KOps/s $\color{#35bf28}+0.78\%$
test_step_mdp_speed[True-True-False-True-False] 49.9840μs 13.9849μs 71.5057 KOps/s 70.0538 KOps/s $\color{#35bf28}+2.07\%$
test_step_mdp_speed[True-True-False-False-True] 39.8350μs 13.5621μs 73.7348 KOps/s 73.0396 KOps/s $\color{#35bf28}+0.95\%$
test_step_mdp_speed[True-True-False-False-False] 43.2710μs 8.7474μs 114.3194 KOps/s 112.8041 KOps/s $\color{#35bf28}+1.34\%$
test_step_mdp_speed[True-False-True-True-True] 70.6820μs 23.5723μs 42.4226 KOps/s 42.0586 KOps/s $\color{#35bf28}+0.87\%$
test_step_mdp_speed[True-False-True-True-False] 74.6430μs 15.2511μs 65.5689 KOps/s 63.8310 KOps/s $\color{#35bf28}+2.72\%$
test_step_mdp_speed[True-False-True-False-True] 55.4640μs 13.5681μs 73.7025 KOps/s 73.6789 KOps/s $\color{#35bf28}+0.03\%$
test_step_mdp_speed[True-False-True-False-False] 35.0460μs 8.8790μs 112.6259 KOps/s 113.0591 KOps/s $\color{#d91a1a}-0.38\%$
test_step_mdp_speed[True-False-False-True-True] 61.7960μs 24.7389μs 40.4222 KOps/s 40.4233 KOps/s $-0.00\%$
test_step_mdp_speed[True-False-False-True-False] 52.7380μs 16.5347μs 60.4790 KOps/s 59.3943 KOps/s $\color{#35bf28}+1.83\%$
test_step_mdp_speed[True-False-False-False-True] 54.4820μs 14.7809μs 67.6550 KOps/s 68.1479 KOps/s $\color{#d91a1a}-0.72\%$
test_step_mdp_speed[True-False-False-False-False] 43.4420μs 9.9843μs 100.1574 KOps/s 99.2767 KOps/s $\color{#35bf28}+0.89\%$
test_step_mdp_speed[False-True-True-True-True] 60.5930μs 23.6723μs 42.2435 KOps/s 41.4740 KOps/s $\color{#35bf28}+1.86\%$
test_step_mdp_speed[False-True-True-True-False] 51.6860μs 15.3864μs 64.9926 KOps/s 63.6592 KOps/s $\color{#35bf28}+2.09\%$
test_step_mdp_speed[False-True-True-False-True] 45.7050μs 15.7093μs 63.6566 KOps/s 63.3692 KOps/s $\color{#35bf28}+0.45\%$
test_step_mdp_speed[False-True-True-False-False] 43.4010μs 9.9342μs 100.6622 KOps/s 100.1730 KOps/s $\color{#35bf28}+0.49\%$
test_step_mdp_speed[False-True-False-True-True] 48.5510μs 25.0053μs 39.9916 KOps/s 39.2117 KOps/s $\color{#35bf28}+1.99\%$
test_step_mdp_speed[False-True-False-True-False] 46.3870μs 16.6133μs 60.1926 KOps/s 59.0957 KOps/s $\color{#35bf28}+1.86\%$
test_step_mdp_speed[False-True-False-False-True] 49.6630μs 16.9011μs 59.1678 KOps/s 59.1020 KOps/s $\color{#35bf28}+0.11\%$
test_step_mdp_speed[False-True-False-False-False] 46.8880μs 11.1301μs 89.8461 KOps/s 89.5452 KOps/s $\color{#35bf28}+0.34\%$
test_step_mdp_speed[False-False-True-True-True] 63.3090μs 26.0027μs 38.4576 KOps/s 37.0821 KOps/s $\color{#35bf28}+3.71\%$
test_step_mdp_speed[False-False-True-True-False] 53.4600μs 17.7800μs 56.2431 KOps/s 53.3585 KOps/s $\textbf{\color{#35bf28}+5.41\%}$
test_step_mdp_speed[False-False-True-False-True] 67.7570μs 17.0405μs 58.6837 KOps/s 58.5954 KOps/s $\color{#35bf28}+0.15\%$
test_step_mdp_speed[False-False-True-False-False] 44.6640μs 11.1831μs 89.4204 KOps/s 88.9091 KOps/s $\color{#35bf28}+0.58\%$
test_step_mdp_speed[False-False-False-True-True] 69.3400μs 26.9193μs 37.1480 KOps/s 36.5243 KOps/s $\color{#35bf28}+1.71\%$
test_step_mdp_speed[False-False-False-True-False] 54.5930μs 18.7920μs 53.2143 KOps/s 52.2920 KOps/s $\color{#35bf28}+1.76\%$
test_step_mdp_speed[False-False-False-False-True] 46.2070μs 17.8902μs 55.8964 KOps/s 55.9680 KOps/s $\color{#d91a1a}-0.13\%$
test_step_mdp_speed[False-False-False-False-False] 45.8760μs 12.2180μs 81.8463 KOps/s 81.5936 KOps/s $\color{#35bf28}+0.31\%$
test_values[generalized_advantage_estimate-True-True] 12.3607ms 9.5962ms 104.2079 Ops/s 105.8152 Ops/s $\color{#d91a1a}-1.52\%$
test_values[vec_generalized_advantage_estimate-True-True] 39.2443ms 33.8291ms 29.5604 Ops/s 28.2883 Ops/s $\color{#35bf28}+4.50\%$
test_values[td0_return_estimate-False-False] 0.2455ms 0.1706ms 5.8618 KOps/s 5.7371 KOps/s $\color{#35bf28}+2.18\%$
test_values[td1_return_estimate-False-False] 26.4533ms 23.8980ms 41.8445 Ops/s 42.0094 Ops/s $\color{#d91a1a}-0.39\%$
test_values[vec_td1_return_estimate-False-False] 35.3396ms 33.5788ms 29.7807 Ops/s 28.2830 Ops/s $\textbf{\color{#35bf28}+5.30\%}$
test_values[td_lambda_return_estimate-True-False] 37.4173ms 34.1682ms 29.2670 Ops/s 29.8069 Ops/s $\color{#d91a1a}-1.81\%$
test_values[vec_td_lambda_return_estimate-True-False] 35.0090ms 33.5188ms 29.8340 Ops/s 28.2407 Ops/s $\textbf{\color{#35bf28}+5.64\%}$
test_gae_speed[generalized_advantage_estimate-False-1-512] 9.6570ms 8.3257ms 120.1103 Ops/s 121.2668 Ops/s $\color{#d91a1a}-0.95\%$
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 2.3614ms 2.0689ms 483.3427 Ops/s 514.6278 Ops/s $\textbf{\color{#d91a1a}-6.08\%}$
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] 0.4112ms 0.3533ms 2.8306 KOps/s 2.8742 KOps/s $\color{#d91a1a}-1.51\%$
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] 50.5891ms 47.7565ms 20.9395 Ops/s 23.9934 Ops/s $\textbf{\color{#d91a1a}-12.73\%}$
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 3.5888ms 3.0616ms 326.6311 Ops/s 329.5421 Ops/s $\color{#d91a1a}-0.88\%$
test_dqn_speed 7.8260ms 1.3699ms 729.9610 Ops/s 669.3932 Ops/s $\textbf{\color{#35bf28}+9.05\%}$
test_ddpg_speed 3.6461ms 2.7266ms 366.7597 Ops/s 367.1627 Ops/s $\color{#d91a1a}-0.11\%$
test_sac_speed 10.4166ms 8.4116ms 118.8829 Ops/s 121.2214 Ops/s $\color{#d91a1a}-1.93\%$
test_redq_speed 94.9047ms 14.3795ms 69.5433 Ops/s 75.5344 Ops/s $\textbf{\color{#d91a1a}-7.93\%}$
test_redq_deprec_speed 14.9376ms 13.7393ms 72.7841 Ops/s 73.3106 Ops/s $\color{#d91a1a}-0.72\%$
test_td3_speed 8.7984ms 8.3424ms 119.8693 Ops/s 120.9407 Ops/s $\color{#d91a1a}-0.89\%$
test_cql_speed 37.6596ms 36.1955ms 27.6278 Ops/s 27.4600 Ops/s $\color{#35bf28}+0.61\%$
test_a2c_speed 8.1571ms 7.3789ms 135.5220 Ops/s 134.6131 Ops/s $\color{#35bf28}+0.68\%$
test_ppo_speed 8.6615ms 7.7069ms 129.7543 Ops/s 128.5031 Ops/s $\color{#35bf28}+0.97\%$
test_reinforce_speed 7.2158ms 6.6299ms 150.8308 Ops/s 152.2161 Ops/s $\color{#d91a1a}-0.91\%$
test_iql_speed 34.5493ms 33.0329ms 30.2729 Ops/s 30.9841 Ops/s $\color{#d91a1a}-2.30\%$
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 2.8972ms 2.2395ms 446.5236 Ops/s 444.9158 Ops/s $\color{#35bf28}+0.36\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 1.4089ms 0.4939ms 2.0247 KOps/s 2.0215 KOps/s $\color{#35bf28}+0.16\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 0.7395ms 0.4676ms 2.1387 KOps/s 2.1325 KOps/s $\color{#35bf28}+0.29\%$
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 3.3491ms 2.2466ms 445.1155 Ops/s 428.7314 Ops/s $\color{#35bf28}+3.82\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 0.7539ms 0.4879ms 2.0496 KOps/s 1.9979 KOps/s $\color{#35bf28}+2.59\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 3.6852ms 0.4682ms 2.1360 KOps/s 2.0642 KOps/s $\color{#35bf28}+3.48\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] 2.3341ms 1.2439ms 803.8918 Ops/s 815.7358 Ops/s $\color{#d91a1a}-1.45\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] 4.6724ms 1.1656ms 857.9515 Ops/s 868.8682 Ops/s $\color{#d91a1a}-1.26\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 3.5226ms 2.4457ms 408.8833 Ops/s 405.1121 Ops/s $\color{#35bf28}+0.93\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 0.8853ms 0.6190ms 1.6155 KOps/s 1.6165 KOps/s $\color{#d91a1a}-0.06\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 0.8152ms 0.5851ms 1.7092 KOps/s 1.6744 KOps/s $\color{#35bf28}+2.08\%$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 3.3197ms 2.2710ms 440.3430 Ops/s 433.4284 Ops/s $\color{#35bf28}+1.60\%$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 0.8321ms 0.5008ms 1.9968 KOps/s 2.0101 KOps/s $\color{#d91a1a}-0.66\%$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 0.6422ms 0.4767ms 2.0979 KOps/s 2.1251 KOps/s $\color{#d91a1a}-1.28\%$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 3.4052ms 2.3545ms 424.7185 Ops/s 408.5122 Ops/s $\color{#35bf28}+3.97\%$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 1.0240ms 0.4972ms 2.0111 KOps/s 1.9995 KOps/s $\color{#35bf28}+0.58\%$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 0.7533ms 0.4750ms 2.1052 KOps/s 2.0758 KOps/s $\color{#35bf28}+1.42\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 3.5436ms 2.4355ms 410.5885 Ops/s 400.5543 Ops/s $\color{#35bf28}+2.51\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 0.9506ms 0.6238ms 1.6032 KOps/s 1.5984 KOps/s $\color{#35bf28}+0.30\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 0.7556ms 0.5944ms 1.6825 KOps/s 1.6648 KOps/s $\color{#35bf28}+1.06\%$
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 0.1193s 7.8829ms 126.8566 Ops/s 127.6588 Ops/s $\color{#d91a1a}-0.63\%$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 14.7851ms 12.2329ms 81.7470 Ops/s 82.5106 Ops/s $\color{#d91a1a}-0.93\%$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 1.2416ms 1.0454ms 956.5410 Ops/s 947.9108 Ops/s $\color{#35bf28}+0.91\%$
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 0.1122s 5.8657ms 170.4839 Ops/s 171.4537 Ops/s $\color{#d91a1a}-0.57\%$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 15.2906ms 12.4909ms 80.0581 Ops/s 71.2459 Ops/s $\textbf{\color{#35bf28}+12.37\%}$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 1.6396ms 1.1473ms 871.6401 Ops/s 951.7578 Ops/s $\textbf{\color{#d91a1a}-8.42\%}$
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 0.1136s 8.3589ms 119.6327 Ops/s 166.1782 Ops/s $\textbf{\color{#d91a1a}-28.01\%}$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 15.2635ms 12.7276ms 78.5697 Ops/s 81.2335 Ops/s $\color{#d91a1a}-3.28\%$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 1.9696ms 1.4017ms 713.4371 Ops/s 732.5273 Ops/s $\color{#d91a1a}-2.61\%$

Copy link

github-actions bot commented Mar 21, 2024

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of GPU Benchmark Tests

Total Benchmarks: 94. Improved: $\large\color{#35bf28}10$. Worsened: $\large\color{#d91a1a}5$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_single 0.1014s 97.8544ms 10.2193 Ops/s 9.5566 Ops/s $\textbf{\color{#35bf28}+6.93\%}$
test_sync 86.0001ms 85.6991ms 11.6687 Ops/s 11.5169 Ops/s $\color{#35bf28}+1.32\%$
test_async 0.1594s 69.8134ms 14.3239 Ops/s 14.1488 Ops/s $\color{#35bf28}+1.24\%$
test_single_pixels 0.1063s 0.1059s 9.4391 Ops/s 9.2682 Ops/s $\color{#35bf28}+1.84\%$
test_sync_pixels 67.6936ms 65.7022ms 15.2202 Ops/s 15.2665 Ops/s $\color{#d91a1a}-0.30\%$
test_async_pixels 0.1232s 55.5644ms 17.9971 Ops/s 18.1080 Ops/s $\color{#d91a1a}-0.61\%$
test_simple 0.7305s 0.6634s 1.5074 Ops/s 1.4933 Ops/s $\color{#35bf28}+0.94\%$
test_transformed 0.9271s 0.8621s 1.1600 Ops/s 1.1277 Ops/s $\color{#35bf28}+2.86\%$
test_serial 2.0956s 2.0272s 0.4933 Ops/s 0.4803 Ops/s $\color{#35bf28}+2.70\%$
test_parallel 1.7995s 1.7449s 0.5731 Ops/s 0.5696 Ops/s $\color{#35bf28}+0.62\%$
test_step_mdp_speed[True-True-True-True-True] 99.7420μs 32.9167μs 30.3797 KOps/s 30.4911 KOps/s $\color{#d91a1a}-0.37\%$
test_step_mdp_speed[True-True-True-True-False] 36.7910μs 19.4978μs 51.2879 KOps/s 50.4947 KOps/s $\color{#35bf28}+1.57\%$
test_step_mdp_speed[True-True-True-False-True] 46.3100μs 18.6797μs 53.5341 KOps/s 54.2319 KOps/s $\color{#d91a1a}-1.29\%$
test_step_mdp_speed[True-True-True-False-False] 31.4100μs 11.0890μs 90.1795 KOps/s 89.4312 KOps/s $\color{#35bf28}+0.84\%$
test_step_mdp_speed[True-True-False-True-True] 55.2810μs 34.6099μs 28.8935 KOps/s 29.2397 KOps/s $\color{#d91a1a}-1.18\%$
test_step_mdp_speed[True-True-False-True-False] 45.0610μs 21.1649μs 47.2480 KOps/s 46.6955 KOps/s $\color{#35bf28}+1.18\%$
test_step_mdp_speed[True-True-False-False-True] 42.2810μs 20.4130μs 48.9884 KOps/s 49.3859 KOps/s $\color{#d91a1a}-0.80\%$
test_step_mdp_speed[True-True-False-False-False] 35.5300μs 12.8817μs 77.6295 KOps/s 75.9785 KOps/s $\color{#35bf28}+2.17\%$
test_step_mdp_speed[True-False-True-True-True] 63.8310μs 35.9247μs 27.8360 KOps/s 27.5604 KOps/s $\color{#35bf28}+1.00\%$
test_step_mdp_speed[True-False-True-True-False] 41.2800μs 23.3163μs 42.8885 KOps/s 42.3196 KOps/s $\color{#35bf28}+1.34\%$
test_step_mdp_speed[True-False-True-False-True] 38.3000μs 20.5465μs 48.6701 KOps/s 49.2037 KOps/s $\color{#d91a1a}-1.08\%$
test_step_mdp_speed[True-False-True-False-False] 33.7210μs 12.9322μs 77.3261 KOps/s 74.9765 KOps/s $\color{#35bf28}+3.13\%$
test_step_mdp_speed[True-False-False-True-True] 64.9610μs 38.2332μs 26.1553 KOps/s 26.1388 KOps/s $\color{#35bf28}+0.06\%$
test_step_mdp_speed[True-False-False-True-False] 54.6310μs 25.0452μs 39.9278 KOps/s 39.8992 KOps/s $\color{#35bf28}+0.07\%$
test_step_mdp_speed[True-False-False-False-True] 49.9910μs 22.2770μs 44.8893 KOps/s 44.3974 KOps/s $\color{#35bf28}+1.11\%$
test_step_mdp_speed[True-False-False-False-False] 33.7810μs 14.7767μs 67.6740 KOps/s 67.8884 KOps/s $\color{#d91a1a}-0.32\%$
test_step_mdp_speed[False-True-True-True-True] 54.1220μs 36.4078μs 27.4666 KOps/s 27.3672 KOps/s $\color{#35bf28}+0.36\%$
test_step_mdp_speed[False-True-True-True-False] 42.7510μs 23.3480μs 42.8302 KOps/s 41.5752 KOps/s $\color{#35bf28}+3.02\%$
test_step_mdp_speed[False-True-True-False-True] 52.4310μs 24.2670μs 41.2082 KOps/s 40.7793 KOps/s $\color{#35bf28}+1.05\%$
test_step_mdp_speed[False-True-True-False-False] 39.2310μs 14.8430μs 67.3718 KOps/s 67.6280 KOps/s $\color{#d91a1a}-0.38\%$
test_step_mdp_speed[False-True-False-True-True] 61.5010μs 38.5556μs 25.9366 KOps/s 26.0764 KOps/s $\color{#d91a1a}-0.54\%$
test_step_mdp_speed[False-True-False-True-False] 46.2710μs 25.3245μs 39.4874 KOps/s 39.0774 KOps/s $\color{#35bf28}+1.05\%$
test_step_mdp_speed[False-True-False-False-True] 46.1110μs 26.1657μs 38.2180 KOps/s 37.6717 KOps/s $\color{#35bf28}+1.45\%$
test_step_mdp_speed[False-True-False-False-False] 34.7300μs 16.7556μs 59.6817 KOps/s 60.0190 KOps/s $\color{#d91a1a}-0.56\%$
test_step_mdp_speed[False-False-True-True-True] 68.4110μs 40.1688μs 24.8949 KOps/s 25.2258 KOps/s $\color{#d91a1a}-1.31\%$
test_step_mdp_speed[False-False-True-True-False] 44.9910μs 27.1752μs 36.7983 KOps/s 37.2913 KOps/s $\color{#d91a1a}-1.32\%$
test_step_mdp_speed[False-False-True-False-True] 83.0020μs 26.3603μs 37.9358 KOps/s 37.8107 KOps/s $\color{#35bf28}+0.33\%$
test_step_mdp_speed[False-False-True-False-False] 42.8310μs 16.8230μs 59.4425 KOps/s 60.1464 KOps/s $\color{#d91a1a}-1.17\%$
test_step_mdp_speed[False-False-False-True-True] 61.7010μs 41.5335μs 24.0770 KOps/s 23.9607 KOps/s $\color{#35bf28}+0.49\%$
test_step_mdp_speed[False-False-False-True-False] 53.8810μs 28.6718μs 34.8775 KOps/s 34.2294 KOps/s $\color{#35bf28}+1.89\%$
test_step_mdp_speed[False-False-False-False-True] 47.7800μs 27.8960μs 35.8474 KOps/s 35.5418 KOps/s $\color{#35bf28}+0.86\%$
test_step_mdp_speed[False-False-False-False-False] 38.1210μs 18.5470μs 53.9170 KOps/s 54.2898 KOps/s $\color{#d91a1a}-0.69\%$
test_values[generalized_advantage_estimate-True-True] 24.0514ms 23.7075ms 42.1807 Ops/s 39.8152 Ops/s $\textbf{\color{#35bf28}+5.94\%}$
test_values[vec_generalized_advantage_estimate-True-True] 87.1013ms 3.3046ms 302.6097 Ops/s 313.7784 Ops/s $\color{#d91a1a}-3.56\%$
test_values[td0_return_estimate-False-False] 97.5320μs 63.7114μs 15.6958 KOps/s 15.3121 KOps/s $\color{#35bf28}+2.51\%$
test_values[td1_return_estimate-False-False] 53.3282ms 52.8591ms 18.9182 Ops/s 18.2509 Ops/s $\color{#35bf28}+3.66\%$
test_values[vec_td1_return_estimate-False-False] 2.1155ms 1.7569ms 569.1972 Ops/s 565.3292 Ops/s $\color{#35bf28}+0.68\%$
test_values[td_lambda_return_estimate-True-False] 86.0341ms 84.2941ms 11.8632 Ops/s 11.2388 Ops/s $\textbf{\color{#35bf28}+5.56\%}$
test_values[vec_td_lambda_return_estimate-True-False] 2.1186ms 1.7599ms 568.2181 Ops/s 565.3171 Ops/s $\color{#35bf28}+0.51\%$
test_gae_speed[generalized_advantage_estimate-False-1-512] 23.9720ms 23.7783ms 42.0551 Ops/s 39.7676 Ops/s $\textbf{\color{#35bf28}+5.75\%}$
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 0.9166ms 0.6981ms 1.4324 KOps/s 1.3685 KOps/s $\color{#35bf28}+4.67\%$
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] 0.7198ms 0.6446ms 1.5514 KOps/s 1.4692 KOps/s $\textbf{\color{#35bf28}+5.60\%}$
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] 1.4931ms 1.4516ms 688.8765 Ops/s 681.7670 Ops/s $\color{#35bf28}+1.04\%$
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 0.9518ms 0.6674ms 1.4983 KOps/s 1.4534 KOps/s $\color{#35bf28}+3.09\%$
test_dqn_speed 8.2818ms 1.4277ms 700.4214 Ops/s 690.3867 Ops/s $\color{#35bf28}+1.45\%$
test_ddpg_speed 3.4031ms 2.6670ms 374.9577 Ops/s 363.5583 Ops/s $\color{#35bf28}+3.14\%$
test_sac_speed 8.6774ms 7.9749ms 125.3933 Ops/s 122.1582 Ops/s $\color{#35bf28}+2.65\%$
test_redq_speed 10.8731ms 10.0068ms 99.9321 Ops/s 97.6933 Ops/s $\color{#35bf28}+2.29\%$
test_redq_deprec_speed 11.2814ms 10.6843ms 93.5953 Ops/s 88.6557 Ops/s $\textbf{\color{#35bf28}+5.57\%}$
test_td3_speed 8.5192ms 7.9164ms 126.3206 Ops/s 123.2286 Ops/s $\color{#35bf28}+2.51\%$
test_cql_speed 25.5961ms 24.7304ms 40.4360 Ops/s 36.4064 Ops/s $\textbf{\color{#35bf28}+11.07\%}$
test_a2c_speed 5.7827ms 5.5043ms 181.6757 Ops/s 177.9162 Ops/s $\color{#35bf28}+2.11\%$
test_ppo_speed 6.8785ms 5.8617ms 170.5994 Ops/s 167.5779 Ops/s $\color{#35bf28}+1.80\%$
test_reinforce_speed 5.2829ms 4.4497ms 224.7345 Ops/s 220.7949 Ops/s $\color{#35bf28}+1.78\%$
test_iql_speed 19.9920ms 19.3294ms 51.7347 Ops/s 51.2957 Ops/s $\color{#35bf28}+0.86\%$
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 3.0109ms 2.8586ms 349.8204 Ops/s 344.3401 Ops/s $\color{#35bf28}+1.59\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 0.6297ms 0.5317ms 1.8807 KOps/s 1.8395 KOps/s $\color{#35bf28}+2.24\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 4.2383ms 0.5133ms 1.9483 KOps/s 1.9396 KOps/s $\color{#35bf28}+0.45\%$
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 3.0620ms 2.8855ms 346.5606 Ops/s 341.0314 Ops/s $\color{#35bf28}+1.62\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 0.6239ms 0.5211ms 1.9190 KOps/s 1.8634 KOps/s $\color{#35bf28}+2.98\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 0.1128s 0.6022ms 1.6604 KOps/s 1.9580 KOps/s $\textbf{\color{#d91a1a}-15.19\%}$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] 1.5094ms 1.4045ms 711.9955 Ops/s 683.6311 Ops/s $\color{#35bf28}+4.15\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] 1.4845ms 1.3316ms 750.9801 Ops/s 719.1052 Ops/s $\color{#35bf28}+4.43\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 3.1987ms 2.9981ms 333.5391 Ops/s 331.7543 Ops/s $\color{#35bf28}+0.54\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 1.5450ms 0.6546ms 1.5277 KOps/s 1.4292 KOps/s $\textbf{\color{#35bf28}+6.89\%}$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 0.8553ms 0.6304ms 1.5864 KOps/s 1.5492 KOps/s $\color{#35bf28}+2.40\%$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 2.9615ms 2.8643ms 349.1195 Ops/s 342.4073 Ops/s $\color{#35bf28}+1.96\%$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 1.3285ms 0.5334ms 1.8748 KOps/s 1.8526 KOps/s $\color{#35bf28}+1.20\%$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 0.7129ms 0.5145ms 1.9438 KOps/s 1.9274 KOps/s $\color{#35bf28}+0.85\%$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 3.0832ms 2.9075ms 343.9372 Ops/s 342.7381 Ops/s $\color{#35bf28}+0.35\%$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 0.6510ms 0.5273ms 1.8965 KOps/s 1.8657 KOps/s $\color{#35bf28}+1.65\%$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 4.4694ms 0.5076ms 1.9702 KOps/s 1.9559 KOps/s $\color{#35bf28}+0.73\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 3.1057ms 3.0055ms 332.7223 Ops/s 327.3884 Ops/s $\color{#35bf28}+1.63\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 1.6531ms 0.6566ms 1.5231 KOps/s 1.4853 KOps/s $\color{#35bf28}+2.54\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 0.7896ms 0.6328ms 1.5802 KOps/s 1.5270 KOps/s $\color{#35bf28}+3.48\%$
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 0.1341s 7.4488ms 134.2506 Ops/s 135.5788 Ops/s $\color{#d91a1a}-0.98\%$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 17.2711ms 14.8093ms 67.5253 Ops/s 64.7254 Ops/s $\color{#35bf28}+4.33\%$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 2.4231ms 1.1622ms 860.4394 Ops/s 923.4219 Ops/s $\textbf{\color{#d91a1a}-6.82\%}$
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 0.1170s 7.0504ms 141.8363 Ops/s 142.1833 Ops/s $\color{#d91a1a}-0.24\%$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 16.8077ms 14.5546ms 68.7067 Ops/s 56.5662 Ops/s $\textbf{\color{#35bf28}+21.46\%}$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 7.4339ms 1.2666ms 789.5193 Ops/s 938.2037 Ops/s $\textbf{\color{#d91a1a}-15.85\%}$
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 0.1183s 9.6925ms 103.1722 Ops/s 133.9202 Ops/s $\textbf{\color{#d91a1a}-22.96\%}$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 17.3496ms 14.9293ms 66.9822 Ops/s 63.5687 Ops/s $\textbf{\color{#35bf28}+5.37\%}$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 7.2612ms 1.5622ms 640.1262 Ops/s 699.1795 Ops/s $\textbf{\color{#d91a1a}-8.45\%}$

@skandermoalla
Copy link
Contributor

Happy to test when ready!

@vmoens
Copy link
Contributor Author

vmoens commented Mar 22, 2024

@skandermoalla It should be ready for a test

Here is my script:

import torch.utils.benchmark

from torchrl.envs import GymEnv, ParallelEnv, SerialEnv

if __name__ == "__main__":
    for mp_start_method in [None, "spawn", "fork"]:
        for worker_device in ["cpu"]:
            for main_device in ["mps", "cpu"]:
                print(mp_start_method, worker_device, main_device)

                env = ParallelEnv(
                    16,
                    lambda: GymEnv("CartPole-v1", device=worker_device),
                    device=torch.device(main_device),
                    non_blocking=True,
                    mp_start_method=mp_start_method)

                env.rollout(2)
                print(torch.utils.benchmark.Timer("env.rollout(1000, break_when_any_done=False)",
                                                  globals=globals()).adaptive_autorange())

                torch.manual_seed(0)
                env.set_seed(0)
                rollout = env.rollout(1000, break_when_any_done=False)
                act0 = rollout["observation"].squeeze()

                torch.manual_seed(0)
                env.set_seed(0)
                rollout = env.rollout(1000, break_when_any_done=False)
                act1 = rollout["observation"].squeeze()

                torch.testing.assert_close(act0, act1)
                env.close()
                del env

I get a throughput of 11K fps on cpu-cpu-"fork" for Pendulum and 9.1K for cartpole (I get to 16K with 32 workers!)
About 2x slower with non-blocking on MPS so but it's expected.

One thing I'm considering is to provide a transform that maps only certain tensors to a given device: if your policy is on CUDA you may not need to move everything to CUDA, just the input keys. Since we now support partial device (i.e. envs with no proper device) this could be a solution to mimize data movements.

@skandermoalla
Copy link
Contributor

@vmoens You're using the Gymnasium 1.0 alpha? I get this with the latest stable Gymnasium, which I think has been changed in 1.0.

torchrl ❯ python vmoens.py          
/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/_pytree.py:147: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  register_pytree_node(
None cpu mps
Traceback (most recent call last):
  File "/Users/moalla/projects/open-source/TorchRL/tests/mps/vmoens.py", line 11, in <module>
    env = ParallelEnv(
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 125, in __call__
    return super().__call__(*args, **kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 160, in __call__
    instance: EnvBase = super().__call__(*args, **kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 335, in __init__
    self._get_metadata(create_env_fn, create_env_kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 449, in _get_metadata
    meta_data = get_env_metadata(create_env_fn[0], create_env_kwargs[0])
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/env_creator.py", line 180, in get_env_metadata
    env = env_or_creator(**kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/tests/mps/vmoens.py", line 13, in <lambda>
    lambda: GymEnv("CartPole-v1", device=worker_device),
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/libs/gym.py", line 543, in __call__
    instance: GymWrapper = super().__call__(*args, **kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 160, in __call__
    instance: EnvBase = super().__call__(*args, **kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/libs/gym.py", line 1277, in __init__
    super().__init__(**kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/libs/gym.py", line 731, in __init__
    super().__init__(**kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2930, in __init__
    self._init_env()  # runs all the steps to have a ready-to-use env
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/libs/gym.py", line 1139, in _init_env
    self.reset()
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2070, in reset
    tensordict_reset = self._reset(tensordict, **kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/libs/gym.py", line 1165, in _reset
    return super()._reset(tensordict, **kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/gym_like.py", line 378, in _reset
    self._sync_device()
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2945, in _sync_device
    return self._sync_orig_device
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2967, in __getattr__
    return getattr(env, attr)
  File "/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py", line 310, in __getattr__
    raise AttributeError(f"accessing private attribute '{name}' is prohibited")
AttributeError: accessing private attribute '_sync_orig_device' is prohibited

~/projects/open-source/TorchRL/tests/mps main*
torchrl ❯ pip list | grep gymnasium
gymnasium                  0.29.1

@skandermoalla
Copy link
Contributor

Still doesn't work with 1.0 actually.

torchrl ❯ python vmoens.py              
/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/_pytree.py:147: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  register_pytree_node(
None cpu mps
Traceback (most recent call last):
  File "/Users/moalla/projects/open-source/TorchRL/tests/mps/vmoens.py", line 11, in <module>
    env = ParallelEnv(
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 125, in __call__
    return super().__call__(*args, **kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 160, in __call__
    instance: EnvBase = super().__call__(*args, **kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 335, in __init__
    self._get_metadata(create_env_fn, create_env_kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 449, in _get_metadata
    meta_data = get_env_metadata(create_env_fn[0], create_env_kwargs[0])
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/env_creator.py", line 180, in get_env_metadata
    env = env_or_creator(**kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/tests/mps/vmoens.py", line 13, in <lambda>
    lambda: GymEnv("CartPole-v1", device=worker_device),
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/libs/gym.py", line 543, in __call__
    instance: GymWrapper = super().__call__(*args, **kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 160, in __call__
    instance: EnvBase = super().__call__(*args, **kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/libs/gym.py", line 1277, in __init__
    super().__init__(**kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/libs/gym.py", line 731, in __init__
    super().__init__(**kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2930, in __init__
    self._init_env()  # runs all the steps to have a ready-to-use env
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/libs/gym.py", line 1139, in _init_env
    self.reset()
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2070, in reset
    tensordict_reset = self._reset(tensordict, **kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/libs/gym.py", line 1165, in _reset
    return super()._reset(tensordict, **kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/gym_like.py", line 378, in _reset
    self._sync_device()
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2945, in _sync_device
    return self._sync_orig_device
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2967, in __getattr__
    return getattr(env, attr)
AttributeError: 'TimeLimit' object has no attribute '_sync_orig_device'

~/projects/open-source/TorchRL/tests/mps main*
torchrl ❯ pip list | grep gymnasium                                   
gymnasium                  1.0.0a1

@skandermoalla
Copy link
Contributor

I guess _sync_orig_device is not meant to be forwarded to the Gymnasium environment 🤔.

Here is what I have:

torchrl                    0.4.0+404805a /Users/moalla/projects/open-source/TorchRL/repos/rl
tensordict                 0.4.0+2dc0285 /Users/moalla/projects/open-source/TorchRL/repos/tensordict

@skandermoalla
Copy link
Contributor

skandermoalla commented Mar 23, 2024

Okay, something broke between commit aa72e598c13ef8cea2a99798bfacbcb372b559c7 and the latest one on this branch.

On aa72e598c13ef8cea2a99798bfacbcb372b559c7 I have: (on an M2 MacBook Air)

torchrl ❯ PYTHONWARNINGS="ignore" python vmoens.py
None cpu mps
<torch.utils.benchmark.utils.common.Measurement object at 0x132852710>
env.rollout(1000, break_when_any_done=False)
  Median: 2.76 s
  IQR:    0.01 s (2.75 to 2.76)
  4 measurements, 1 runs per measurement, 1 thread
None cpu cpu
<torch.utils.benchmark.utils.common.Measurement object at 0x127b9c850>
env.rollout(1000, break_when_any_done=False)
  Median: 1.85 s
  IQR:    0.01 s (1.85 to 1.85)
  4 measurements, 1 runs per measurement, 1 thread
spawn cpu mps
<torch.utils.benchmark.utils.common.Measurement object at 0x1409a1150>
env.rollout(1000, break_when_any_done=False)
  Median: 2.77 s
  IQR:    0.03 s (2.76 to 2.79)
  4 measurements, 1 runs per measurement, 1 thread
spawn cpu cpu
<torch.utils.benchmark.utils.common.Measurement object at 0x127bcf610>
env.rollout(1000, break_when_any_done=False)
  Median: 1.86 s
  IQR:    0.02 s (1.86 to 1.88)
  4 measurements, 1 runs per measurement, 1 thread
fork cpu mps
<torch.utils.benchmark.utils.common.Measurement object at 0x1328531f0>
env.rollout(1000, break_when_any_done=False)
  Median: 2.79 s
  IQR:    0.01 s (2.78 to 2.79)
  4 measurements, 1 runs per measurement, 1 thread
fork cpu cpu
<torch.utils.benchmark.utils.common.Measurement object at 0x104d87eb0>
env.rollout(1000, break_when_any_done=False)
  Median: 1.94 s
  IQR:    0.03 s (1.92 to 1.95)
  4 measurements, 1 runs per measurement, 1 thread

@skandermoalla
Copy link
Contributor

With aa72e598c13ef8cea2a99798bfacbcb372b559c7 I also get the following:

The previous issue in #1864 (comment) is solved.
There is no corruption of the observation anymore.

However, I found another bug (which may be coming from PyTorch).

from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import (
    ExplorationType,
    ParallelEnv,
    TransformedEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor

max_step = 10
env_id = "CartPole-v1"
device = "mps"


def build_actor(env):
    return ProbabilisticActor(
        module=TensorDictModule(
            nn.LazyLinear(env.action_spec.space.n),
            in_keys=["observation"],
            out_keys=["logits"],
        ),
        spec=env.action_spec,
        distribution_class=OneHotCategorical,
        in_keys=["logits"],
        default_interaction_type=ExplorationType.RANDOM,
    )


if __name__ == "__main__":
    env = ParallelEnv(4, lambda: GymEnv(env_id), device=device)
    # Changing to serial env removes the problem.
    env = TransformedEnv(env)
    # Or removing the transformed env removes the problem.

    policy_module = build_actor(env)
    policy_module.to(device)
    policy_module(env.reset())
    for i in range(10):
        batches = env.rollout(max_step + 3, policy=policy_module, break_when_any_done=False)
torchrl ❯ python issue_transformed_env.py     
/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/_pytree.py:147: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  register_pytree_node(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/torch/nn/modules/lazy.py:181: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
  warnings.warn('Lazy modules are a new feature under heavy development '
/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/_pytree.py:147: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  register_pytree_node(
/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/_pytree.py:147: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  register_pytree_node(
/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/_pytree.py:147: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  register_pytree_node(
/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/_pytree.py:147: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  register_pytree_node(
/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/utils.py:166: UserWarning: The expected key set and actual key set differ. This will work but with a slower throughput than when the specs match exactly the actual key set in the data. Expected - Actual keys=set(), 
Actual - Expected keys={'logits'}.
  warnings.warn(
Traceback (most recent call last):
  File "/Users/moalla/projects/open-source/TorchRL/tests/mps/issue_transformed_env.py", line 40, in <module>
    batches = env.rollout(max_step + 3, policy=policy_module, break_when_any_done=False)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2504, in rollout
    tensordicts = self._rollout_nonstop(**kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2606, in _rollout_nonstop
    tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2662, in step_and_maybe_reset
    tensordict_ = self.maybe_reset(tensordict_)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2707, in maybe_reset
    tensordict = self.reset(tensordict)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2069, in reset
    tensordict_reset = self._reset(tensordict, **kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/transforms/transforms.py", line 798, in _reset
    tensordict_reset = self.base_env._reset(tensordict=tensordict, **kwargs)
  File "/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 60, in decorated_fun
    return fun(self, *args, **kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 1486, in _reset
    self.shared_tensordicts[i].update_(
  File "/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/base.py", line 2859, in update_
    self._apply_nest(
  File "/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/_td.py", line 761, in _apply_nest
    item_trsf = fn(unravel_key(prefix + (key,)), item, *_others)
  File "/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/base.py", line 2843, in inplace_update
    dest.copy_(source, non_blocking=non_blocking)
RuntimeError: destOffset % 4 == 0 INTERNAL ASSERT FAILED at "/Users/runner/work/_temp/anaconda/conda-bld/pytorch_1704987091277/work/aten/src/ATen/native/mps/operations/Copy.mm":107, please report a bug to PyTorch. Unaligned blit request

@vmoens
Copy link
Contributor Author

vmoens commented Mar 23, 2024

  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 1486, in _reset
    self.shared_tensordicts[i].update_(

Yeah that was an error on my side, I pushed a commit with a typo. Should be fixed now

@skandermoalla
Copy link
Contributor

Cool. Works with the latest commit.

  1. On your script, I have
None cpu mps
<torch.utils.benchmark.utils.common.Measurement object at 0x14bda42e0>
env.rollout(1000, break_when_any_done=False)
  Median: 2.97 s
  IQR:    0.09 s (2.93 to 3.02)
  4 measurements, 1 runs per measurement, 1 thread
None cpu cpu
<torch.utils.benchmark.utils.common.Measurement object at 0x1509a51e0>
env.rollout(1000, break_when_any_done=False)
  Median: 1.93 s
  IQR:    0.02 s (1.92 to 1.95)
  4 measurements, 1 runs per measurement, 1 thread
spawn cpu mps
<torch.utils.benchmark.utils.common.Measurement object at 0x13299c850>
env.rollout(1000, break_when_any_done=False)
  Median: 2.95 s
  IQR:    0.04 s (2.93 to 2.97)
  4 measurements, 1 runs per measurement, 1 thread
spawn cpu cpu
<torch.utils.benchmark.utils.common.Measurement object at 0x14bd47e80>
env.rollout(1000, break_when_any_done=False)
  Median: 1.99 s
  IQR:    0.06 s (1.98 to 2.05)
  4 measurements, 1 runs per measurement, 1 thread
fork cpu mps
<torch.utils.benchmark.utils.common.Measurement object at 0x150937460>
env.rollout(1000, break_when_any_done=False)
  Median: 3.19 s
  IQR:    0.05 s (3.16 to 3.21)
  4 measurements, 1 runs per measurement, 1 thread
fork cpu cpu
<torch.utils.benchmark.utils.common.Measurement object at 0x14ba7c190>
env.rollout(1000, break_when_any_done=False)
  Median: 2.25 s
  IQR:    0.07 s (2.21 to 2.28)
  4 measurements, 1 runs per measurement, 1 thread
  1. On the previous scripts that created the issue, everything seems fine, there is no issue anymore ([BUG] Problems with BatchedEnv on accelerated device with single envs on cpu #1864 (comment))

  2. There is a bug on the script listed in [BugFix] Robust sync for non_blocking=True #2034 (comment).

Thanks for this! With 1 and 2 I guess the problem is solved. I'm not sure if 3 is in the scope of this PR.

@skandermoalla
Copy link
Contributor

3 is still a barrier to using MPS in practice with TorchRL though, so I'd be good to address it soon too.

@vmoens
Copy link
Contributor Author

vmoens commented Mar 25, 2024

3 is still a barrier to using MPS in practice with TorchRL though, so I'd be good to address it soon too.

Unfortunately this one is known and still not solved:
pytorch/pytorch#119367

Upvote, comment and make yourself heard on that issue!

torchrl/envs/batched_envs.py Outdated Show resolved Hide resolved
torchrl/envs/batched_envs.py Outdated Show resolved Hide resolved
Comment on lines +1474 to +1479
elif self.device is not None and self.device.type == "mps":
# copy_ fails when moving mps->cpu using copy_
# in some cases when a view of an mps tensor is used.
# We know the shared tensors are not MPS, so we can
# safely assume that the shared tensors are on cpu
tensordict_ = tensordict_.to("cpu")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@skandermoalla this will make the class usable, though it's not optimal.
this _reset will not be used much in practice and the info copied will usually be lightweight (e.g. "_reset" info etc) so that should not slow things down too much

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this!

@vmoens vmoens marked this pull request as ready for review March 25, 2024 15:03
@vmoens vmoens added the Environments Adds or modifies an environment wrapper label Mar 26, 2024
Comment on lines +2897 to +2903
warnings.warn(
"Your wrapper was not given a device. Currently, this "
"value will default to 'cpu'. From v0.5 it will "
"default to `None`. With a device of None, no device casting "
"is performed and the resulting tensordicts are deviceless. "
"Please set your device accordingly.",
category=DeprecationWarning,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Asking for feedback on this

@matteobettini @skandermoalla @albertbou92 @BY571 @alexander-soare @dennismalmgren

I know this is going to be massively annoying but we need to deprecate the default device of envs to CPU.
First it's arbitrary (an env can be on cuda by default) and restrictive (we now support envs that have parts of their input/output on cpu and part on another device).

The only option I see to avoid having this warning popping up all the time is to make an environment variable that you can set to use the non-deprecated behaviour. I personally don't like relying on environment variables...

Wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand the problem.

Here's my understanding:
We want to deprecate the default device of an env.
If users don't want to do casting, with the current version they can already just not set the device (or set it to None?). This is valid behavior.
However, this will raise a (useless) warning until v0.5 and we want to avoid that.
Nevertheless, we want the warning to be raised when the user is not aware of the deprecation and just didn't put a device.

This is a new feature, so anyway users will be changing their code when adopting it, so an environment variable is not necessary I'd say. They can set an optional keyword argument no_default_device=True until v0.5.

"Please set your device accordingly." was a bit confusing. On my first read I thought that meant I have to set a device explicitly and None is not a device. On my second read, I think it just means I have to be careful when picking between a device and None.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, this will raise a (useless) warning until v0.5 and we want to avoid that.
Nevertheless, we want the warning to be raised when the user is not aware of the deprecation and just didn't put a device.

Yes so the only thing I want is that people who used to do assert env.device == "cpu" don't see their code breaking after 0.4, but they know that it will break in v0.5.

So I take it that we should keep the warning then, right? I should maybe use torchrl logger to allow people to disable warnings all at once.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes so the only thing I want is that people who used to do assert env.device == "cpu" don't see their code breaking after 0.4, but they know that it will break in v0.5.

This is what the current code does. So yes keep the warning.

For the early adopters, they will change their code anyway so it's fine to add an argument to suppress the warning IMO.

@vmoens vmoens merged commit 2b95b41 into main Mar 26, 2024
66 of 67 checks passed
@vmoens vmoens deleted the fix-nonblocking-once-more branch March 26, 2024 19:21
vmoens added a commit that referenced this pull request Apr 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Environments Adds or modifies an environment wrapper
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants