Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 15, 2024
2 parents e5774d1 + e1b471a commit 5789a15
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 2 deletions.
2 changes: 2 additions & 0 deletions sota-implementations/dqn/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from torchrl.record.loggers import generate_exp_name, get_logger
from utils_atari import eval_model, make_dqn_model, make_env

torch.set_float32_matmul_precision("high")


@hydra.main(config_path="", config_name="config_atari", version_base="1.1")
def main(cfg: "DictConfig"): # noqa: F821
Expand Down
2 changes: 2 additions & 0 deletions sota-implementations/dqn/dqn_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from torchrl.record.loggers import generate_exp_name, get_logger
from utils_cartpole import eval_model, make_dqn_model, make_env

torch.set_float32_matmul_precision("high")


@hydra.main(config_path="", config_name="config_cartpole", version_base="1.1")
def main(cfg: "DictConfig"): # noqa: F821
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/gail/gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def main(cfg: "DictConfig"): # noqa: F821
lmbda=cfg.ppo.loss.gae_lambda,
value_network=critic,
average_gae=False,
device=device,
)

loss_module = ClipPPOLoss(
Expand Down
7 changes: 5 additions & 2 deletions sota-implementations/gail/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def make_ppo_models_state(proof_environment, compile, device):
"low": proof_environment.action_spec_unbatched.space.low.to(device),
"high": proof_environment.action_spec_unbatched.space.high.to(device),
"tanh_loc": False,
"safe_tanh": not compile,
# "safe_tanh": not compile,
}

# Define policy architecture
Expand All @@ -77,7 +77,9 @@ def make_ppo_models_state(proof_environment, compile, device):
policy_mlp = torch.nn.Sequential(
policy_mlp,
AddStateIndependentNormalScale(
proof_environment.action_spec_unbatched.shape[-1], scale_lb=1e-8
proof_environment.action_spec_unbatched.shape[-1],
scale_lb=1e-8,
device=device,
),
)

Expand All @@ -102,6 +104,7 @@ def make_ppo_models_state(proof_environment, compile, device):
activation_class=torch.nn.Tanh,
out_features=1,
num_cells=[64, 64],
device=device,
)

# Initialize value weights
Expand Down

0 comments on commit 5789a15

Please sign in to comment.