diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index cc5225fafbb..3810260760a 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -71,8 +71,8 @@ def main(cfg: "DictConfig"): # noqa: F821 policy=actor, frames_per_batch=frames_per_batch, total_frames=total_frames, - device="cpu", - storing_device="cpu", + device=device, + storing_device=device, max_frames_per_traj=-1, compile_policy={"mode": compile_mode} if compile_mode else False, cudagraph_policy=cfg.compile.cudagraphs, @@ -96,6 +96,7 @@ def main(cfg: "DictConfig"): # noqa: F821 value_network=critic, average_gae=False, device=device, + vectorized=not cfg.compile.compile, ) loss_module = ClipPPOLoss( actor_network=actor, diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index d8cb2f8a8a2..1f4700c58f0 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -98,6 +98,7 @@ def main(cfg: "DictConfig"): # noqa: F821 value_network=critic, average_gae=False, device=device, + vectorized=not cfg.compile.compile, ) loss_module = ClipPPOLoss(