From 5ccc53d7f7128ebbffe92152d96e0388f6a139e3 Mon Sep 17 00:00:00 2001 From: Kaustubh Mani Date: Sun, 17 Sep 2023 20:26:07 -0400 Subject: [PATCH] risk obs input correction --- cleanrl/ppo_continuous_action_wandb.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/cleanrl/ppo_continuous_action_wandb.py b/cleanrl/ppo_continuous_action_wandb.py index dddcd82e1..9efa356e8 100644 --- a/cleanrl/ppo_continuous_action_wandb.py +++ b/cleanrl/ppo_continuous_action_wandb.py @@ -475,13 +475,20 @@ def train(cfg): else: criterion = nn.BCEWithLogitsLoss(pos_weight=weight_tensor) + if "goal" in cfg.risk_model_path.lower(): + risk_obs_size = 72 + elif cfg.risk_model_path == "scratch": + risk_obs_size = np.array(envs.single_observation_space.shape).prod() + else: + risk_obs_size = 88 + if cfg.use_risk: print("using risk") #if cfg.risk_type == "binary": agent = RiskAgent(envs=envs, risk_size=risk_size).to(device) #else: # agent = ContRiskAgent(envs=envs).to(device) - risk_model = risk_model_class[cfg.model_type][cfg.risk_type](obs_size=np.array(envs.single_observation_space.shape).prod(), batch_norm=True, out_size=risk_size) + risk_model = risk_model_class[cfg.model_type][cfg.risk_type](obs_size=risk_obs_size, batch_norm=True, out_size=risk_size) if os.path.exists(cfg.risk_model_path): risk_model.load_state_dict(torch.load(cfg.risk_model_path, map_location=device)) print("Pretrained risk model loaded successfully")