Skip to content

Commit

Permalink
add some changes
Browse files Browse the repository at this point in the history
  • Loading branch information
manila95 committed Dec 19, 2023
1 parent a974cd6 commit b9cc679
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions cleanrl/ppo_continuous_action_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ def train(cfg):
random.seed(cfg.seed)
np.random.seed(cfg.seed)
torch.manual_seed(cfg.model_seed)
torch.set_num_threads(4)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu")
Expand Down Expand Up @@ -558,7 +559,7 @@ def train(cfg):
agent = RiskAgent(envs=envs, risk_size=risk_size).to(device)
#else:
# agent = ContRiskAgent(envs=envs).to(device)
risk_model = risk_model_class[cfg.model_type][cfg.risk_type](obs_size=96, batch_norm=True, out_size=risk_size, action_size=envs.single_action_space.shape[0], model_type="state_action_risk")
risk_model = risk_model_class[cfg.model_type][cfg.risk_type](obs_size=96, batch_norm=True, out_size=risk_size, action_size=envs.single_action_space.shape[0], model_type="state_action_risk", device=device)
if os.path.exists(cfg.risk_model_path):
risk_model.load_state_dict(torch.load(cfg.risk_model_path, map_location=device))
print("Pretrained risk model loaded successfully")
Expand Down Expand Up @@ -661,7 +662,7 @@ def train(cfg):
if cfg.use_risk:
with torch.no_grad():
next_obs_risk = get_risk_obs(cfg, next_obs)
next_risk = torch.Tensor(risk_model(next_obs_risk.to(device))).to(device)
next_risk = risk_model(next_obs_risk.to(device))
if cfg.risk_type == "continuous":
next_risk = next_risk.unsqueeze(0)
#print(next_risk.size())
Expand Down Expand Up @@ -694,7 +695,7 @@ def train(cfg):
actions[step] = action
logprobs[step] = logprob

if cfg.use_risk:
if cfg.use_risk and False:
with torch.no_grad():
candidates = candidates.squeeze()
# print(next_obs_risk.repeat(5, 1).size(), candidates.size())
Expand Down Expand Up @@ -767,12 +768,12 @@ def train(cfg):
else:
data = rb.sample(cfg.risk_batch_size*cfg.num_update_risk)
state = torch.cat([data["obs"], data["next_obs"]], axis=0)
actions = torch.cat([data["actions"], torch.zeros_like(data["actions"])], axis=0)
actions = torch.cat([data["actions"].squeeze(), torch.zeros_like(data["actions"]).squeeze()], axis=0)
dist_to_fail = torch.cat([data["dist_to_fail"], data["dist_to_fail"]], axis=0)
print(state.size(), actions.size(), dist_to_fail.size())
risk_dataset = RiskyDataset(state.to(device), actions.to(device), dist_to_fail.to(device), True, risk_type=cfg.risk_type,
fear_clip=None, fear_radius=cfg.fear_radius, one_hot=True, quantile_size=cfg.quantile_size, quantile_num=cfg.quantile_num)
risk_dataloader = DataLoader(risk_dataset, batch_size=cfg.risk_batch_size, shuffle=True, num_workers=4, generator=torch.Generator(device=device))
risk_dataloader = DataLoader(risk_dataset, batch_size=cfg.risk_batch_size, shuffle=True, num_workers=4, generator=torch.Generator(device="cpu"))

risk_loss = train_risk(risk_model, risk_dataloader, criterion, opt_risk, 1, device, train_mode="state_action")

Expand Down

0 comments on commit b9cc679

Please sign in to comment.