Skip to content

Commit

Permalink
[Quality] IMPALA auto-device
Browse files Browse the repository at this point in the history
ghstack-source-id: 202e8a48b78cc03277f928cbab696d26253bc0ee
Pull Request resolved: #2654
  • Loading branch information
vmoens committed Dec 15, 2024
1 parent 4b65245 commit 5bb6abd
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 7 deletions.
2 changes: 1 addition & 1 deletion sota-implementations/impala/config_multi_node_ray.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ ray_init_config:
storage: null

# Device for the forward and backward passes
local_device: "cuda:0"
local_device:

# Resources assigned to each IMPALA rollout collection worker
remote_worker_resources:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ env:
env_name: PongNoFrameskip-v4

# Device for the forward and backward passes
local_device: "cuda:0"
local_device:

# SLURM config
slurm_config:
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/impala/config_single_node.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ env:
env_name: PongNoFrameskip-v4

# Device for the forward and backward passes
device: "cuda:0"
device:

# collector
collector:
Expand Down
6 changes: 5 additions & 1 deletion sota-implementations/impala/impala_multi_node_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ def main(cfg: "DictConfig"): # noqa: F821
from torchrl.record.loggers import generate_exp_name, get_logger
from utils import eval_model, make_env, make_ppo_models

device = torch.device(cfg.local_device)
device = cfg.local_device
if not device:
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
else:
device = torch.device(device)

# Correct for frame_skip
frame_skip = 4
Expand Down
6 changes: 5 additions & 1 deletion sota-implementations/impala/impala_multi_node_submitit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ def main(cfg: "DictConfig"): # noqa: F821
from torchrl.record.loggers import generate_exp_name, get_logger
from utils import eval_model, make_env, make_ppo_models

device = torch.device(cfg.local_device)
device = cfg.local_device
if not device:
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
else:
device = torch.device(device)

# Correct for frame_skip
frame_skip = 4
Expand Down
7 changes: 5 additions & 2 deletions sota-implementations/impala/impala_single_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ def main(cfg: "DictConfig"): # noqa: F821
from torchrl.record.loggers import generate_exp_name, get_logger
from utils import eval_model, make_env, make_ppo_models

device = torch.device(cfg.device)
device = cfg.device
if not device:
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
else:
device = torch.device(device)

# Correct for frame_skip
frame_skip = 4
Expand All @@ -55,7 +59,6 @@ def main(cfg: "DictConfig"): # noqa: F821

# Create models (check utils.py)
actor, critic = make_ppo_models(cfg.env.env_name)
actor, critic = actor.to(device), critic.to(device)

# Create collector
collector = MultiaSyncDataCollector(
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/ppo/ppo_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def update(batch, num_network_updates):
# Update the networks
optim.step()
return loss.detach().set("alpha", alpha), num_network_updates.clone()

if cfg.compile.compile:
update = compile_with_warmup(update, mode=compile_mode, warmup=1)
adv_module = compile_with_warmup(adv_module, mode=compile_mode, warmup=1)
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/ppo/ppo_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def update(batch, num_network_updates):
# Update the networks
optim.step()
return loss.detach().set("alpha", alpha), num_network_updates.clone()

if cfg.compile.compile:
update = compile_with_warmup(update, mode=compile_mode, warmup=1)
adv_module = compile_with_warmup(adv_module, mode=compile_mode, warmup=1)
Expand Down

0 comments on commit 5bb6abd

Please sign in to comment.