diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index 0185c835c08..a0cf2726aca 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -192,7 +192,7 @@ def update(batch, num_network_updates): # extract cfg variables cfg_loss_ppo_epochs = cfg.loss.ppo_epochs cfg_optim_anneal_lr = cfg.optim.anneal_lr - cfg_optim_lr = cfg.optim.lr + cfg_optim_lr = torch.tensor(cfg.optim.lr, device=device) cfg_loss_anneal_clip_eps = cfg.loss.anneal_clip_epsilon cfg_loss_clip_epsilon = cfg.loss.clip_epsilon cfg_logger_test_interval = cfg.logger.test_interval diff --git a/sota-implementations/ppo/utils_mujoco.py b/sota-implementations/ppo/utils_mujoco.py index 584945013dc..1f224b81528 100644 --- a/sota-implementations/ppo/utils_mujoco.py +++ b/sota-implementations/ppo/utils_mujoco.py @@ -77,7 +77,7 @@ def make_ppo_models_state(proof_environment, device): policy_mlp, AddStateIndependentNormalScale( proof_environment.action_spec_unbatched.shape[-1], scale_lb=1e-8 - ), + ).to(device), ) # Add probabilistic sampling of the actions