From 5bb6abdb49fabab5ba201ee7abb416b2fad4d1ac Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 13:07:43 -0800 Subject: [PATCH] [Quality] IMPALA auto-device ghstack-source-id: 202e8a48b78cc03277f928cbab696d26253bc0ee Pull Request resolved: https://github.com/pytorch/rl/pull/2654 --- sota-implementations/impala/config_multi_node_ray.yaml | 2 +- .../impala/config_multi_node_submitit.yaml | 2 +- sota-implementations/impala/config_single_node.yaml | 2 +- sota-implementations/impala/impala_multi_node_ray.py | 6 +++++- sota-implementations/impala/impala_multi_node_submitit.py | 6 +++++- sota-implementations/impala/impala_single_node.py | 7 +++++-- sota-implementations/ppo/ppo_atari.py | 1 + sota-implementations/ppo/ppo_mujoco.py | 1 + 8 files changed, 20 insertions(+), 7 deletions(-) diff --git a/sota-implementations/impala/config_multi_node_ray.yaml b/sota-implementations/impala/config_multi_node_ray.yaml index c67b5ed52da..549428a4725 100644 --- a/sota-implementations/impala/config_multi_node_ray.yaml +++ b/sota-implementations/impala/config_multi_node_ray.yaml @@ -24,7 +24,7 @@ ray_init_config: storage: null # Device for the forward and backward passes -local_device: "cuda:0" +local_device: # Resources assigned to each IMPALA rollout collection worker remote_worker_resources: diff --git a/sota-implementations/impala/config_multi_node_submitit.yaml b/sota-implementations/impala/config_multi_node_submitit.yaml index 59973e46b40..4d4332722aa 100644 --- a/sota-implementations/impala/config_multi_node_submitit.yaml +++ b/sota-implementations/impala/config_multi_node_submitit.yaml @@ -3,7 +3,7 @@ env: env_name: PongNoFrameskip-v4 # Device for the forward and backward passes -local_device: "cuda:0" +local_device: # SLURM config slurm_config: diff --git a/sota-implementations/impala/config_single_node.yaml b/sota-implementations/impala/config_single_node.yaml index b93c3802a33..655edaddc4e 100644 --- a/sota-implementations/impala/config_single_node.yaml +++ b/sota-implementations/impala/config_single_node.yaml @@ -3,7 +3,7 @@ env: env_name: PongNoFrameskip-v4 # Device for the forward and backward passes -device: "cuda:0" +device: # collector collector: diff --git a/sota-implementations/impala/impala_multi_node_ray.py b/sota-implementations/impala/impala_multi_node_ray.py index ba40de1acde..b2b724f6a6d 100644 --- a/sota-implementations/impala/impala_multi_node_ray.py +++ b/sota-implementations/impala/impala_multi_node_ray.py @@ -32,7 +32,11 @@ def main(cfg: "DictConfig"): # noqa: F821 from torchrl.record.loggers import generate_exp_name, get_logger from utils import eval_model, make_env, make_ppo_models - device = torch.device(cfg.local_device) + device = cfg.local_device + if not device: + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0") + else: + device = torch.device(device) # Correct for frame_skip frame_skip = 4 diff --git a/sota-implementations/impala/impala_multi_node_submitit.py b/sota-implementations/impala/impala_multi_node_submitit.py index 5f77008a12b..07d38604391 100644 --- a/sota-implementations/impala/impala_multi_node_submitit.py +++ b/sota-implementations/impala/impala_multi_node_submitit.py @@ -34,7 +34,11 @@ def main(cfg: "DictConfig"): # noqa: F821 from torchrl.record.loggers import generate_exp_name, get_logger from utils import eval_model, make_env, make_ppo_models - device = torch.device(cfg.local_device) + device = cfg.local_device + if not device: + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0") + else: + device = torch.device(device) # Correct for frame_skip frame_skip = 4 diff --git a/sota-implementations/impala/impala_single_node.py b/sota-implementations/impala/impala_single_node.py index 130d0d30dd7..cd11ae467c3 100644 --- a/sota-implementations/impala/impala_single_node.py +++ b/sota-implementations/impala/impala_single_node.py @@ -31,7 +31,11 @@ def main(cfg: "DictConfig"): # noqa: F821 from torchrl.record.loggers import generate_exp_name, get_logger from utils import eval_model, make_env, make_ppo_models - device = torch.device(cfg.device) + device = cfg.device + if not device: + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0") + else: + device = torch.device(device) # Correct for frame_skip frame_skip = 4 @@ -55,7 +59,6 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create models (check utils.py) actor, critic = make_ppo_models(cfg.env.env_name) - actor, critic = actor.to(device), critic.to(device) # Create collector collector = MultiaSyncDataCollector( diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index 8b97f227490..39cd208cdfe 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -178,6 +178,7 @@ def update(batch, num_network_updates): # Update the networks optim.step() return loss.detach().set("alpha", alpha), num_network_updates.clone() + if cfg.compile.compile: update = compile_with_warmup(update, mode=compile_mode, warmup=1) adv_module = compile_with_warmup(adv_module, mode=compile_mode, warmup=1) diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index 162b8e701df..9ae4d549f3e 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -163,6 +163,7 @@ def update(batch, num_network_updates): # Update the networks optim.step() return loss.detach().set("alpha", alpha), num_network_updates.clone() + if cfg.compile.compile: update = compile_with_warmup(update, mode=compile_mode, warmup=1) adv_module = compile_with_warmup(adv_module, mode=compile_mode, warmup=1)