Skip to content

Commit

Permalink
risk obs input correction
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaustubh Mani committed Sep 18, 2023
1 parent 9730b68 commit 5ccc53d
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion cleanrl/ppo_continuous_action_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,13 +475,20 @@ def train(cfg):
else:
criterion = nn.BCEWithLogitsLoss(pos_weight=weight_tensor)

if "goal" in cfg.risk_model_path.lower():
risk_obs_size = 72
elif cfg.risk_model_path == "scratch":
risk_obs_size = np.array(envs.single_observation_space.shape).prod()
else:
risk_obs_size = 88

if cfg.use_risk:
print("using risk")
#if cfg.risk_type == "binary":
agent = RiskAgent(envs=envs, risk_size=risk_size).to(device)
#else:
# agent = ContRiskAgent(envs=envs).to(device)
risk_model = risk_model_class[cfg.model_type][cfg.risk_type](obs_size=np.array(envs.single_observation_space.shape).prod(), batch_norm=True, out_size=risk_size)
risk_model = risk_model_class[cfg.model_type][cfg.risk_type](obs_size=risk_obs_size, batch_norm=True, out_size=risk_size)
if os.path.exists(cfg.risk_model_path):
risk_model.load_state_dict(torch.load(cfg.risk_model_path, map_location=device))
print("Pretrained risk model loaded successfully")
Expand Down

0 comments on commit 5ccc53d

Please sign in to comment.