Skip to content

Commit

Permalink
saving the csc model
Browse files Browse the repository at this point in the history
  • Loading branch information
manila95 committed Jan 29, 2024
1 parent 0d3c3eb commit e23a7cc
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion cleanrl/ppo_continuous_action_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ def train(cfg):
cfg.use_risk = False if cfg.risk_model_path == "None" else True

import wandb
run = wandb.init(config=vars(cfg), entity="manila95",
run = wandb.init(config=vars(cfg), entity="kaustubh95",
project="risk-aware-exploration",
monitor_gym=True,
sync_tensorboard=True, save_code=True)
Expand Down Expand Up @@ -956,6 +956,9 @@ def train(cfg):
writer.add_scalar("Results/Avg_Return", avg_mean_score, global_step)
torch.save(agent.state_dict(), os.path.join(wandb.run.dir, "policy.pt"))
wandb.save("policy.pt")
if cfg.use_csc:
torch.save(qf1.state_dict(), os.path.join(wandb.run.dir, "csc.pt"))
wandb.save("csc.pt")
if cfg.use_risk:
torch.save(risk_model.state_dict(), os.path.join(wandb.run.dir, "risk_model.pt"))
wandb.save("risk_model.pt")
Expand Down

0 comments on commit e23a7cc

Please sign in to comment.