diff --git a/cleanrl/ppo_continuous_action_wandb.py b/cleanrl/ppo_continuous_action_wandb.py index e4241520..3f587709 100644 --- a/cleanrl/ppo_continuous_action_wandb.py +++ b/cleanrl/ppo_continuous_action_wandb.py @@ -609,7 +609,7 @@ def train(cfg): cfg.use_risk = False if cfg.risk_model_path == "None" else True import wandb - run = wandb.init(config=vars(cfg), entity="manila95", + run = wandb.init(config=vars(cfg), entity="kaustubh95", project="risk-aware-exploration", monitor_gym=True, sync_tensorboard=True, save_code=True) @@ -956,6 +956,9 @@ def train(cfg): writer.add_scalar("Results/Avg_Return", avg_mean_score, global_step) torch.save(agent.state_dict(), os.path.join(wandb.run.dir, "policy.pt")) wandb.save("policy.pt") + if cfg.use_csc: + torch.save(qf1.state_dict(), os.path.join(wandb.run.dir, "csc.pt")) + wandb.save("csc.pt") if cfg.use_risk: torch.save(risk_model.state_dict(), os.path.join(wandb.run.dir, "risk_model.pt")) wandb.save("risk_model.pt")