Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 15, 2024
1 parent 9ea2c12 commit e1b471a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
1 change: 1 addition & 0 deletions sota-implementations/gail/gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def main(cfg: "DictConfig"): # noqa: F821
lmbda=cfg.ppo.loss.gae_lambda,
value_network=critic,
average_gae=False,
device=device,
)

loss_module = ClipPPOLoss(
Expand Down
7 changes: 5 additions & 2 deletions sota-implementations/gail/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def make_ppo_models_state(proof_environment, compile, device):
"low": proof_environment.action_spec_unbatched.space.low.to(device),
"high": proof_environment.action_spec_unbatched.space.high.to(device),
"tanh_loc": False,
"safe_tanh": not compile,
# "safe_tanh": not compile,
}

# Define policy architecture
Expand All @@ -77,7 +77,9 @@ def make_ppo_models_state(proof_environment, compile, device):
policy_mlp = torch.nn.Sequential(
policy_mlp,
AddStateIndependentNormalScale(
proof_environment.action_spec_unbatched.shape[-1], scale_lb=1e-8
proof_environment.action_spec_unbatched.shape[-1],
scale_lb=1e-8,
device=device,
),
)

Expand All @@ -102,6 +104,7 @@ def make_ppo_models_state(proof_environment, compile, device):
activation_class=torch.nn.Tanh,
out_features=1,
num_cells=[64, 64],
device=device,
)

# Initialize value weights
Expand Down

0 comments on commit e1b471a

Please sign in to comment.