From b46947d9ad86e38128774a7eb8300c6c35adef56 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 14:01:49 -0800 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- sota-implementations/ppo/ppo_mujoco.py | 2 +- sota-implementations/ppo/utils_mujoco.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index 3c73c8ca8f8..acba881b84b 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -191,7 +191,7 @@ def update(batch, num_network_updates): # extract cfg variables cfg_loss_ppo_epochs = cfg.loss.ppo_epochs cfg_optim_anneal_lr = cfg.optim.anneal_lr - cfg_optim_lr = cfg.optim.lr + cfg_optim_lr = torch.tensor(cfg.optim.lr, device=device) cfg_loss_anneal_clip_eps = cfg.loss.anneal_clip_epsilon cfg_loss_clip_epsilon = cfg.loss.clip_epsilon cfg_logger_test_interval = cfg.logger.test_interval diff --git a/sota-implementations/ppo/utils_mujoco.py b/sota-implementations/ppo/utils_mujoco.py index 584945013dc..1f224b81528 100644 --- a/sota-implementations/ppo/utils_mujoco.py +++ b/sota-implementations/ppo/utils_mujoco.py @@ -77,7 +77,7 @@ def make_ppo_models_state(proof_environment, device): policy_mlp, AddStateIndependentNormalScale( proof_environment.action_spec_unbatched.shape[-1], scale_lb=1e-8 - ), + ).to(device), ) # Add probabilistic sampling of the actions