diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index 775dcfe206d..f8c18147306 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -201,7 +201,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( i * frames_in_batch * frame_skip ) // test_interval: diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index 0276039058f..d115174eb9c 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -186,7 +186,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): prev_test_frame = ((i - 1) * frames_in_batch) // cfg.logger.test_interval cur_test_frame = (i * frames_in_batch) // cfg.logger.test_interval final = collected_frames >= collector.total_frames diff --git a/sota-implementations/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py index d8185c8091c..5ca70f83b53 100644 --- a/sota-implementations/cql/cql_offline.py +++ b/sota-implementations/cql/cql_offline.py @@ -150,7 +150,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # evaluation if i % evaluation_interval == 0: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_td = eval_env.rollout( max_steps=eval_steps, policy=model[0], auto_cast_to_device=True ) diff --git a/sota-implementations/cql/cql_online.py b/sota-implementations/cql/cql_online.py index 5f8f81357c8..cf629ed0733 100644 --- a/sota-implementations/cql/cql_online.py +++ b/sota-implementations/cql/cql_online.py @@ -204,7 +204,7 @@ def main(cfg: "DictConfig"): # noqa: F821 cur_test_frame = (i * frames_per_batch) // evaluation_interval final = current_frames >= collector.total_frames if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/cql/discrete_cql_online.py b/sota-implementations/cql/discrete_cql_online.py index 4b6f14cd058..d0d6693eb97 100644 --- a/sota-implementations/cql/discrete_cql_online.py +++ b/sota-implementations/cql/discrete_cql_online.py @@ -183,7 +183,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/ddpg/ddpg.py b/sota-implementations/ddpg/ddpg.py index eb0b88c26f7..a92ee6185c3 100644 --- a/sota-implementations/ddpg/ddpg.py +++ b/sota-implementations/ddpg/ddpg.py @@ -185,7 +185,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/decision_transformer/dt.py b/sota-implementations/decision_transformer/dt.py index dcb074b77fe..9cca9fd8af5 100644 --- a/sota-implementations/decision_transformer/dt.py +++ b/sota-implementations/decision_transformer/dt.py @@ -116,7 +116,7 @@ def main(cfg: "DictConfig"): # noqa: F821 to_log = {"train/loss": loss_vals["loss"]} # Evaluation - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): if i % pretrain_log_interval == 0: eval_td = test_env.rollout( max_steps=eval_steps, diff --git a/sota-implementations/decision_transformer/online_dt.py b/sota-implementations/decision_transformer/online_dt.py index 5cb297e5c0b..da2241ce9fa 100644 --- a/sota-implementations/decision_transformer/online_dt.py +++ b/sota-implementations/decision_transformer/online_dt.py @@ -126,7 +126,7 @@ def main(cfg: "DictConfig"): # noqa: F821 } # Evaluation - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): inference_policy.eval() if i % pretrain_log_interval == 0: eval_td = test_env.rollout( diff --git a/sota-implementations/discrete_sac/discrete_sac.py b/sota-implementations/discrete_sac/discrete_sac.py index 6e100f92dc3..386f743c7d3 100644 --- a/sota-implementations/discrete_sac/discrete_sac.py +++ b/sota-implementations/discrete_sac/discrete_sac.py @@ -204,7 +204,7 @@ def main(cfg: "DictConfig"): # noqa: F821 cur_test_frame = (i * frames_per_batch) // eval_iter final = current_frames >= collector.total_frames if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index 90f93551d4d..906273ee2f5 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -199,7 +199,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get and log evaluation rewards and eval time - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): prev_test_frame = ((i - 1) * frames_per_batch) // test_interval cur_test_frame = (i * frames_per_batch) // test_interval final = current_frames >= collector.total_frames diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py index ac3f17a9203..173f88f7028 100644 --- a/sota-implementations/dqn/dqn_cartpole.py +++ b/sota-implementations/dqn/dqn_cartpole.py @@ -180,7 +180,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get and log evaluation rewards and eval time - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): prev_test_frame = ((i - 1) * frames_per_batch) // test_interval cur_test_frame = (i * frames_per_batch) // test_interval final = current_frames >= collector.total_frames diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index e7b346b2b22..af8d3334950 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -284,7 +284,7 @@ def compile_rssms(module): # Evaluation if (i % eval_iter) == 0: # Real env - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_rollout = test_env.rollout( eval_rollout_steps, policy, @@ -298,7 +298,7 @@ def compile_rssms(module): log_metrics(logger, eval_metrics, collected_frames) # Simulated env if model_based_env_eval is not None: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_rollout = model_based_env_eval.rollout( eval_rollout_steps, policy, diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index ff14871b011..59a17ff8648 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -535,7 +535,7 @@ def _dreamer_make_actor_real( SafeProbabilisticModule( in_keys=["loc", "scale"], out_keys=[action_key], - default_interaction_type=InteractionType.MODE, + default_interaction_type=InteractionType.DETERMINISTIC, distribution_class=TanhNormal, distribution_kwargs={"tanh_loc": True}, spec=CompositeSpec( diff --git a/sota-implementations/impala/impala_multi_node_ray.py b/sota-implementations/impala/impala_multi_node_ray.py index 0482a595ffa..1998c044305 100644 --- a/sota-implementations/impala/impala_multi_node_ray.py +++ b/sota-implementations/impala/impala_multi_node_ray.py @@ -247,7 +247,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( i * frames_in_batch * frame_skip ) // test_interval: diff --git a/sota-implementations/impala/impala_multi_node_submitit.py b/sota-implementations/impala/impala_multi_node_submitit.py index ce96cf06ce8..fdee4256c42 100644 --- a/sota-implementations/impala/impala_multi_node_submitit.py +++ b/sota-implementations/impala/impala_multi_node_submitit.py @@ -239,7 +239,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( i * frames_in_batch * frame_skip ) // test_interval: diff --git a/sota-implementations/impala/impala_single_node.py b/sota-implementations/impala/impala_single_node.py index bb0f314197a..cf583909620 100644 --- a/sota-implementations/impala/impala_single_node.py +++ b/sota-implementations/impala/impala_single_node.py @@ -217,7 +217,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( i * frames_in_batch * frame_skip ) // test_interval: diff --git a/sota-implementations/iql/discrete_iql.py b/sota-implementations/iql/discrete_iql.py index 33513dd3973..ae1894379fd 100644 --- a/sota-implementations/iql/discrete_iql.py +++ b/sota-implementations/iql/discrete_iql.py @@ -186,7 +186,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/iql/iql_offline.py b/sota-implementations/iql/iql_offline.py index d98724e1371..d1a16fd8192 100644 --- a/sota-implementations/iql/iql_offline.py +++ b/sota-implementations/iql/iql_offline.py @@ -130,7 +130,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # evaluation if i % evaluation_interval == 0: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_td = eval_env.rollout( max_steps=eval_steps, policy=model[0], auto_cast_to_device=True ) diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index b66c6f9dcf2..d50ff806294 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -184,7 +184,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/multiagent/iql.py b/sota-implementations/multiagent/iql.py index 81551ebefb7..a4d2b88a9d0 100644 --- a/sota-implementations/multiagent/iql.py +++ b/sota-implementations/multiagent/iql.py @@ -206,7 +206,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/sota-implementations/multiagent/maddpg_iddpg.py b/sota-implementations/multiagent/maddpg_iddpg.py index 9d14ff04b04..bd44bb0a043 100644 --- a/sota-implementations/multiagent/maddpg_iddpg.py +++ b/sota-implementations/multiagent/maddpg_iddpg.py @@ -230,7 +230,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/sota-implementations/multiagent/mappo_ippo.py b/sota-implementations/multiagent/mappo_ippo.py index e752c4d73f2..fa006a7d4a2 100644 --- a/sota-implementations/multiagent/mappo_ippo.py +++ b/sota-implementations/multiagent/mappo_ippo.py @@ -236,7 +236,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/sota-implementations/multiagent/qmix_vdn.py b/sota-implementations/multiagent/qmix_vdn.py index d294a9c783e..4e6a962c556 100644 --- a/sota-implementations/multiagent/qmix_vdn.py +++ b/sota-implementations/multiagent/qmix_vdn.py @@ -241,7 +241,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/sota-implementations/multiagent/sac.py b/sota-implementations/multiagent/sac.py index 30b7e7e98bc..f7b2523010b 100644 --- a/sota-implementations/multiagent/sac.py +++ b/sota-implementations/multiagent/sac.py @@ -300,7 +300,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index 908cb7924a3..2b02254032a 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -217,7 +217,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( i * frames_in_batch * frame_skip ) // test_interval: diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index e3e74971a49..219ae1b59b6 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -210,7 +210,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): if ((i - 1) * frames_in_batch) // cfg_logger_test_interval < ( i * frames_in_batch ) // cfg_logger_test_interval: diff --git a/sota-implementations/sac/sac.py b/sota-implementations/sac/sac.py index f7a399cda72..9904fe072ab 100644 --- a/sota-implementations/sac/sac.py +++ b/sota-implementations/sac/sac.py @@ -197,7 +197,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/td3/td3.py b/sota-implementations/td3/td3.py index 97fd039c238..5fbc9b032d7 100644 --- a/sota-implementations/td3/td3.py +++ b/sota-implementations/td3/td3.py @@ -195,7 +195,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps,