From 0c6081dd16bb0efb570d4e85f8c9055b2ddc8ec0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 13:28:51 -0800 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- sota-implementations/ppo/ppo_atari.py | 5 +++-- sota-implementations/ppo/ppo_mujoco.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index b054e36104b..fbecf7ab0a8 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -71,8 +71,8 @@ def main(cfg: "DictConfig"): # noqa: F821 policy=actor, frames_per_batch=frames_per_batch, total_frames=total_frames, - device="cpu", - storing_device="cpu", + device=device, + storing_device=device, max_frames_per_traj=-1, compile_policy={"mode": compile_mode} if compile_mode else False, cudagraph_policy=cfg.compile.cudagraphs, @@ -96,6 +96,7 @@ def main(cfg: "DictConfig"): # noqa: F821 value_network=critic, average_gae=False, device=device, + vectorized=not cfg.compile.compile, ) loss_module = ClipPPOLoss( actor_network=actor, diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index 44c5693a8c6..849aac38f41 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -98,6 +98,7 @@ def main(cfg: "DictConfig"): # noqa: F821 value_network=critic, average_gae=False, device=device, + vectorized=not cfg.compile.compile, ) loss_module = ClipPPOLoss(