From b9cc679f6cd713df523567b46a043724f654f344 Mon Sep 17 00:00:00 2001 From: Kaustubh Mani Date: Tue, 19 Dec 2023 13:27:22 -0500 Subject: [PATCH] add some changes --- cleanrl/ppo_continuous_action_wandb.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/cleanrl/ppo_continuous_action_wandb.py b/cleanrl/ppo_continuous_action_wandb.py index 71d3616df..2a7092146 100644 --- a/cleanrl/ppo_continuous_action_wandb.py +++ b/cleanrl/ppo_continuous_action_wandb.py @@ -510,6 +510,7 @@ def train(cfg): random.seed(cfg.seed) np.random.seed(cfg.seed) torch.manual_seed(cfg.model_seed) + torch.set_num_threads(4) torch.backends.cudnn.deterministic = cfg.torch_deterministic device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu") @@ -558,7 +559,7 @@ def train(cfg): agent = RiskAgent(envs=envs, risk_size=risk_size).to(device) #else: # agent = ContRiskAgent(envs=envs).to(device) - risk_model = risk_model_class[cfg.model_type][cfg.risk_type](obs_size=96, batch_norm=True, out_size=risk_size, action_size=envs.single_action_space.shape[0], model_type="state_action_risk") + risk_model = risk_model_class[cfg.model_type][cfg.risk_type](obs_size=96, batch_norm=True, out_size=risk_size, action_size=envs.single_action_space.shape[0], model_type="state_action_risk", device=device) if os.path.exists(cfg.risk_model_path): risk_model.load_state_dict(torch.load(cfg.risk_model_path, map_location=device)) print("Pretrained risk model loaded successfully") @@ -661,7 +662,7 @@ def train(cfg): if cfg.use_risk: with torch.no_grad(): next_obs_risk = get_risk_obs(cfg, next_obs) - next_risk = torch.Tensor(risk_model(next_obs_risk.to(device))).to(device) + next_risk = risk_model(next_obs_risk.to(device)) if cfg.risk_type == "continuous": next_risk = next_risk.unsqueeze(0) #print(next_risk.size()) @@ -694,7 +695,7 @@ def train(cfg): actions[step] = action logprobs[step] = logprob - if cfg.use_risk: + if cfg.use_risk and False: with torch.no_grad(): candidates = candidates.squeeze() # print(next_obs_risk.repeat(5, 1).size(), candidates.size()) @@ -767,12 +768,12 @@ def train(cfg): else: data = rb.sample(cfg.risk_batch_size*cfg.num_update_risk) state = torch.cat([data["obs"], data["next_obs"]], axis=0) - actions = torch.cat([data["actions"], torch.zeros_like(data["actions"])], axis=0) + actions = torch.cat([data["actions"].squeeze(), torch.zeros_like(data["actions"]).squeeze()], axis=0) dist_to_fail = torch.cat([data["dist_to_fail"], data["dist_to_fail"]], axis=0) print(state.size(), actions.size(), dist_to_fail.size()) risk_dataset = RiskyDataset(state.to(device), actions.to(device), dist_to_fail.to(device), True, risk_type=cfg.risk_type, fear_clip=None, fear_radius=cfg.fear_radius, one_hot=True, quantile_size=cfg.quantile_size, quantile_num=cfg.quantile_num) - risk_dataloader = DataLoader(risk_dataset, batch_size=cfg.risk_batch_size, shuffle=True, num_workers=4, generator=torch.Generator(device=device)) + risk_dataloader = DataLoader(risk_dataset, batch_size=cfg.risk_batch_size, shuffle=True, num_workers=4, generator=torch.Generator(device="cpu")) risk_loss = train_risk(risk_model, risk_dataloader, criterion, opt_risk, 1, device, train_mode="state_action")