Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quality] IMPALA auto-device #2654

Merged
merged 25 commits into from
Dec 16, 2024
2 changes: 1 addition & 1 deletion sota-implementations/impala/config_multi_node_ray.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ ray_init_config:
storage: null

# Device for the forward and backward passes
local_device: "cuda:0"
local_device:

# Resources assigned to each IMPALA rollout collection worker
remote_worker_resources:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ env:
env_name: PongNoFrameskip-v4

# Device for the forward and backward passes
local_device: "cuda:0"
local_device:

# SLURM config
slurm_config:
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/impala/config_single_node.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ env:
env_name: PongNoFrameskip-v4

# Device for the forward and backward passes
device: "cuda:0"
device:

# collector
collector:
Expand Down
6 changes: 5 additions & 1 deletion sota-implementations/impala/impala_multi_node_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ def main(cfg: "DictConfig"): # noqa: F821
from torchrl.record.loggers import generate_exp_name, get_logger
from utils import eval_model, make_env, make_ppo_models

device = torch.device(cfg.local_device)
device = cfg.local_device
if not device:
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
else:
device = torch.device(device)

# Correct for frame_skip
frame_skip = 4
Expand Down
6 changes: 5 additions & 1 deletion sota-implementations/impala/impala_multi_node_submitit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ def main(cfg: "DictConfig"): # noqa: F821
from torchrl.record.loggers import generate_exp_name, get_logger
from utils import eval_model, make_env, make_ppo_models

device = torch.device(cfg.local_device)
device = cfg.local_device
if not device:
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
else:
device = torch.device(device)

# Correct for frame_skip
frame_skip = 4
Expand Down
7 changes: 5 additions & 2 deletions sota-implementations/impala/impala_single_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ def main(cfg: "DictConfig"): # noqa: F821
from torchrl.record.loggers import generate_exp_name, get_logger
from utils import eval_model, make_env, make_ppo_models

device = torch.device(cfg.device)
device = cfg.device
if not device:
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
else:
device = torch.device(device)

# Correct for frame_skip
frame_skip = 4
Expand All @@ -55,7 +59,6 @@ def main(cfg: "DictConfig"): # noqa: F821

# Create models (check utils.py)
actor, critic = make_ppo_models(cfg.env.env_name)
actor, critic = actor.to(device), critic.to(device)

# Create collector
collector = MultiaSyncDataCollector(
Expand Down
Loading