Skip to content

Commit

Permalink
[Quality] IMPALA auto-device
Browse files Browse the repository at this point in the history
ghstack-source-id: 04ebbe09714518b328243d4680035d0527465ac5
Pull Request resolved: #2654
  • Loading branch information
vmoens committed Dec 15, 2024
1 parent 00352ab commit f95ac39
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 7 deletions.
2 changes: 1 addition & 1 deletion sota-implementations/impala/config_multi_node_ray.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ ray_init_config:
storage: null

# Device for the forward and backward passes
local_device: "cuda:0"
local_device:

# Resources assigned to each IMPALA rollout collection worker
remote_worker_resources:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ env:
env_name: PongNoFrameskip-v4

# Device for the forward and backward passes
local_device: "cuda:0"
local_device:

# SLURM config
slurm_config:
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/impala/config_single_node.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ env:
env_name: PongNoFrameskip-v4

# Device for the forward and backward passes
device: "cuda:0"
device:

# collector
collector:
Expand Down
6 changes: 5 additions & 1 deletion sota-implementations/impala/impala_multi_node_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ def main(cfg: "DictConfig"): # noqa: F821
from torchrl.record.loggers import generate_exp_name, get_logger
from utils import eval_model, make_env, make_ppo_models

device = torch.device(cfg.local_device)
device = cfg.local_device
if not device:
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
else:
device = torch.device(device)

# Correct for frame_skip
frame_skip = 4
Expand Down
6 changes: 5 additions & 1 deletion sota-implementations/impala/impala_multi_node_submitit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ def main(cfg: "DictConfig"): # noqa: F821
from torchrl.record.loggers import generate_exp_name, get_logger
from utils import eval_model, make_env, make_ppo_models

device = torch.device(cfg.local_device)
device = cfg.local_device
if not device:
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
else:
device = torch.device(device)

# Correct for frame_skip
frame_skip = 4
Expand Down
7 changes: 5 additions & 2 deletions sota-implementations/impala/impala_single_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ def main(cfg: "DictConfig"): # noqa: F821
from torchrl.record.loggers import generate_exp_name, get_logger
from utils import eval_model, make_env, make_ppo_models

device = torch.device(cfg.device)
device = cfg.device
if not device:
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
else:
device = torch.device(device)

# Correct for frame_skip
frame_skip = 4
Expand All @@ -55,7 +59,6 @@ def main(cfg: "DictConfig"): # noqa: F821

# Create models (check utils.py)
actor, critic = make_ppo_models(cfg.env.env_name)
actor, critic = actor.to(device), critic.to(device)

# Create collector
collector = MultiaSyncDataCollector(
Expand Down

0 comments on commit f95ac39

Please sign in to comment.