Skip to content

Commit

Permalink
able to train risk model from scratch along with the policy
Browse files Browse the repository at this point in the history
  • Loading branch information
manila95 committed Sep 12, 2023
1 parent 8923cdd commit 6ba5107
Showing 1 changed file with 66 additions and 29 deletions.
95 changes: 66 additions & 29 deletions cleanrl/ppo_continuous_action_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from stable_baselines3.common.buffers import *

from src.models.risk_models import *
from src.datasets.risk_datasets import *
from src.utils import *

import hydra
Expand Down Expand Up @@ -64,7 +65,7 @@ def parse_args():
parser.add_argument("--storage-path", type=str, default="./data/ppo/term_1",
help="the storage path for the data collected")

parser.add_argument("--total-timesteps", type=int, default=10000,
parser.add_argument("--total-timesteps", type=int, default=100000,
help="total timesteps of the experiments")
parser.add_argument("--learning-rate", type=float, default=3e-4,
help="the learning rate of the optimizer")
Expand Down Expand Up @@ -112,19 +113,19 @@ def parse_args():
help="Use risk model in the critic or not ")
parser.add_argument("--model-type", type=str, default="mlp",
help="specify the NN to use for the risk model")
parser.add_argument("--risk-type", type=str, default="discrete",
help="whether the risk is discrete or continuous")
parser.add_argument("--risk-type", type=str, default="binary",
help="whether the risk is binary or continuous")
parser.add_argument("--fear-radius", type=int, default=5,
help="fear radius for training the risk model")
parser.add_argument("--num-risk-datapoints", type=int, default=1000,
help="fear radius for training the risk model")
parser.add_argument("--update-risk-model", type=int, default=1000,
help="number of epochs to update the risk model")
parser.add_argument("--risk-epochs", type=int, default=10,
help="number of epochs to update the risk model")
parser.add_argument("--risk-update-period", type=int, default=1000,
help="how frequently to update the risk model")
parser.add_argument("--num-update-risk", type=int, default=10,
help="number of sgd steps to update the risk model")
parser.add_argument("--risk-lr", type=float, default=1e-7,
help="the learning rate of the optimizer")
parser.add_argument("--risk-batch-size", type=int, default=10,
parser.add_argument("--risk-batch-size", type=int, default=1000,
help="number of epochs to update the risk model")
parser.add_argument("--fine-tune-risk", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
Expand All @@ -134,6 +135,8 @@ def parse_args():
help="which type of replay buffer to use for ")
parser.add_argument("--freeze-risk-layers", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
parser.add_argument("--weight", type=float, default=1.0,
help="weight for the 1 class in BCE loss")
parser.add_argument("--quantile-size", type=int, default=4, help="size of the risk quantile ")
parser.add_argument("--quantile-num", type=int, default=5, help="number of quantiles to make")
args = parser.parse_args()
Expand Down Expand Up @@ -352,6 +355,31 @@ def risk_sgd_step(cfg, model, batch, criterion, opt, device):
return loss


def train_risk(cfg, model, data, criterion, opt, device):
model.train()
dataset = RiskyDataset(data["next_obs"], data["actions"], data["risks"], False, risk_type=cfg.risk_type,
fear_clip=None, fear_radius=cfg.fear_radius, one_hot=True, quantile_size=cfg.quantile_size, quantile_num=cfg.quantile_num)
dataloader = DataLoader(dataset, batch_size=cfg.risk_batch_size, shuffle=True, num_workers=10, generator=torch.Generator(device=device))
net_loss = 0
for batch in dataloader:
pred = model(batch[0].to(device))
if cfg.model_type == "mlp":
loss = criterion(pred, batch[1].squeeze().to(device))
else:
loss = criterion(pred, torch.argmax(batch[1].squeeze(), axis=1).to(device))
opt.zero_grad()
loss.backward()
opt.step()

net_loss += loss.item()

model.eval()
return net_loss





def train(cfg):
# fmt: on

Expand Down Expand Up @@ -391,44 +419,50 @@ def train(cfg):
)
assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"

risk_model_class = {"bayesian": {"continuous": BayesRiskEstCont, "discrete": BayesRiskEst, "quantile": BayesRiskEst},
"mlp": {"continuous": RiskEst, "discrete": RiskEst}}
risk_model_class = {"bayesian": {"continuous": BayesRiskEstCont, "binary": BayesRiskEst, "quantile": BayesRiskEst},
"mlp": {"continuous": RiskEst, "binary": RiskEst}}

risk_size_dict = {"continuous": 1, "discrete": 2, "quantile": cfg.quantile_num}
risk_size_dict = {"continuous": 1, "binary": 2, "quantile": cfg.quantile_num}
risk_size = risk_size_dict[cfg.risk_type]
if cfg.fine_tune_risk:
if cfg.rb_type == "balanced":
rb = ReplayBufferBalanced(buffer_size=cfg.total_timesteps)
else:
rb = ReplayBuffer(buffer_size=cfg.total_timesteps)
#, observation_space=envs.single_observation_space, action_space=envs.single_action_space)
if cfg.risk_type == "quantile":
weight_tensor = torch.Tensor([1]*cfg.quantile_num).to(device)
weight_tensor[0] = cfg.weight
elif cfg.risk_type == "binary":
weight_tensor = torch.Tensor([1., cfg.weight]).to(device)
if cfg.model_type == "bayesian":
criterion = nn.NLLLoss(weight=torch.Tensor([1, 1.]).to(device))
criterion = nn.NLLLoss(weight=weight_tensor)
else:
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([1, 1.]).to(device))
criterion = nn.BCEWithLogitsLoss(pos_weight=weight_tensor)

if cfg.use_risk:
print("using risk")
#if cfg.risk_type == "discrete":
#if cfg.risk_type == "binary":
agent = RiskAgent(envs=envs, risk_size=risk_size).to(device)
#else:
# agent = ContRiskAgent(envs=envs).to(device)
risk_model = risk_model_class[cfg.model_type][cfg.risk_type](obs_size=np.array(envs.single_observation_space.shape).prod(), batch_norm=False, out_size=risk_size)
if os.path.exists(cfg.risk_model_path):
risk_model = risk_model_class[cfg.model_type][cfg.risk_type](obs_size=np.array(envs.single_observation_space.shape).prod(), batch_norm=False, out_size=risk_size)
risk_model.load_state_dict(torch.load(cfg.risk_model_path, map_location=device))
risk_model.to(device)
print("risk model loaded successfully")
if cfg.fine_tune_risk:
## Freezing all except last layer of the risk model
if cfg.freeze_risk_layers:
for param in risk_model.parameters():
param.requires_grad = False
risk_model.out.weight.requires_grad = True
risk_model.out.bias.requires_grad = True
opt_risk = optim.Adam(filter(lambda p: p.requires_grad, risk_model.parameters()), lr=cfg.risk_lr, eps=1e-10)
if cfg.fine_tune_risk:
# print("Fine Tuning risk")
## Freezing all except last layer of the risk model
if cfg.freeze_risk_layers:
for param in risk_model.parameters():
param.requires_grad = False
risk_model.out.weight.requires_grad = True
risk_model.out.bias.requires_grad = True
opt_risk = optim.Adam(filter(lambda p: p.requires_grad, risk_model.parameters()), lr=cfg.risk_lr, eps=1e-10)
risk_model.eval()
else:
raise("No model in the path specified!!")
print("No model in the path specified!!")
else:
agent = Agent(envs=envs).to(device)

Expand Down Expand Up @@ -507,7 +541,7 @@ def train(cfg):
if cfg.risk_type == "continuous":
next_risk = next_risk.unsqueeze(0)
#print(next_risk.size())
if cfg.binary_risk and cfg.risk_type == "discrete":
if cfg.binary_risk and cfg.risk_type == "binary":
id_risk = torch.argmax(next_risk, axis=1)
next_risk = torch.zeros_like(next_risk)
next_risk[:, id_risk] = 1
Expand Down Expand Up @@ -564,11 +598,14 @@ def train(cfg):

obs_ = next_obs
# if global_step % cfg.update_risk_model == 0 and cfg.fine_tune_risk:
if cfg.use_risk and (global_step > cfg.start_risk_update and cfg.fine_tune_risk):
if cfg.use_risk and (global_step > cfg.start_risk_update and cfg.fine_tune_risk) and global_step % cfg.risk_update_period == 0:
#print(global_step)
batch = rb.sample(cfg.risk_batch_size)
risk_loss = risk_sgd_step(cfg, risk_model, batch, criterion, opt_risk, device)
writer.add_scalar("risk/risk_loss", risk_loss, global_step)
# update_risk = 0
# while update_risk < cfg.num_update_risk:
data = rb.sample(cfg.risk_batch_size*cfg.num_update_risk)
risk_loss = train_risk(cfg, risk_model, data, criterion, opt_risk, device)
writer.add_scalar("risk/risk_loss", risk_loss, global_step)
# update_risk += 1
# fine_tune_risk(cfg, risk_model, f_obs[-cfg.num_risk_datapoints:], f_risks[-cfg.num_risk_datapoints:], opt_risk, device)


Expand Down

0 comments on commit 6ba5107

Please sign in to comment.