Skip to content

Commit

Permalink
training risk model from scratch along with PPO
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaustubh Mani committed Sep 12, 2023
1 parent 6ba5107 commit 3f77239
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions cleanrl/ppo_continuous_action_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,13 @@ def parse_args():
help="the learning rate of the optimizer")
parser.add_argument("--risk-batch-size", type=int, default=1000,
help="number of epochs to update the risk model")
parser.add_argument("--fine-tune-risk", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
parser.add_argument("--fine-tune-risk", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
parser.add_argument("--start-risk-update", type=int, default=10000,
help="number of epochs to update the risk model")
parser.add_argument("--rb-type", type=str, default="balanced",
help="which type of replay buffer to use for ")
parser.add_argument("--freeze-risk-layers", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
parser.add_argument("--freeze-risk-layers", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
parser.add_argument("--weight", type=float, default=1.0,
help="weight for the 1 class in BCE loss")
Expand Down Expand Up @@ -357,9 +357,9 @@ def risk_sgd_step(cfg, model, batch, criterion, opt, device):

def train_risk(cfg, model, data, criterion, opt, device):
model.train()
dataset = RiskyDataset(data["next_obs"], data["actions"], data["risks"], False, risk_type=cfg.risk_type,
dataset = RiskyDataset(data["next_obs"].to('cpu'), data["actions"].to('cpu'), data["risks"].to('cpu'), False, risk_type=cfg.risk_type,
fear_clip=None, fear_radius=cfg.fear_radius, one_hot=True, quantile_size=cfg.quantile_size, quantile_num=cfg.quantile_num)
dataloader = DataLoader(dataset, batch_size=cfg.risk_batch_size, shuffle=True, num_workers=10, generator=torch.Generator(device=device))
dataloader = DataLoader(dataset, batch_size=cfg.risk_batch_size, shuffle=True, num_workers=10, generator=torch.Generator(device='cpu'))
net_loss = 0
for batch in dataloader:
pred = model(batch[0].to(device))
Expand Down Expand Up @@ -449,8 +449,8 @@ def train(cfg):
risk_model = risk_model_class[cfg.model_type][cfg.risk_type](obs_size=np.array(envs.single_observation_space.shape).prod(), batch_norm=False, out_size=risk_size)
if os.path.exists(cfg.risk_model_path):
risk_model.load_state_dict(torch.load(cfg.risk_model_path, map_location=device))
risk_model.to(device)
print("risk model loaded successfully")
risk_model.to(device)
#print("risk model loaded successfully")
if cfg.fine_tune_risk:
# print("Fine Tuning risk")
## Freezing all except last layer of the risk model
Expand Down Expand Up @@ -537,7 +537,7 @@ def train(cfg):

if cfg.use_risk:
with torch.no_grad():
next_risk = torch.Tensor(risk_model(next_obs)).to(device)
next_risk = torch.Tensor(risk_model(next_obs.to(device))).to(device)
if cfg.risk_type == "continuous":
next_risk = next_risk.unsqueeze(0)
#print(next_risk.size())
Expand Down

0 comments on commit 3f77239

Please sign in to comment.