From 23cab4156ebc208e2d3f924edd4a97638951646c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 18:45:09 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- sota-implementations/gail/gail.py | 1 + sota-implementations/gail/ppo_utils.py | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index e6c729eabc3..0f72128318d 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -95,6 +95,7 @@ def main(cfg: "DictConfig"): # noqa: F821 lmbda=cfg.ppo.loss.gae_lambda, value_network=critic, average_gae=False, + device=device, ) loss_module = ClipPPOLoss( diff --git a/sota-implementations/gail/ppo_utils.py b/sota-implementations/gail/ppo_utils.py index 6ba12acdf9c..7dcc2db6b74 100644 --- a/sota-implementations/gail/ppo_utils.py +++ b/sota-implementations/gail/ppo_utils.py @@ -55,7 +55,7 @@ def make_ppo_models_state(proof_environment, compile, device): "low": proof_environment.action_spec_unbatched.space.low.to(device), "high": proof_environment.action_spec_unbatched.space.high.to(device), "tanh_loc": False, - "safe_tanh": not compile, + # "safe_tanh": not compile, } # Define policy architecture @@ -77,7 +77,9 @@ def make_ppo_models_state(proof_environment, compile, device): policy_mlp = torch.nn.Sequential( policy_mlp, AddStateIndependentNormalScale( - proof_environment.action_spec_unbatched.shape[-1], scale_lb=1e-8 + proof_environment.action_spec_unbatched.shape[-1], + scale_lb=1e-8, + device=device, ), ) @@ -102,6 +104,7 @@ def make_ppo_models_state(proof_environment, compile, device): activation_class=torch.nn.Tanh, out_features=1, num_cells=[64, 64], + device=device, ) # Initialize value weights