-
Notifications
You must be signed in to change notification settings - Fork 325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] ChessEnv #2641
[Feature] ChessEnv #2641
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2641
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New Failures, 20 Unrelated FailuresAs of commit 68228f9 with merge base 4bc40a8 (): NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: b799f3a105ca47be61e8e6595a1ac30146dcc867 Pull Request resolved: #2641
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.4331s | 0.4316s | 2.3168 Ops/s | 2.2565 Ops/s | |
test_transformed | 0.7154s | 0.6282s | 1.5919 Ops/s | 1.6196 Ops/s | |
test_serial | 1.3445s | 1.3401s | 0.7462 Ops/s | 0.7331 Ops/s | |
test_parallel | 1.3261s | 1.2905s | 0.7749 Ops/s | 0.7575 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.2265ms | 29.6251μs | 33.7551 KOps/s | 33.5258 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 52.6380μs | 17.7310μs | 56.3985 KOps/s | 57.0336 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 48.9910μs | 16.9819μs | 58.8861 KOps/s | 59.5265 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 29.5760μs | 10.0551μs | 99.4521 KOps/s | 101.1100 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 68.1670μs | 32.4919μs | 30.7769 KOps/s | 31.2891 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 0.6140ms | 19.6847μs | 50.8008 KOps/s | 50.4310 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 50.8450μs | 18.8053μs | 53.1766 KOps/s | 54.0776 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 0.1457ms | 12.6159μs | 79.2651 KOps/s | 85.7924 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 65.8530μs | 33.9483μs | 29.4566 KOps/s | 29.7346 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 48.2700μs | 21.2381μs | 47.0853 KOps/s | 46.7160 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 49.7530μs | 18.8686μs | 52.9982 KOps/s | 53.9706 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 40.6160μs | 11.7266μs | 85.2761 KOps/s | 86.2718 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 0.2451ms | 35.5667μs | 28.1162 KOps/s | 28.3807 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 0.2597ms | 23.2641μs | 42.9846 KOps/s | 42.9796 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 66.1970μs | 20.4840μs | 48.8187 KOps/s | 49.3813 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 44.5140μs | 13.4391μs | 74.4097 KOps/s | 74.6183 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 64.8210μs | 33.4371μs | 29.9069 KOps/s | 29.7311 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 57.8280μs | 21.4686μs | 46.5797 KOps/s | 46.5685 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 74.3490μs | 21.3999μs | 46.7291 KOps/s | 47.5245 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 55.7340μs | 13.1361μs | 76.1260 KOps/s | 77.2889 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 74.4590μs | 35.4761μs | 28.1880 KOps/s | 28.2193 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 52.9190μs | 23.0259μs | 43.4294 KOps/s | 43.5536 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 0.1021s | 26.6722μs | 37.4923 KOps/s | 44.3272 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 0.6077ms | 14.7146μs | 67.9599 KOps/s | 68.0383 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 86.8530μs | 37.4870μs | 26.6759 KOps/s | 27.0417 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 59.7710μs | 25.1599μs | 39.7457 KOps/s | 40.4132 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 64.8710μs | 23.0399μs | 43.4029 KOps/s | 44.6287 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 34.6350μs | 14.8891μs | 67.1631 KOps/s | 68.3239 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 95.4980μs | 38.4105μs | 26.0346 KOps/s | 26.0590 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 58.8100μs | 26.7117μs | 37.4368 KOps/s | 37.8759 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 54.0410μs | 24.4079μs | 40.9704 KOps/s | 41.6634 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 44.4730μs | 16.6904μs | 59.9148 KOps/s | 62.7936 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 10.1072ms | 9.7438ms | 102.6294 Ops/s | 104.8964 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 37.6337ms | 35.8964ms | 27.8579 Ops/s | 30.2109 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.2426ms | 0.1817ms | 5.5038 KOps/s | 5.7233 KOps/s | |
test_values[td1_return_estimate-False-False] | 27.3063ms | 24.6585ms | 40.5540 Ops/s | 41.5463 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 38.0520ms | 35.9406ms | 27.8237 Ops/s | 30.0397 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 38.3141ms | 35.0315ms | 28.5457 Ops/s | 28.9413 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 38.0963ms | 36.0883ms | 27.7098 Ops/s | 30.0199 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 8.9012ms | 8.4622ms | 118.1720 Ops/s | 118.5250 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 1.9760ms | 1.7932ms | 557.6499 Ops/s | 496.3628 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.5178ms | 0.3643ms | 2.7447 KOps/s | 2.7395 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 57.2308ms | 49.1492ms | 20.3462 Ops/s | 22.3374 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 4.0158ms | 3.0531ms | 327.5352 Ops/s | 327.6253 Ops/s | |
test_dqn_speed[False-None] | 2.1605ms | 1.3329ms | 750.2403 Ops/s | 724.6606 Ops/s | |
test_dqn_speed[False-backward] | 2.4740ms | 1.8547ms | 539.1778 Ops/s | 543.8927 Ops/s | |
test_dqn_speed[True-None] | 0.7625ms | 0.4655ms | 2.1482 KOps/s | 2.1537 KOps/s | |
test_dqn_speed[True-backward] | 0.9603ms | 0.8892ms | 1.1246 KOps/s | 1.1437 KOps/s | |
test_dqn_speed[reduce-overhead-None] | 0.6009ms | 0.4659ms | 2.1462 KOps/s | 2.1741 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 0.9611ms | 0.8917ms | 1.1215 KOps/s | 1.1423 KOps/s | |
test_ddpg_speed[False-None] | 3.1082ms | 2.8001ms | 357.1259 Ops/s | 351.4708 Ops/s | |
test_ddpg_speed[False-backward] | 4.0237ms | 3.9275ms | 254.6150 Ops/s | 248.8246 Ops/s | |
test_ddpg_speed[True-None] | 1.1938ms | 0.9948ms | 1.0052 KOps/s | 1.0186 KOps/s | |
test_ddpg_speed[True-backward] | 1.9709ms | 1.8884ms | 529.5491 Ops/s | 499.6218 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 1.4373ms | 0.9956ms | 1.0044 KOps/s | 1.0077 KOps/s | |
test_ddpg_speed[reduce-overhead-backward] | 2.0215ms | 1.9077ms | 524.1852 Ops/s | 532.0466 Ops/s | |
test_sac_speed[False-None] | 9.5507ms | 7.9970ms | 125.0473 Ops/s | 125.5359 Ops/s | |
test_sac_speed[False-backward] | 11.3298ms | 10.6811ms | 93.6236 Ops/s | 93.5433 Ops/s | |
test_sac_speed[True-None] | 2.1176ms | 1.8040ms | 554.3277 Ops/s | 548.5009 Ops/s | |
test_sac_speed[True-backward] | 3.6056ms | 3.4673ms | 288.4083 Ops/s | 283.5129 Ops/s | |
test_sac_speed[reduce-overhead-None] | 2.4185ms | 1.7943ms | 557.3347 Ops/s | 546.4773 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 3.5095ms | 3.4435ms | 290.4037 Ops/s | 287.5994 Ops/s | |
test_redq_speed[False-None] | 19.0292ms | 12.8413ms | 77.8739 Ops/s | 75.6816 Ops/s | |
test_redq_speed[False-backward] | 24.7499ms | 22.0842ms | 45.2813 Ops/s | 44.0420 Ops/s | |
test_redq_speed[True-None] | 5.2640ms | 4.4624ms | 224.0936 Ops/s | 216.9611 Ops/s | |
test_redq_speed[True-backward] | 12.7451ms | 11.8588ms | 84.3258 Ops/s | 80.6038 Ops/s | |
test_redq_speed[reduce-overhead-None] | 5.5262ms | 4.4504ms | 224.7006 Ops/s | 200.0010 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 13.0062ms | 11.9319ms | 83.8087 Ops/s | 83.4657 Ops/s | |
test_redq_deprec_speed[False-None] | 13.9719ms | 12.4743ms | 80.1646 Ops/s | 78.7231 Ops/s | |
test_redq_deprec_speed[False-backward] | 19.0614ms | 18.0874ms | 55.2870 Ops/s | 54.0254 Ops/s | |
test_redq_deprec_speed[True-None] | 3.9982ms | 3.5237ms | 283.7892 Ops/s | 283.9776 Ops/s | |
test_redq_deprec_speed[True-backward] | 8.0375ms | 7.8250ms | 127.7960 Ops/s | 118.4172 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 4.1318ms | 3.5119ms | 284.7428 Ops/s | 282.9651 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 8.3834ms | 7.8536ms | 127.3309 Ops/s | 127.0386 Ops/s | |
test_td3_speed[False-None] | 8.8917ms | 7.7225ms | 129.4911 Ops/s | 125.3614 Ops/s | |
test_td3_speed[False-backward] | 12.2452ms | 10.1674ms | 98.3536 Ops/s | 96.8089 Ops/s | |
test_td3_speed[True-None] | 1.9638ms | 1.6794ms | 595.4357 Ops/s | 586.1076 Ops/s | |
test_td3_speed[True-backward] | 3.3186ms | 3.2588ms | 306.8645 Ops/s | 304.2732 Ops/s | |
test_td3_speed[reduce-overhead-None] | 1.9032ms | 1.6859ms | 593.1520 Ops/s | 579.5732 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 4.0353ms | 3.2890ms | 304.0410 Ops/s | 293.4435 Ops/s | |
test_cql_speed[False-None] | 43.8109ms | 36.4807ms | 27.4118 Ops/s | 27.2280 Ops/s | |
test_cql_speed[False-backward] | 63.0921ms | 47.3273ms | 21.1294 Ops/s | 20.8605 Ops/s | |
test_cql_speed[True-None] | 16.9356ms | 15.7852ms | 63.3506 Ops/s | 62.9501 Ops/s | |
test_cql_speed[True-backward] | 27.3871ms | 22.5769ms | 44.2931 Ops/s | 42.6461 Ops/s | |
test_cql_speed[reduce-overhead-None] | 16.2813ms | 15.4127ms | 64.8814 Ops/s | 63.2398 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 23.6191ms | 22.2748ms | 44.8937 Ops/s | 43.5954 Ops/s | |
test_a2c_speed[False-None] | 8.2147ms | 7.0854ms | 141.1363 Ops/s | 136.2875 Ops/s | |
test_a2c_speed[False-backward] | 15.2936ms | 14.1760ms | 70.5415 Ops/s | 67.0423 Ops/s | |
test_a2c_speed[True-None] | 4.5779ms | 4.1599ms | 240.3925 Ops/s | 233.5737 Ops/s | |
test_a2c_speed[True-backward] | 11.1258ms | 10.5885ms | 94.4419 Ops/s | 89.4684 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 4.7196ms | 4.1646ms | 240.1165 Ops/s | 232.8998 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 20.9431ms | 11.1378ms | 89.7842 Ops/s | 92.9100 Ops/s | |
test_ppo_speed[False-None] | 9.5040ms | 7.3328ms | 136.3729 Ops/s | 132.4478 Ops/s | |
test_ppo_speed[False-backward] | 15.7712ms | 14.6762ms | 68.1373 Ops/s | 64.1496 Ops/s | |
test_ppo_speed[True-None] | 4.0567ms | 3.6580ms | 273.3708 Ops/s | 268.8576 Ops/s | |
test_ppo_speed[True-backward] | 9.6900ms | 9.4452ms | 105.8738 Ops/s | 101.6377 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 4.3632ms | 3.6521ms | 273.8129 Ops/s | 252.0867 Ops/s | |
test_ppo_speed[reduce-overhead-backward] | 10.2452ms | 9.8711ms | 101.3057 Ops/s | 102.9900 Ops/s | |
test_reinforce_speed[False-None] | 8.5774ms | 6.5860ms | 151.8374 Ops/s | 148.8055 Ops/s | |
test_reinforce_speed[False-backward] | 11.2973ms | 9.9565ms | 100.4370 Ops/s | 99.0698 Ops/s | |
test_reinforce_speed[True-None] | 3.2725ms | 2.6672ms | 374.9252 Ops/s | 358.4032 Ops/s | |
test_reinforce_speed[True-backward] | 9.5596ms | 8.6768ms | 115.2497 Ops/s | 114.1728 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 3.2143ms | 2.6675ms | 374.8863 Ops/s | 374.4309 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 8.8436ms | 8.4743ms | 118.0037 Ops/s | 114.9224 Ops/s | |
test_iql_speed[False-None] | 32.9659ms | 31.6241ms | 31.6215 Ops/s | 30.6675 Ops/s | |
test_iql_speed[False-backward] | 46.5327ms | 44.6330ms | 22.4049 Ops/s | 21.9718 Ops/s | |
test_iql_speed[True-None] | 11.2228ms | 10.4853ms | 95.3717 Ops/s | 92.7012 Ops/s | |
test_iql_speed[True-backward] | 21.8846ms | 21.1549ms | 47.2704 Ops/s | 45.6349 Ops/s | |
test_iql_speed[reduce-overhead-None] | 11.5223ms | 10.6351ms | 94.0281 Ops/s | 92.4088 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 22.3998ms | 21.5015ms | 46.5083 Ops/s | 44.3457 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 7.4865ms | 4.9357ms | 202.6069 Ops/s | 194.1357 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.7043ms | 0.5035ms | 1.9861 KOps/s | 1.9117 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.6683ms | 0.4803ms | 2.0821 KOps/s | 1.9931 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 5.1361ms | 4.6609ms | 214.5491 Ops/s | 193.6556 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 1.1700ms | 0.4854ms | 2.0601 KOps/s | 2.0064 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.7545ms | 0.4668ms | 2.1422 KOps/s | 2.1258 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 2.1929ms | 1.6369ms | 610.8946 Ops/s | 606.4069 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 1.9324ms | 1.5843ms | 631.2024 Ops/s | 629.8557 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 7.1931ms | 4.9525ms | 201.9167 Ops/s | 196.3398 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 3.7340ms | 0.6346ms | 1.5757 KOps/s | 1.5282 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.8164ms | 0.6082ms | 1.6443 KOps/s | 1.5848 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 7.1702ms | 4.6973ms | 212.8904 Ops/s | 201.0881 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.8076ms | 0.5065ms | 1.9742 KOps/s | 1.8851 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 7.4796ms | 0.4898ms | 2.0416 KOps/s | 2.0360 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 7.2732ms | 4.6195ms | 216.4738 Ops/s | 205.2051 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 1.2248ms | 0.4874ms | 2.0517 KOps/s | 1.9736 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.7095ms | 0.4701ms | 2.1272 KOps/s | 1.9937 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 7.2281ms | 4.8115ms | 207.8356 Ops/s | 199.0525 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 2.1806ms | 0.6348ms | 1.5754 KOps/s | 1.5476 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.9147ms | 0.6164ms | 1.6223 KOps/s | 1.5549 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 0.4063s | 12.3608ms | 80.9007 Ops/s | 249.6540 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 9.4106ms | 2.3391ms | 427.5235 Ops/s | 437.3234 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 1.8897ms | 1.2532ms | 797.9744 Ops/s | 771.2217 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 7.7709ms | 4.5424ms | 220.1478 Ops/s | 34.1981 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 0.3629s | 9.4789ms | 105.4975 Ops/s | 431.0318 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 6.9913ms | 1.3562ms | 737.3793 Ops/s | 790.9071 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 6.2427ms | 4.4872ms | 222.8579 Ops/s | 218.5579 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 6.9113ms | 2.4402ms | 409.7960 Ops/s | 394.3260 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 6.3471ms | 1.5001ms | 666.6159 Ops/s | 688.6817 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 11.3318ms | 11.0307ms | 90.6558 Ops/s | 86.8859 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 15.0848ms | 14.3572ms | 69.6516 Ops/s | 68.9019 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 21.5252ms | 19.7764ms | 50.5654 Ops/s | 49.5320 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 16.5774ms | 14.5504ms | 68.7268 Ops/s | 68.5072 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 21.0664ms | 19.5808ms | 51.0704 Ops/s | 49.8835 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 17.5973ms | 15.7714ms | 63.4059 Ops/s | 63.7936 Ops/s |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.7435s | 0.7427s | 1.3465 Ops/s | 1.2701 Ops/s | |
test_transformed | 1.0945s | 1.0145s | 0.9857 Ops/s | 0.9904 Ops/s | |
test_serial | 2.2957s | 2.1995s | 0.4547 Ops/s | 0.4530 Ops/s | |
test_parallel | 2.0407s | 1.9529s | 0.5121 Ops/s | 0.5040 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.2181ms | 37.3400μs | 26.7809 KOps/s | 23.9210 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 52.4610μs | 22.3677μs | 44.7073 KOps/s | 43.9335 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 50.2610μs | 20.7704μs | 48.1454 KOps/s | 44.4152 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 39.2710μs | 12.3778μs | 80.7898 KOps/s | 78.6446 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 0.1153ms | 39.9458μs | 25.0339 KOps/s | 24.3665 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 60.0310μs | 24.1664μs | 41.3797 KOps/s | 40.5530 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 52.7410μs | 22.7550μs | 43.9464 KOps/s | 41.9349 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 49.7810μs | 14.4752μs | 69.0839 KOps/s | 67.5521 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 82.2710μs | 41.5727μs | 24.0543 KOps/s | 22.6080 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 53.5510μs | 26.2079μs | 38.1564 KOps/s | 37.2919 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 51.1710μs | 23.0532μs | 43.3779 KOps/s | 42.0412 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 43.9510μs | 14.5164μs | 68.8877 KOps/s | 68.1162 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 72.9610μs | 44.4584μs | 22.4929 KOps/s | 21.5652 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 64.1310μs | 28.7623μs | 34.7678 KOps/s | 34.5770 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 58.2900μs | 25.1142μs | 39.8181 KOps/s | 38.0828 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 39.9910μs | 16.3926μs | 61.0031 KOps/s | 58.8617 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 77.9010μs | 42.1677μs | 23.7148 KOps/s | 23.1854 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 54.9100μs | 26.0446μs | 38.3957 KOps/s | 37.1457 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 64.6910μs | 26.6804μs | 37.4807 KOps/s | 36.3446 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 47.8810μs | 15.7382μs | 63.5396 KOps/s | 60.6439 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 78.5710μs | 44.3138μs | 22.5663 KOps/s | 21.7180 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 60.3910μs | 28.1681μs | 35.5012 KOps/s | 34.6466 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 3.3148ms | 28.9118μs | 34.5880 KOps/s | 33.2766 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 47.0300μs | 18.2913μs | 54.6708 KOps/s | 53.0846 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 81.7720μs | 47.6577μs | 20.9830 KOps/s | 20.8915 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 61.5110μs | 30.3194μs | 32.9821 KOps/s | 31.7060 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 55.8210μs | 28.4501μs | 35.1493 KOps/s | 34.1200 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 46.6710μs | 18.2798μs | 54.7052 KOps/s | 54.6429 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 88.5510μs | 48.2823μs | 20.7115 KOps/s | 20.4416 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 63.4710μs | 32.6276μs | 30.6489 KOps/s | 30.6864 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 59.2310μs | 30.4674μs | 32.8219 KOps/s | 32.6304 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 44.4510μs | 19.8040μs | 50.4948 KOps/s | 49.3538 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 26.0980ms | 25.4790ms | 39.2480 Ops/s | 38.4557 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 97.5494ms | 2.8526ms | 350.5626 Ops/s | 327.8182 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.1054ms | 80.7474μs | 12.3843 KOps/s | 12.3928 KOps/s | |
test_values[td1_return_estimate-False-False] | 56.4475ms | 55.7751ms | 17.9291 Ops/s | 17.5634 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 1.3321ms | 1.0885ms | 918.6766 Ops/s | 918.4396 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 97.1070ms | 90.8157ms | 11.0113 Ops/s | 11.0949 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 1.3240ms | 1.0858ms | 920.9403 Ops/s | 922.7844 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 25.4315ms | 25.3166ms | 39.4998 Ops/s | 40.0253 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 1.0404ms | 0.7553ms | 1.3239 KOps/s | 1.3298 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.7679ms | 0.6765ms | 1.4781 KOps/s | 1.4869 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 1.5552ms | 1.4851ms | 673.3605 Ops/s | 676.7186 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 0.7234ms | 0.6886ms | 1.4523 KOps/s | 1.4596 KOps/s | |
test_dqn_speed[False-None] | 7.3661ms | 1.5431ms | 648.0486 Ops/s | 673.9410 Ops/s | |
test_dqn_speed[False-backward] | 2.1684ms | 2.1229ms | 471.0627 Ops/s | 468.0191 Ops/s | |
test_dqn_speed[True-None] | 0.6827ms | 0.5520ms | 1.8115 KOps/s | 1.8745 KOps/s | |
test_dqn_speed[True-backward] | 1.2991ms | 1.2312ms | 812.2245 Ops/s | 904.2082 Ops/s | |
test_dqn_speed[reduce-overhead-None] | 0.6843ms | 0.5724ms | 1.7469 KOps/s | 1.7597 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 1.2370ms | 1.0990ms | 909.9056 Ops/s | 938.1396 Ops/s | |
test_ddpg_speed[False-None] | 3.2640ms | 2.9256ms | 341.8059 Ops/s | 351.5238 Ops/s | |
test_ddpg_speed[False-backward] | 4.8269ms | 4.3160ms | 231.6974 Ops/s | 236.1261 Ops/s | |
test_ddpg_speed[True-None] | 1.1595ms | 1.0689ms | 935.5799 Ops/s | 945.7651 Ops/s | |
test_ddpg_speed[True-backward] | 2.3636ms | 2.3045ms | 433.9263 Ops/s | 465.3393 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 1.1406ms | 1.0811ms | 924.9627 Ops/s | 921.0099 Ops/s | |
test_ddpg_speed[reduce-overhead-backward] | 1.8510ms | 1.7909ms | 558.3904 Ops/s | 608.9942 Ops/s | |
test_sac_speed[False-None] | 8.5074ms | 8.0045ms | 124.9303 Ops/s | 123.9176 Ops/s | |
test_sac_speed[False-backward] | 12.1007ms | 11.3756ms | 87.9071 Ops/s | 89.2033 Ops/s | |
test_sac_speed[True-None] | 1.5988ms | 1.5303ms | 653.4574 Ops/s | 635.6520 Ops/s | |
test_sac_speed[True-backward] | 3.4544ms | 3.4054ms | 293.6547 Ops/s | 293.6847 Ops/s | |
test_sac_speed[reduce-overhead-None] | 22.0317ms | 12.3254ms | 81.1332 Ops/s | 80.3697 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 1.5617ms | 1.5010ms | 666.2265 Ops/s | 734.9282 Ops/s | |
test_redq_speed[False-None] | 8.3443ms | 7.4835ms | 133.6265 Ops/s | 132.7352 Ops/s | |
test_redq_speed[False-backward] | 12.9493ms | 11.8167ms | 84.6260 Ops/s | 87.1025 Ops/s | |
test_redq_speed[True-None] | 2.0340ms | 1.9792ms | 505.2466 Ops/s | 499.6782 Ops/s | |
test_redq_speed[True-backward] | 3.9176ms | 3.8343ms | 260.8007 Ops/s | 268.5052 Ops/s | |
test_redq_speed[reduce-overhead-None] | 2.0775ms | 1.9830ms | 504.2827 Ops/s | 502.9215 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 4.3544ms | 3.7015ms | 270.1607 Ops/s | 260.9664 Ops/s | |
test_redq_deprec_speed[False-None] | 9.5318ms | 9.0073ms | 111.0205 Ops/s | 109.6396 Ops/s | |
test_redq_deprec_speed[False-backward] | 12.7283ms | 12.1782ms | 82.1141 Ops/s | 79.4958 Ops/s | |
test_redq_deprec_speed[True-None] | 2.3733ms | 2.3123ms | 432.4682 Ops/s | 432.1166 Ops/s | |
test_redq_deprec_speed[True-backward] | 4.4517ms | 4.0087ms | 249.4595 Ops/s | 251.6187 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 2.4020ms | 2.3063ms | 433.5913 Ops/s | 433.3488 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 4.4141ms | 4.0434ms | 247.3181 Ops/s | 251.8098 Ops/s | |
test_td3_speed[False-None] | 8.2230ms | 7.8921ms | 126.7082 Ops/s | 126.5862 Ops/s | |
test_td3_speed[False-backward] | 10.8345ms | 10.3793ms | 96.3451 Ops/s | 96.9289 Ops/s | |
test_td3_speed[True-None] | 1.6056ms | 1.5572ms | 642.1599 Ops/s | 648.8883 Ops/s | |
test_td3_speed[True-backward] | 3.2008ms | 3.1128ms | 321.2528 Ops/s | 323.1797 Ops/s | |
test_td3_speed[reduce-overhead-None] | 47.9801ms | 24.5881ms | 40.6701 Ops/s | 38.5702 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 1.3423ms | 1.2875ms | 776.7252 Ops/s | 769.9516 Ops/s | |
test_cql_speed[False-None] | 16.5067ms | 16.1238ms | 62.0200 Ops/s | 61.6380 Ops/s | |
test_cql_speed[False-backward] | 21.9718ms | 21.5486ms | 46.4067 Ops/s | 46.0686 Ops/s | |
test_cql_speed[True-None] | 3.2077ms | 3.0964ms | 322.9534 Ops/s | 342.9842 Ops/s | |
test_cql_speed[True-backward] | 5.4429ms | 5.0839ms | 196.6977 Ops/s | 188.0948 Ops/s | |
test_cql_speed[reduce-overhead-None] | 21.4546ms | 12.9470ms | 77.2383 Ops/s | 77.6337 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 1.5426ms | 1.5039ms | 664.9167 Ops/s | 596.6732 Ops/s | |
test_a2c_speed[False-None] | 3.2774ms | 3.1768ms | 314.7844 Ops/s | 310.4098 Ops/s | |
test_a2c_speed[False-backward] | 6.8202ms | 6.2697ms | 159.4981 Ops/s | 154.9539 Ops/s | |
test_a2c_speed[True-None] | 1.0629ms | 0.9947ms | 1.0053 KOps/s | 974.0225 Ops/s | |
test_a2c_speed[True-backward] | 2.6802ms | 2.6251ms | 380.9362 Ops/s | 354.1405 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 0.3972s | 11.9943ms | 83.3728 Ops/s | 89.2122 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 1.0611ms | 1.0037ms | 996.3590 Ops/s | 871.6926 Ops/s | |
test_ppo_speed[False-None] | 3.7558ms | 3.6783ms | 271.8624 Ops/s | 269.2614 Ops/s | |
test_ppo_speed[False-backward] | 7.4766ms | 6.9277ms | 144.3481 Ops/s | 138.1136 Ops/s | |
test_ppo_speed[True-None] | 1.0003ms | 0.9339ms | 1.0708 KOps/s | 1.0376 KOps/s | |
test_ppo_speed[True-backward] | 2.6552ms | 2.5674ms | 389.5010 Ops/s | 370.6218 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 0.6383ms | 0.4833ms | 2.0693 KOps/s | 1.9591 KOps/s | |
test_ppo_speed[reduce-overhead-backward] | 1.1675ms | 1.1297ms | 885.2064 Ops/s | 877.8149 Ops/s | |
test_reinforce_speed[False-None] | 2.3220ms | 2.2363ms | 447.1717 Ops/s | 440.2945 Ops/s | |
test_reinforce_speed[False-backward] | 3.8554ms | 3.4368ms | 290.9715 Ops/s | 293.2156 Ops/s | |
test_reinforce_speed[True-None] | 0.9696ms | 0.8366ms | 1.1953 KOps/s | 1.2164 KOps/s | |
test_reinforce_speed[True-backward] | 2.6564ms | 2.5762ms | 388.1622 Ops/s | 391.3644 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 21.8660ms | 11.5672ms | 86.4512 Ops/s | 90.3418 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 1.2747ms | 1.2205ms | 819.3393 Ops/s | 940.2065 Ops/s | |
test_iql_speed[False-None] | 10.1620ms | 9.2359ms | 108.2737 Ops/s | 109.0814 Ops/s | |
test_iql_speed[False-backward] | 13.8166ms | 13.3410ms | 74.9571 Ops/s | 76.7808 Ops/s | |
test_iql_speed[True-None] | 1.8454ms | 1.7525ms | 570.5981 Ops/s | 559.4600 Ops/s | |
test_iql_speed[True-backward] | 4.6369ms | 4.4876ms | 222.8347 Ops/s | 235.7223 Ops/s | |
test_iql_speed[reduce-overhead-None] | 20.2050ms | 11.2548ms | 88.8510 Ops/s | 113.7153 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 1.5868ms | 1.5498ms | 645.2280 Ops/s | 703.0247 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 7.8932ms | 6.2978ms | 158.7850 Ops/s | 157.7003 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.5997ms | 0.3499ms | 2.8577 KOps/s | 2.8888 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.5734ms | 0.3649ms | 2.7402 KOps/s | 3.2080 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.3027ms | 6.0420ms | 165.5074 Ops/s | 164.4297 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 2.1005ms | 0.2840ms | 3.5212 KOps/s | 3.4460 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.6448ms | 0.3032ms | 3.2983 KOps/s | 3.3910 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 1.4795ms | 1.2762ms | 783.5752 Ops/s | 762.2931 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 1.4858ms | 1.2196ms | 819.9466 Ops/s | 796.0576 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.4923ms | 6.2428ms | 160.1850 Ops/s | 160.0910 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.9872ms | 0.4089ms | 2.4453 KOps/s | 2.0561 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.7916ms | 0.4624ms | 2.1624 KOps/s | 2.1460 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 6.2659ms | 6.0499ms | 165.2916 Ops/s | 165.2761 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 1.7796ms | 0.2781ms | 3.5957 KOps/s | 2.8343 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.5001ms | 0.3278ms | 3.0508 KOps/s | 2.6736 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.3969ms | 6.0402ms | 165.5573 Ops/s | 166.0985 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 1.5228ms | 0.3200ms | 3.1252 KOps/s | 2.8711 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.5538ms | 0.2546ms | 3.9280 KOps/s | 3.2873 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.3901ms | 6.2801ms | 159.2328 Ops/s | 161.2588 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 2.3763ms | 0.4095ms | 2.4419 KOps/s | 2.2528 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.5225ms | 0.3875ms | 2.5805 KOps/s | 2.3449 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 6.9371ms | 5.2977ms | 188.7608 Ops/s | 190.7916 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 3.9935ms | 1.9360ms | 516.5182 Ops/s | 422.8916 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 8.7585ms | 1.2605ms | 793.3546 Ops/s | 756.7494 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.4975s | 15.2120ms | 65.7374 Ops/s | 190.3049 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 10.0420ms | 2.0024ms | 499.4104 Ops/s | 425.1538 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 7.7025ms | 1.2379ms | 807.8456 Ops/s | 845.4744 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 7.5654ms | 5.6223ms | 177.8635 Ops/s | 32.7469 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 9.3888ms | 2.2258ms | 449.2760 Ops/s | 434.1424 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 7.3920ms | 1.4305ms | 699.0450 Ops/s | 686.0384 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 13.6238ms | 13.0423ms | 76.6736 Ops/s | 75.1243 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 19.0923ms | 16.9742ms | 58.9128 Ops/s | 60.3264 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 18.0146ms | 17.5687ms | 56.9193 Ops/s | 55.4169 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 17.4745ms | 16.7330ms | 59.7623 Ops/s | 59.6286 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 18.1418ms | 17.5951ms | 56.8341 Ops/s | 56.8911 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 18.6891ms | 17.9087ms | 55.8387 Ops/s | 54.8903 Ops/s |
|
||
done = board.is_checkmate() | ||
turn = torch.tensor(board.turn) | ||
reward = torch.tensor([done]).int() * (turn.int() * 2 - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a nit: I think this line is a little hard to read. Maybe something like this is a little easier to grasp?
winner = not board.turn
reward = torch.tensor([0 if not done else (-1 if winner == chess.BLACK else 1)])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since there are multiple ways of ending the game maybe we should be a bit more comprehensive there.:
Win = 1
Lose = -1
Stalemate ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it's worth considering that usually in tournaments, if there's a checkmate or resignation, the winner gets 1 point and the loser gets 0, but if it's a draw of any kind, both players get 0.5 points. source
But I think whatever the values of win/lose are, the value of draw should probably be the average of win and lose
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also chess evaluation engines, like stockfish, use 0 if the position is equal (including draws), negative score for a black advantage, and positive score for white advantage
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok so what about 0 / 0.5 / 1?
Also worth considering: more granular reward
https://www.restack.io/p/reinforcement-learning-answer-chess-bot-cat-ai
Since we have the opportunity of having multiple rewards, we could add another tensor that assign a reward for taking / losing pieces.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that -1/0/1 is probably more consistent with existing chess engines, and it can encapsulate the value of the game for both black and white in a single number.
0/0.5/1 is probably also good, but then I guess we'd need two rewards, one for black and one for white. That is, unless we always set the reward to 0.5 for non-terminal states, but I'm not sure that would be ideal for transforms.RewardSum
.
a reward for taking / losing pieces
I think that could promote moves that do not optimize actually winning the game. For instance, if the reward decreases when losing a piece, then it could discourage making a sacrifice (or a combination of sacrifices) that leads to a forced checkmate. Likewise, if the reward increases when taking pieces, it could potentially encourage taking a piece rather than properly defending against an imminent checkmate threat.
Although it depends on what you want. Sometimes you don't want a chess bot to play optimally--like if it is meant to play against humans who are no match for the best chess bots
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with all of this. With the chess
lib, it's always white-centric no? ie, win
means white wins, or is there a way to flip things around and be black-centric?
ghstack-source-id: 42d198f63f23b5ebdff5d30870c2e4210692a1b0 Pull Request resolved: #2641
else: | ||
self.board.set_fen(fen.data) | ||
|
||
hashing = hash(fen) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, hash
is not guaranteed to give unique values for string inputs, since the number of possible strings is infinite, greater than the number of hash values. The output of hash
only has 2^64 possible values.
>>> import sys
>>> sys.hash_info.width
64
Is uniqueness required for this hash?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the questions to ask are
- what are the chances that a collision will occur
- what happens if a collision occurs
If we build a forest with - say - 10M elements, the number of combinations is still 10^13 times bigger than the capacity of the forest so I think it's safe to assume that the risk of failures to rebuild a tree due to hash collision is going to be small.
Another option on the safe size could also be to tokenize the fen. We could pad the tokens if they're shorter to make sure they all fit contiguously in memory.
Another question to solve is the one of reproducibility which worries me more than collision: if you restart your python process, the hash map will not hold anymore so any save data will be meaningless. IIRC there's a way to set the "seed" of the hash but that'd acting on a global variable which we may want to avoid anyway!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, at 10 million, the risk of collision is insignificant. I was curious where the limit is, so I looked into it. If my reasoning is correct, the risk starts to become significant around the order of 1 billion generated hashes. I think at 1 billion, the probability of collision is 2.7%. At 5 billion, the probability is almost 50%. (I put up some notes here).
So if we expect the tree to have significantly fewer than 1 billion nodes, then Python hash
should be good enough I suppose.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like one option for reproducibility is to use hashlib
, which I think also has features for generating unique hashes if we decide we need that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm open to it!
ghstack-source-id: b40e0cfab2a3955ae5c44dc0963a31f127c17040 Pull Request resolved: #2641
ghstack-source-id: ba6f6a0b3f1b450ee93df816c55e0efc921d6f0d Pull Request resolved: #2641
@kurtamohler I landed all the stack up until here. For the record here is the random forest maker I showed the other day
|
Ok I'll give it a shot! It might be easiest if we merge this as-is first and then I can make a new PR with updates. I looked into whether ghstack has collaboration features, and it looks like it doesn't |
ghstack-source-id: 087c3b12cd621ea11a252b34c4896133697bce1a Pull Request resolved: #2641
ghstack-source-id: 087c3b12cd621ea11a252b34c4896133697bce1a Pull Request resolved: #2641
Done! Thanks for that |
Stack from ghstack (oldest at bottom):