-
Notifications
You must be signed in to change notification settings - Fork 325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Test] Add tests and a few fixes for ChessEnv #2661
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2661
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (10 Unrelated Failures)As of commit 8abf242 with merge base 91064bc (): FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 157cc72f2715bfd8a4c2dca048b0d58f63280d2f Pull Request resolved: #2661
ghstack-source-id: 157cc72f2715bfd8a4c2dca048b0d58f63280d2f Pull Request resolved: pytorch#2661
ghstack-source-id: 4be1cb3d54fdd43e71f6b0141a4179e0d49a111d Pull Request resolved: #2661
@@ -127,10 +127,13 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None): | |||
self._set_action_space(tensordict) | |||
return super().rand_action(tensordict) | |||
|
|||
def _is_done(self, board): | |||
return board.is_game_over() | board.is_fifty_moves() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a pretty important rule that we're neglecting at the moment, which is that if a particular position happens three times in a game, then it's a draw.
The chess
library has a feature to check for this, but it requires the chess.Board
to have the entire history of the game. Since we're constantly calling Board.set_fen
, we're erasing the history, so we don't have a way to detect repetition at the moment.
If the main purpose of ChessEnv is to prove MCTS works, then adding this rule may not be worth the effort. But if we specifically want to follow all the standard rules, then we'd have to figure out a way to add it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very good point thanks for flagging it!
I'd be in favor of documenting it in the docstrings for now.
One option could be to use a state that is a sequence of all previous fen (like we do with LLMs where a state is all the prompt tokens + all generated tokens) and manually check for duplicates.
I see that chess also has a the option of exporting PGNs but it's a bit convoluted because we'd need to create a
import chess.pgn
# Create a new Board object
board = chess.Board()
# Make some moves
board.push_uci("e2e4")
board.push_uci("e7e5")
board.push_uci("d2d4")
# Create a new Game object
game = chess.pgn.Game()
# Add the moves to the game
node = game
for move in board.move_stack:
node = node.add_variation(move)
# Generate the PGN string
pgn_string = str(game)
print(pgn_string)
which gives us
[Event "?"]
[Site "?"]
[Date "????.??.??"]
[Round "?"]
[White "?"]
[Black "?"]
[Result "*"]
1. e4 e5 2. d4 *
So in practice, we could be using fen
or pgn
(or else?) as a state representation
env = ChessEnv(representation="fen")
I believe there should be a way to make these interchangeable (at the end of the day it's still string-in string-out)
ghstack-source-id: 4be1cb3d54fdd43e71f6b0141a4179e0d49a111d Pull Request resolved: pytorch#2661
ghstack-source-id: d2a63922664972465399fde369ecfa53ad13ce1c Pull Request resolved: #2661
ghstack-source-id: d2a63922664972465399fde369ecfa53ad13ce1c Pull Request resolved: pytorch#2661
ghstack-source-id: d2a63922664972465399fde369ecfa53ad13ce1c Pull Request resolved: #2661
ghstack-source-id: d2a63922664972465399fde369ecfa53ad13ce1c Pull Request resolved: pytorch#2661
ghstack-source-id: a2c74421a4c31193365157b1f53d2cf1010e9452 Pull Request resolved: #2661
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.4364s | 0.4328s | 2.3104 Ops/s | 2.1625 Ops/s | |
test_transformed | 0.6243s | 0.6106s | 1.6376 Ops/s | 1.5693 Ops/s | |
test_serial | 1.4739s | 1.3790s | 0.7252 Ops/s | 0.7263 Ops/s | |
test_parallel | 1.3976s | 1.3227s | 0.7560 Ops/s | 0.7472 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.4434ms | 30.6986μs | 32.5748 KOps/s | 32.1875 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 69.3090μs | 18.3253μs | 54.5694 KOps/s | 55.3764 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 44.0330μs | 17.5729μs | 56.9059 KOps/s | 57.1275 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 51.7170μs | 10.4629μs | 95.5755 KOps/s | 96.9026 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 95.5180μs | 32.9441μs | 30.3544 KOps/s | 30.3202 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 52.0570μs | 20.1360μs | 49.6623 KOps/s | 49.8369 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 61.0740μs | 19.2923μs | 51.8341 KOps/s | 51.3346 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 43.9720μs | 12.3428μs | 81.0187 KOps/s | 81.7705 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 69.2390μs | 35.0860μs | 28.5014 KOps/s | 28.5908 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 55.3830μs | 22.4383μs | 44.5666 KOps/s | 45.8429 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 85.7200μs | 19.3189μs | 51.7629 KOps/s | 51.9529 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 46.4860μs | 12.4067μs | 80.6017 KOps/s | 82.2691 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 74.0680μs | 36.7826μs | 27.1868 KOps/s | 26.8297 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 61.7850μs | 24.0588μs | 41.5648 KOps/s | 41.8224 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 64.9010μs | 21.1850μs | 47.2033 KOps/s | 46.8219 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 57.3180μs | 14.1099μs | 70.8722 KOps/s | 71.6974 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 74.7490μs | 34.8614μs | 28.6850 KOps/s | 28.4319 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 64.2090μs | 22.4813μs | 44.4814 KOps/s | 45.0476 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 65.9730μs | 22.1821μs | 45.0814 KOps/s | 45.0156 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 42.3090μs | 13.8230μs | 72.3432 KOps/s | 72.9899 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 75.7820μs | 36.6453μs | 27.2886 KOps/s | 27.2284 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 56.5360μs | 24.3671μs | 41.0390 KOps/s | 41.5298 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 2.9229ms | 24.2517μs | 41.2342 KOps/s | 41.5398 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 46.3270μs | 15.7680μs | 63.4194 KOps/s | 64.6437 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 71.3930μs | 38.9586μs | 25.6683 KOps/s | 25.9908 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 59.7910μs | 26.1029μs | 38.3099 KOps/s | 38.1920 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 57.5480μs | 23.9844μs | 41.6937 KOps/s | 41.7881 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 41.9680μs | 15.5725μs | 64.2157 KOps/s | 64.4791 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 78.9680μs | 40.3142μs | 24.8052 KOps/s | 24.7194 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 73.1970μs | 27.7048μs | 36.0949 KOps/s | 36.4004 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 73.9480μs | 25.4036μs | 39.3645 KOps/s | 39.2031 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 46.4160μs | 17.2915μs | 57.8319 KOps/s | 58.2133 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 10.0010ms | 9.6551ms | 103.5717 Ops/s | 103.3488 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 35.6475ms | 33.5991ms | 29.7627 Ops/s | 27.8848 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.2179ms | 0.1778ms | 5.6244 KOps/s | 5.2226 KOps/s | |
test_values[td1_return_estimate-False-False] | 27.9134ms | 23.8874ms | 41.8630 Ops/s | 40.3439 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 35.7674ms | 33.7090ms | 29.6657 Ops/s | 27.4792 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 38.6580ms | 34.4257ms | 29.0481 Ops/s | 27.7775 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 35.9572ms | 33.7437ms | 29.6352 Ops/s | 27.8602 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 11.3420ms | 8.3495ms | 119.7679 Ops/s | 119.9945 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 2.2746ms | 2.0278ms | 493.1472 Ops/s | 501.9593 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.6157ms | 0.3597ms | 2.7800 KOps/s | 2.7821 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 48.6672ms | 45.5077ms | 21.9743 Ops/s | 21.2588 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 4.0724ms | 3.0622ms | 326.5594 Ops/s | 327.4381 Ops/s | |
test_dqn_speed[False-None] | 5.7508ms | 1.4010ms | 713.7663 Ops/s | 719.1149 Ops/s | |
test_dqn_speed[False-backward] | 1.9484ms | 1.8535ms | 539.5205 Ops/s | 527.1339 Ops/s | |
test_dqn_speed[True-None] | 0.7256ms | 0.4587ms | 2.1801 KOps/s | 2.1493 KOps/s | |
test_dqn_speed[True-backward] | 0.9119ms | 0.8717ms | 1.1472 KOps/s | 739.7870 Ops/s | |
test_dqn_speed[reduce-overhead-None] | 0.7593ms | 0.4694ms | 2.1303 KOps/s | 2.1438 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 1.0106ms | 0.9005ms | 1.1105 KOps/s | 1.1145 KOps/s | |
test_ddpg_speed[False-None] | 3.5694ms | 2.8727ms | 348.1008 Ops/s | 346.8930 Ops/s | |
test_ddpg_speed[False-backward] | 4.5263ms | 4.0293ms | 248.1829 Ops/s | 245.9799 Ops/s | |
test_ddpg_speed[True-None] | 1.2577ms | 0.9953ms | 1.0047 KOps/s | 996.9787 Ops/s | |
test_ddpg_speed[True-backward] | 2.0458ms | 1.9209ms | 520.5853 Ops/s | 525.3059 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 1.2787ms | 1.0083ms | 991.7600 Ops/s | 997.9820 Ops/s | |
test_ddpg_speed[reduce-overhead-backward] | 1.9639ms | 1.8900ms | 529.0952 Ops/s | 514.6894 Ops/s | |
test_sac_speed[False-None] | 10.0450ms | 8.1230ms | 123.1077 Ops/s | 121.9806 Ops/s | |
test_sac_speed[False-backward] | 12.6584ms | 11.0301ms | 90.6607 Ops/s | 88.0042 Ops/s | |
test_sac_speed[True-None] | 2.3562ms | 1.8225ms | 548.6971 Ops/s | 536.1286 Ops/s | |
test_sac_speed[True-backward] | 3.6925ms | 3.5354ms | 282.8569 Ops/s | 266.6449 Ops/s | |
test_sac_speed[reduce-overhead-None] | 2.3157ms | 1.8373ms | 544.2866 Ops/s | 540.2007 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 3.8358ms | 3.5922ms | 278.3792 Ops/s | 278.1524 Ops/s | |
test_redq_speed[False-None] | 14.9658ms | 13.1489ms | 76.0520 Ops/s | 66.1202 Ops/s | |
test_redq_speed[False-backward] | 24.0038ms | 22.5293ms | 44.3867 Ops/s | 43.4013 Ops/s | |
test_redq_speed[True-None] | 5.5611ms | 4.9129ms | 203.5462 Ops/s | 199.1081 Ops/s | |
test_redq_speed[True-backward] | 13.0394ms | 12.4099ms | 80.5810 Ops/s | 80.0405 Ops/s | |
test_redq_speed[reduce-overhead-None] | 6.0542ms | 4.8354ms | 206.8069 Ops/s | 212.9029 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 12.8065ms | 12.2760ms | 81.4599 Ops/s | 77.7997 Ops/s | |
test_redq_deprec_speed[False-None] | 14.6243ms | 13.1606ms | 75.9842 Ops/s | 74.0935 Ops/s | |
test_redq_deprec_speed[False-backward] | 20.5191ms | 19.4027ms | 51.5391 Ops/s | 50.7933 Ops/s | |
test_redq_deprec_speed[True-None] | 4.9656ms | 3.6715ms | 272.3663 Ops/s | 260.0350 Ops/s | |
test_redq_deprec_speed[True-backward] | 8.8491ms | 8.1009ms | 123.4432 Ops/s | 108.0115 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 4.1842ms | 3.6373ms | 274.9289 Ops/s | 266.1424 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 8.8563ms | 8.3470ms | 119.8036 Ops/s | 114.6569 Ops/s | |
test_td3_speed[False-None] | 8.6006ms | 8.1820ms | 122.2197 Ops/s | 120.4833 Ops/s | |
test_td3_speed[False-backward] | 11.3159ms | 10.7808ms | 92.7577 Ops/s | 93.2924 Ops/s | |
test_td3_speed[True-None] | 1.9480ms | 1.7285ms | 578.5453 Ops/s | 570.8091 Ops/s | |
test_td3_speed[True-backward] | 3.4412ms | 3.3150ms | 301.6573 Ops/s | 290.3016 Ops/s | |
test_td3_speed[reduce-overhead-None] | 2.0098ms | 1.7183ms | 581.9540 Ops/s | 565.8830 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 4.9133ms | 3.4425ms | 290.4861 Ops/s | 296.5032 Ops/s | |
test_cql_speed[False-None] | 41.7374ms | 36.7371ms | 27.2204 Ops/s | 26.5740 Ops/s | |
test_cql_speed[False-backward] | 48.5246ms | 46.3429ms | 21.5783 Ops/s | 20.7144 Ops/s | |
test_cql_speed[True-None] | 17.4439ms | 15.8973ms | 62.9038 Ops/s | 62.9375 Ops/s | |
test_cql_speed[True-backward] | 24.0704ms | 22.4800ms | 44.4840 Ops/s | 42.9276 Ops/s | |
test_cql_speed[reduce-overhead-None] | 17.1425ms | 15.6862ms | 63.7502 Ops/s | 61.4580 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 24.2857ms | 22.5350ms | 44.3755 Ops/s | 42.7414 Ops/s | |
test_a2c_speed[False-None] | 8.7240ms | 7.2801ms | 137.3601 Ops/s | 132.9376 Ops/s | |
test_a2c_speed[False-backward] | 16.2533ms | 14.6887ms | 68.0796 Ops/s | 66.7876 Ops/s | |
test_a2c_speed[True-None] | 5.0173ms | 4.2080ms | 237.6431 Ops/s | 221.8664 Ops/s | |
test_a2c_speed[True-backward] | 11.5344ms | 10.7156ms | 93.3222 Ops/s | 90.4524 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 4.7601ms | 4.2185ms | 237.0526 Ops/s | 230.0639 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 13.9312ms | 11.1245ms | 89.8915 Ops/s | 85.9541 Ops/s | |
test_ppo_speed[False-None] | 8.8717ms | 7.5901ms | 131.7508 Ops/s | 126.4333 Ops/s | |
test_ppo_speed[False-backward] | 16.1270ms | 15.0101ms | 66.6218 Ops/s | 62.0561 Ops/s | |
test_ppo_speed[True-None] | 4.1909ms | 3.7097ms | 269.5669 Ops/s | 260.9093 Ops/s | |
test_ppo_speed[True-backward] | 10.1309ms | 9.6060ms | 104.1019 Ops/s | 102.0013 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 4.0614ms | 3.7030ms | 270.0491 Ops/s | 266.2048 Ops/s | |
test_ppo_speed[reduce-overhead-backward] | 10.5519ms | 9.6502ms | 103.6244 Ops/s | 101.6805 Ops/s | |
test_reinforce_speed[False-None] | 8.6258ms | 6.6618ms | 150.1089 Ops/s | 148.6708 Ops/s | |
test_reinforce_speed[False-backward] | 10.2616ms | 9.8412ms | 101.6140 Ops/s | 99.5938 Ops/s | |
test_reinforce_speed[True-None] | 3.3999ms | 2.6598ms | 375.9742 Ops/s | 366.5452 Ops/s | |
test_reinforce_speed[True-backward] | 9.7821ms | 8.8080ms | 113.5334 Ops/s | 111.2635 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 3.3086ms | 2.6424ms | 378.4471 Ops/s | 361.7890 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 12.1617ms | 8.7286ms | 114.5663 Ops/s | 112.7065 Ops/s | |
test_iql_speed[False-None] | 34.3253ms | 32.6206ms | 30.6555 Ops/s | 30.0787 Ops/s | |
test_iql_speed[False-backward] | 47.6944ms | 45.4260ms | 22.0138 Ops/s | 21.5091 Ops/s | |
test_iql_speed[True-None] | 11.9180ms | 10.8960ms | 91.7770 Ops/s | 88.5038 Ops/s | |
test_iql_speed[True-backward] | 22.6103ms | 21.8625ms | 45.7404 Ops/s | 45.4540 Ops/s | |
test_iql_speed[reduce-overhead-None] | 12.0486ms | 10.7475ms | 93.0452 Ops/s | 90.6950 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 22.3013ms | 21.6224ms | 46.2484 Ops/s | 45.4466 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 5.4476ms | 4.9971ms | 200.1170 Ops/s | 192.5217 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 1.0135ms | 0.5170ms | 1.9342 KOps/s | 1.9295 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.7940ms | 0.4918ms | 2.0332 KOps/s | 2.0112 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.7242ms | 4.8475ms | 206.2923 Ops/s | 208.0166 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 1.4412ms | 0.5037ms | 1.9852 KOps/s | 1.9568 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.7211ms | 0.4791ms | 2.0872 KOps/s | 2.0620 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 1.8569ms | 1.6242ms | 615.6868 Ops/s | 601.1488 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 2.3170ms | 1.5895ms | 629.1306 Ops/s | 620.0229 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 5.3322ms | 4.9077ms | 203.7598 Ops/s | 195.8964 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 0.9937ms | 0.6515ms | 1.5350 KOps/s | 1.4976 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.8446ms | 0.6210ms | 1.6102 KOps/s | 1.5730 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 5.2727ms | 4.8099ms | 207.9056 Ops/s | 200.0216 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.9019ms | 0.5230ms | 1.9121 KOps/s | 1.8388 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.6870ms | 0.4884ms | 2.0474 KOps/s | 1.9921 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 7.2587ms | 4.8254ms | 207.2378 Ops/s | 200.9485 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 3.1404ms | 0.5077ms | 1.9697 KOps/s | 1.9665 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.6923ms | 0.4851ms | 2.0615 KOps/s | 2.0338 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 5.6563ms | 4.9005ms | 204.0620 Ops/s | 203.0697 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 2.4046ms | 0.6522ms | 1.5333 KOps/s | 1.5106 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.8784ms | 0.6285ms | 1.5911 KOps/s | 1.5889 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 0.4409s | 13.0543ms | 76.6032 Ops/s | 37.3512 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 7.2183ms | 2.2899ms | 436.7013 Ops/s | 425.4940 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 5.6740ms | 1.2888ms | 775.9440 Ops/s | 718.3987 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 5.5813ms | 4.3237ms | 231.2818 Ops/s | 216.6238 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 0.3913s | 10.0514ms | 99.4886 Ops/s | 432.3913 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 4.3557ms | 1.2710ms | 786.7889 Ops/s | 649.5858 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 5.9564ms | 4.5980ms | 217.4882 Ops/s | 230.0512 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 8.2517ms | 2.4429ms | 409.3548 Ops/s | 400.6435 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 2.6390ms | 1.3606ms | 734.9763 Ops/s | 666.0348 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 12.7605ms | 11.2910ms | 88.5662 Ops/s | 83.3391 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 16.2688ms | 15.0364ms | 66.5052 Ops/s | 64.6946 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 21.7324ms | 19.9826ms | 50.0436 Ops/s | 48.5716 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 16.5968ms | 15.1227ms | 66.1259 Ops/s | 64.9141 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 21.3140ms | 20.0423ms | 49.8946 Ops/s | 48.5774 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 17.2809ms | 16.3102ms | 61.3112 Ops/s | 59.3299 Ops/s |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.7118s | 0.7097s | 1.4090 Ops/s | 1.3687 Ops/s | |
test_transformed | 0.9628s | 0.9622s | 1.0392 Ops/s | 1.0455 Ops/s | |
test_serial | 2.1970s | 2.1126s | 0.4734 Ops/s | 0.4810 Ops/s | |
test_parallel | 2.0222s | 1.9284s | 0.5186 Ops/s | 0.5033 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.2023ms | 39.6569μs | 25.2163 KOps/s | 25.0189 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 0.3964ms | 23.4059μs | 42.7242 KOps/s | 42.6404 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 49.7600μs | 22.1223μs | 45.2033 KOps/s | 45.0760 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 47.4600μs | 12.9425μs | 77.2646 KOps/s | 77.0218 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 0.4228ms | 42.3073μs | 23.6366 KOps/s | 23.3713 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 58.6810μs | 25.6056μs | 39.0540 KOps/s | 38.8871 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 0.4050ms | 24.1444μs | 41.4174 KOps/s | 40.0850 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 0.3946ms | 15.3586μs | 65.1102 KOps/s | 65.3083 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 87.0910μs | 45.2848μs | 22.0825 KOps/s | 22.2381 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 0.4099ms | 28.3019μs | 35.3333 KOps/s | 35.2175 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 0.4048ms | 24.7834μs | 40.3495 KOps/s | 40.2928 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 45.1910μs | 15.2358μs | 65.6347 KOps/s | 65.0621 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 0.4255ms | 47.0620μs | 21.2485 KOps/s | 21.1400 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 0.4077ms | 30.2660μs | 33.0404 KOps/s | 32.7536 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 0.4081ms | 26.7300μs | 37.4111 KOps/s | 37.2814 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 0.3939ms | 17.4382μs | 57.3455 KOps/s | 56.1445 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 81.5310μs | 44.8349μs | 22.3040 KOps/s | 22.1996 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 0.4046ms | 27.9095μs | 35.8301 KOps/s | 35.8608 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 0.4149ms | 28.5560μs | 35.0189 KOps/s | 35.0160 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 50.6100μs | 17.1172μs | 58.4208 KOps/s | 58.1786 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 0.4306ms | 47.0246μs | 21.2655 KOps/s | 21.3035 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 67.2500μs | 30.3954μs | 32.8998 KOps/s | 33.2329 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 3.0426ms | 31.0819μs | 32.1731 KOps/s | 32.2247 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 0.4004ms | 19.3887μs | 51.5765 KOps/s | 51.3686 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 96.4610μs | 49.6079μs | 20.1581 KOps/s | 20.3234 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 0.4154ms | 32.7621μs | 30.5231 KOps/s | 30.2656 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 70.9410μs | 28.7329μs | 34.8033 KOps/s | 32.5387 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 46.6910μs | 19.1575μs | 52.1990 KOps/s | 51.2019 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 98.6910μs | 50.7617μs | 19.6999 KOps/s | 19.2576 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 64.4100μs | 34.6616μs | 28.8504 KOps/s | 30.0607 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 69.9810μs | 31.8533μs | 31.3939 KOps/s | 31.1329 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 48.9710μs | 21.2034μs | 47.1622 KOps/s | 46.3896 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 24.5693ms | 24.2088ms | 41.3072 Ops/s | 41.6637 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 0.1024s | 2.9376ms | 340.4149 Ops/s | 321.7123 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.1039ms | 79.0033μs | 12.6577 KOps/s | 12.7531 KOps/s | |
test_values[td1_return_estimate-False-False] | 54.6604ms | 53.9696ms | 18.5289 Ops/s | 18.5381 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 1.3538ms | 1.0700ms | 934.5392 Ops/s | 936.5971 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 85.9052ms | 85.4230ms | 11.7064 Ops/s | 11.7425 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 1.3815ms | 1.0658ms | 938.2703 Ops/s | 936.7071 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 24.2999ms | 24.0308ms | 41.6132 Ops/s | 42.2405 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 1.0238ms | 0.7430ms | 1.3460 KOps/s | 1.3463 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.7663ms | 0.6565ms | 1.5233 KOps/s | 1.5257 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 1.5367ms | 1.4692ms | 680.6355 Ops/s | 679.7602 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 0.8305ms | 0.6692ms | 1.4944 KOps/s | 1.4650 KOps/s | |
test_dqn_speed[False-None] | 6.7769ms | 1.5196ms | 658.0756 Ops/s | 670.1378 Ops/s | |
test_dqn_speed[False-backward] | 2.1546ms | 2.1021ms | 475.7093 Ops/s | 482.3031 Ops/s | |
test_dqn_speed[True-None] | 0.9245ms | 0.5355ms | 1.8673 KOps/s | 1.7952 KOps/s | |
test_dqn_speed[True-backward] | 1.2419ms | 1.1916ms | 839.1745 Ops/s | 848.9675 Ops/s | |
test_dqn_speed[reduce-overhead-None] | 0.9739ms | 0.5750ms | 1.7391 KOps/s | 1.8036 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 1.1041ms | 1.0608ms | 942.6492 Ops/s | 1.0296 KOps/s | |
test_ddpg_speed[False-None] | 3.1893ms | 2.8347ms | 352.7664 Ops/s | 353.5390 Ops/s | |
test_ddpg_speed[False-backward] | 4.4715ms | 4.1386ms | 241.6263 Ops/s | 249.7372 Ops/s | |
test_ddpg_speed[True-None] | 1.2795ms | 1.0775ms | 928.1072 Ops/s | 917.7667 Ops/s | |
test_ddpg_speed[True-backward] | 2.3238ms | 2.2846ms | 437.7152 Ops/s | 461.7143 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 1.2266ms | 1.0844ms | 922.2111 Ops/s | 888.9223 Ops/s | |
test_ddpg_speed[reduce-overhead-backward] | 1.8235ms | 1.7585ms | 568.6792 Ops/s | 605.7765 Ops/s | |
test_sac_speed[False-None] | 8.4338ms | 8.0252ms | 124.6076 Ops/s | 125.2197 Ops/s | |
test_sac_speed[False-backward] | 11.5429ms | 11.0647ms | 90.3778 Ops/s | 92.3751 Ops/s | |
test_sac_speed[True-None] | 1.9716ms | 1.5387ms | 649.9043 Ops/s | 645.1472 Ops/s | |
test_sac_speed[True-backward] | 3.8633ms | 3.4323ms | 291.3459 Ops/s | 294.0736 Ops/s | |
test_sac_speed[reduce-overhead-None] | 23.0241ms | 12.5665ms | 79.5766 Ops/s | 79.9982 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 1.5788ms | 1.4845ms | 673.6279 Ops/s | 749.8924 Ops/s | |
test_redq_speed[False-None] | 8.7035ms | 7.5078ms | 133.1945 Ops/s | 132.9554 Ops/s | |
test_redq_speed[False-backward] | 12.3712ms | 11.5563ms | 86.5331 Ops/s | 89.0470 Ops/s | |
test_redq_speed[True-None] | 2.0377ms | 1.9839ms | 504.0459 Ops/s | 492.1801 Ops/s | |
test_redq_speed[True-backward] | 4.3549ms | 3.8355ms | 260.7201 Ops/s | 266.1378 Ops/s | |
test_redq_speed[reduce-overhead-None] | 2.1618ms | 2.0948ms | 477.3644 Ops/s | 487.2390 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 4.0955ms | 3.6419ms | 274.5803 Ops/s | 260.9628 Ops/s | |
test_redq_deprec_speed[False-None] | 9.9129ms | 9.1653ms | 109.1072 Ops/s | 110.2757 Ops/s | |
test_redq_deprec_speed[False-backward] | 12.4213ms | 11.9291ms | 83.8284 Ops/s | 81.0516 Ops/s | |
test_redq_deprec_speed[True-None] | 2.3728ms | 2.3166ms | 431.6669 Ops/s | 425.7320 Ops/s | |
test_redq_deprec_speed[True-backward] | 4.5541ms | 4.1492ms | 241.0127 Ops/s | 240.4184 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 2.3976ms | 2.3138ms | 432.1931 Ops/s | 408.5259 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 4.1042ms | 3.9802ms | 251.2427 Ops/s | 233.2979 Ops/s | |
test_td3_speed[False-None] | 34.7288ms | 8.4139ms | 118.8515 Ops/s | 126.8221 Ops/s | |
test_td3_speed[False-backward] | 10.8283ms | 10.3031ms | 97.0579 Ops/s | 95.7000 Ops/s | |
test_td3_speed[True-None] | 1.6289ms | 1.5966ms | 626.3356 Ops/s | 628.7796 Ops/s | |
test_td3_speed[True-backward] | 3.2571ms | 3.1590ms | 316.5527 Ops/s | 318.5017 Ops/s | |
test_td3_speed[reduce-overhead-None] | 50.3900ms | 25.7791ms | 38.7911 Ops/s | 36.7774 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 1.3337ms | 1.2677ms | 788.8217 Ops/s | 697.5447 Ops/s | |
test_cql_speed[False-None] | 17.2146ms | 16.6801ms | 59.9518 Ops/s | 58.9723 Ops/s | |
test_cql_speed[False-backward] | 22.2747ms | 21.6676ms | 46.1519 Ops/s | 45.0742 Ops/s | |
test_cql_speed[True-None] | 3.0554ms | 2.9552ms | 338.3888 Ops/s | 336.0626 Ops/s | |
test_cql_speed[True-backward] | 5.4987ms | 5.1008ms | 196.0487 Ops/s | 194.8729 Ops/s | |
test_cql_speed[reduce-overhead-None] | 21.5085ms | 13.2100ms | 75.7005 Ops/s | 75.8216 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 1.5711ms | 1.4920ms | 670.2358 Ops/s | 608.2093 Ops/s | |
test_a2c_speed[False-None] | 3.3417ms | 3.1601ms | 316.4484 Ops/s | 311.9196 Ops/s | |
test_a2c_speed[False-backward] | 6.5822ms | 5.9842ms | 167.1056 Ops/s | 157.1767 Ops/s | |
test_a2c_speed[True-None] | 1.0809ms | 1.0020ms | 998.0261 Ops/s | 978.3653 Ops/s | |
test_a2c_speed[True-backward] | 2.6359ms | 2.5830ms | 387.1513 Ops/s | 357.8377 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 21.3698ms | 11.6424ms | 85.8929 Ops/s | 86.4899 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 1.0252ms | 0.9691ms | 1.0319 KOps/s | 1.0151 KOps/s | |
test_ppo_speed[False-None] | 3.9688ms | 3.8102ms | 262.4539 Ops/s | 275.2663 Ops/s | |
test_ppo_speed[False-backward] | 7.6340ms | 6.8032ms | 146.9892 Ops/s | 149.3931 Ops/s | |
test_ppo_speed[True-None] | 1.0663ms | 0.9841ms | 1.0161 KOps/s | 1.0414 KOps/s | |
test_ppo_speed[True-backward] | 2.6306ms | 2.5673ms | 389.5202 Ops/s | 388.6112 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 0.5730ms | 0.5066ms | 1.9741 KOps/s | 1.9135 KOps/s | |
test_ppo_speed[reduce-overhead-backward] | 1.0143ms | 0.9472ms | 1.0557 KOps/s | 890.0178 Ops/s | |
test_reinforce_speed[False-None] | 2.4505ms | 2.2526ms | 443.9384 Ops/s | 449.6850 Ops/s | |
test_reinforce_speed[False-backward] | 3.6130ms | 3.2098ms | 311.5498 Ops/s | 302.4322 Ops/s | |
test_reinforce_speed[True-None] | 0.8911ms | 0.8218ms | 1.2168 KOps/s | 1.1805 KOps/s | |
test_reinforce_speed[True-backward] | 2.4682ms | 2.4013ms | 416.4475 Ops/s | 386.1304 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 22.0170ms | 11.6691ms | 85.6961 Ops/s | 88.5900 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 1.0564ms | 1.0213ms | 979.1218 Ops/s | 836.4456 Ops/s | |
test_iql_speed[False-None] | 9.7493ms | 9.2141ms | 108.5288 Ops/s | 109.2875 Ops/s | |
test_iql_speed[False-backward] | 13.5824ms | 12.7953ms | 78.1538 Ops/s | 76.6992 Ops/s | |
test_iql_speed[True-None] | 2.2395ms | 1.7716ms | 564.4539 Ops/s | 569.7311 Ops/s | |
test_iql_speed[True-backward] | 4.5151ms | 4.3994ms | 227.3012 Ops/s | 225.1074 Ops/s | |
test_iql_speed[reduce-overhead-None] | 20.5018ms | 11.5377ms | 86.6725 Ops/s | 87.9458 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 1.4464ms | 1.3984ms | 715.0824 Ops/s | 636.1179 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 7.8387ms | 6.4368ms | 155.3578 Ops/s | 152.7301 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.4781ms | 0.2690ms | 3.7173 KOps/s | 3.1270 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.5881ms | 0.2474ms | 4.0420 KOps/s | 2.9817 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.4476ms | 6.1641ms | 162.2290 Ops/s | 160.6765 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 2.2376ms | 0.3022ms | 3.3093 KOps/s | 2.9009 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.6481ms | 0.2982ms | 3.3536 KOps/s | 4.1932 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 1.4625ms | 1.2109ms | 825.8089 Ops/s | 759.1159 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 1.6045ms | 1.1705ms | 854.3081 Ops/s | 816.9909 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.5227ms | 6.3722ms | 156.9313 Ops/s | 157.1642 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.1185ms | 0.4407ms | 2.2693 KOps/s | 2.0364 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.6860ms | 0.4110ms | 2.4333 KOps/s | 2.2736 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 6.2836ms | 6.1786ms | 161.8480 Ops/s | 163.0578 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 1.9804ms | 0.3577ms | 2.7957 KOps/s | 3.6346 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.6184ms | 0.3397ms | 2.9434 KOps/s | 2.8232 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.3967ms | 6.1732ms | 161.9903 Ops/s | 161.2698 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 1.5202ms | 0.2574ms | 3.8851 KOps/s | 3.1465 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.4367ms | 0.2364ms | 4.2309 KOps/s | 4.2055 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.4906ms | 6.3015ms | 158.6913 Ops/s | 157.6789 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 2.6155ms | 0.4995ms | 2.0019 KOps/s | 2.2363 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.6916ms | 0.4755ms | 2.1031 KOps/s | 2.1288 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 7.2274ms | 5.3817ms | 185.8156 Ops/s | 186.2375 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 10.2199ms | 2.0706ms | 482.9588 Ops/s | 437.8107 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 6.8401ms | 1.2274ms | 814.7222 Ops/s | 843.6348 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 7.0021ms | 5.4346ms | 184.0070 Ops/s | 188.0095 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 10.3735ms | 2.0619ms | 484.9825 Ops/s | 428.5337 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 3.4237ms | 1.1473ms | 871.6244 Ops/s | 850.6718 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.5028s | 15.6050ms | 64.0819 Ops/s | 33.0593 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 8.1922ms | 2.1102ms | 473.8994 Ops/s | 547.6302 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 8.7146ms | 1.3731ms | 728.2687 Ops/s | 819.9522 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 13.5394ms | 13.3634ms | 74.8315 Ops/s | 73.8652 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 19.1877ms | 17.1316ms | 58.3715 Ops/s | 56.3741 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 18.2481ms | 17.7664ms | 56.2861 Ops/s | 54.7010 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 19.4761ms | 17.3741ms | 57.5571 Ops/s | 56.3965 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 17.5614ms | 17.3553ms | 57.6192 Ops/s | 54.4325 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 20.0309ms | 18.5890ms | 53.7952 Ops/s | 51.8166 Ops/s |
ghstack-source-id: a2c74421a4c31193365157b1f53d2cf1010e9452 Pull Request resolved: pytorch#2661
ghstack-source-id: d0fbb520e35c74305041340722a7560ac2f958f2 Pull Request resolved: #2661
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fantastic I love the shape it's taking
We should add a kwarg for representing the board as a tensor, I will add it later!
@@ -127,10 +127,13 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None): | |||
self._set_action_space(tensordict) | |||
return super().rand_action(tensordict) | |||
|
|||
def _is_done(self, board): | |||
return board.is_game_over() | board.is_fifty_moves() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very good point thanks for flagging it!
I'd be in favor of documenting it in the docstrings for now.
One option could be to use a state that is a sequence of all previous fen (like we do with LLMs where a state is all the prompt tokens + all generated tokens) and manually check for duplicates.
I see that chess also has a the option of exporting PGNs but it's a bit convoluted because we'd need to create a
import chess.pgn
# Create a new Board object
board = chess.Board()
# Make some moves
board.push_uci("e2e4")
board.push_uci("e7e5")
board.push_uci("d2d4")
# Create a new Game object
game = chess.pgn.Game()
# Add the moves to the game
node = game
for move in board.move_stack:
node = node.add_variation(move)
# Generate the PGN string
pgn_string = str(game)
print(pgn_string)
which gives us
[Event "?"]
[Site "?"]
[Date "????.??.??"]
[Round "?"]
[White "?"]
[Black "?"]
[Result "*"]
1. e4 e5 2. d4 *
So in practice, we could be using fen
or pgn
(or else?) as a state representation
env = ChessEnv(representation="fen")
I believe there should be a way to make these interchangeable (at the end of the day it's still string-in string-out)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fantastic I love the shape it's taking
We should add a kwarg for representing the board as a tensor, I will add it later!
Stack from ghstack (oldest at bottom):