From e6fd43ffeb91786d88435226714ff17c43d7c908 Mon Sep 17 00:00:00 2001 From: kaustubh Date: Tue, 21 Nov 2023 23:12:46 -0500 Subject: [PATCH] working with risk --- cleanrl/dqn.py | 160 +++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 147 insertions(+), 13 deletions(-) diff --git a/cleanrl/dqn.py b/cleanrl/dqn.py index 3dfff8ed0..7327a64e2 100644 --- a/cleanrl/dqn.py +++ b/cleanrl/dqn.py @@ -11,9 +11,13 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim -from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.buffers import ReplayBuffer as sb3buffer from torch.utils.tensorboard import SummaryWriter +from src.models.risk_models import * +from src.datasets.risk_datasets import * +from src.utils import * + def parse_args(): # fmt: off @@ -70,6 +74,54 @@ def parse_args(): help="timestep to start learning") parser.add_argument("--train-frequency", type=int, default=10, help="the frequency of training") + + ## Arguments related to risk model + parser.add_argument("--use-risk", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="Use risk model or not ") + parser.add_argument("--risk-actor", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Use risk model in the actor or not ") + parser.add_argument("--risk-critic", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="Use risk model in the critic or not ") + parser.add_argument("--risk-model-path", type=str, default="None", + help="the id of the environment") + parser.add_argument("--binary-risk", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="Use risk model in the critic or not ") + parser.add_argument("--model-type", type=str, default="bayesian", + help="specify the NN to use for the risk model") + parser.add_argument("--risk-bnorm", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") + parser.add_argument("--risk-type", type=str, default="quantile", + help="whether the risk is binary or continuous") + parser.add_argument("--fear-radius", type=int, default=5, + help="fear radius for training the risk model") + parser.add_argument("--num-risk-datapoints", type=int, default=1000, + help="fear radius for training the risk model") + parser.add_argument("--risk-update-period", type=int, default=1000, + help="how frequently to update the risk model") + parser.add_argument("--num-risk-epochs", type=int, default=1, + help="number of sgd steps to update the risk model") + parser.add_argument("--num-update-risk", type=int, default=10, + help="number of sgd steps to update the risk model") + parser.add_argument("--risk-lr", type=float, default=1e-7, + help="the learning rate of the optimizer") + parser.add_argument("--risk-batch-size", type=int, default=1000, + help="number of epochs to update the risk model") + parser.add_argument("--fine-tune-risk", type=str, default="None", + help="fine tune risk by which method") + parser.add_argument("--finetune-risk-online", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") + parser.add_argument("--start-risk-update", type=int, default=10000, + help="number of epochs to update the risk model") + parser.add_argument("--rb-type", type=str, default="simple", + help="which type of replay buffer to use for ") + parser.add_argument("--freeze-risk-layers", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") + parser.add_argument("--weight", type=float, default=1.0, + help="weight for the 1 class in BCE loss") + parser.add_argument("--quantile-size", type=int, default=2, help="size of the risk quantile ") + parser.add_argument("--quantile-num", type=int, default=10, help="number of quantiles to make") + parser.add_argument("--risk-penalty", type=float, default=0., help="penalty to impose for entering risky states") + parser.add_argument("--risk-penalty-start", type=float, default=20., help="penalty to impose for entering risky states") args = parser.parse_args() # fmt: on assert args.num_envs == 1, "vectorized envs are not supported at the moment" @@ -92,20 +144,25 @@ def thunk(): return thunk + + # ALGO LOGIC: initialize agent here: class QNetwork(nn.Module): - def __init__(self, env): + def __init__(self, env, risk_size=0): super().__init__() self.network = nn.Sequential( - nn.Linear(np.array(env.single_observation_space["image"].shape).prod(), 120), + nn.Linear(np.array(env.single_observation_space["image"].shape).prod()+risk_size, 120), nn.ReLU(), nn.Linear(120, 84), nn.ReLU(), nn.Linear(84, env.single_action_space.n), ) - def forward(self, x): - return self.network(x) + def forward(self, x, risk=None): + if risk is None: + return self.network(x) + else: + return self.network(torch.cat([x, risk], axis=1)) def linear_schedule(start_e: float, end_e: float, duration: int, t: int): @@ -151,19 +208,55 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): torch.set_default_tensor_type('torch.cuda.FloatTensor') device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") - + args.use_risk = False if args.risk_model_path == "None" else True # env setup envs = gym.vector.SyncVectorEnv( [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)] ) assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" - q_network = QNetwork(envs).to(device) + + risk_model_class = {"bayesian": {"continuous": BayesRiskEstCont, "binary": BayesRiskEst, "quantile": BayesRiskEst}, + "mlp": {"continuous": RiskEst, "binary": RiskEst}} + + risk_size_dict = {"continuous": 1, "binary": 2, "quantile": args.quantile_num} + risk_size = risk_size_dict[args.risk_type] + risk_bins = np.array([i*args.quantile_size for i in range(args.quantile_num)]) + + if args.use_risk: + risk_model = risk_model_class[args.model_type][args.risk_type](obs_size=np.array(envs.single_observation_space["image"].shape).prod(), batch_norm=True, out_size=risk_size) + if os.path.exists(args.risk_model_path): + risk_model.load_state_dict(torch.load(args.risk_model_path, map_location=device)) + print("Pretrained risk model loaded successfully") + + risk_model.to(device) + risk_model.eval() + + + if args.fine_tune_risk != "None" and args.use_risk: + if args.rb_type == "balanced": + risk_rb = ReplayBufferBalanced(buffer_size=args.total_timesteps) + else: + risk_rb = ReplayBuffer(buffer_size=args.total_timesteps) + #, observation_space=envs.single_observation_space, action_space=envs.single_action_space) + if args.risk_type == "quantile": + weight_tensor = torch.Tensor([1]*args.quantile_num).to(device) + weight_tensor[0] = args.weight + elif args.risk_type == "binary": + weight_tensor = torch.Tensor([1., args.weight]).to(device) + if args.model_type == "bayesian": + criterion = nn.NLLLoss(weight=weight_tensor) + else: + criterion = nn.BCEWithLogitsLoss(pos_weight=weight_tensor) + opt_risk = optim.Adam(filter(lambda p: p.requires_grad, risk_model.parameters()), lr=args.risk_lr, eps=1e-10) + + + q_network = QNetwork(envs, risk_size=risk_size).to(device) optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate) - target_network = QNetwork(envs).to(device) + target_network = QNetwork(envs, risk_size=risk_size).to(device) target_network.load_state_dict(q_network.state_dict()) - rb = ReplayBuffer( + rb = sb3buffer( args.buffer_size, envs.single_observation_space["image"], envs.single_action_space, @@ -172,6 +265,8 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): ) start_time = time.time() + + f_obs, f_next_obs, f_actions = [None]*args.num_envs, [None]*args.num_envs, [None]*args.num_envs # TRY NOT TO MODIFY: start the game obs, _ = envs.reset(seed=args.seed) obs = obs @@ -183,19 +278,43 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): if random.random() < epsilon: actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) else: - q_values = q_network(torch.Tensor(obs["image"]).reshape(args.num_envs, -1).to(device)) + obs_in = torch.Tensor(obs["image"]).reshape(args.num_envs, -1).to(device) + with torch.no_grad(): + risk = risk_model(obs_in) if args.use_risk else None + q_values = q_network(obs_in, risk) actions = torch.argmax(q_values, dim=1).cpu().numpy() # TRY NOT TO MODIFY: execute the game and log data. next_obs, rewards, terminated, truncated, infos = envs.step(actions) cost = int(terminated) and (rewards == 0) + if (args.fine_tune_risk != "None" and args.use_risk) or args.collect_data: + for i in range(args.num_envs): + f_obs[i] = torch.Tensor(obs["image"][i]).reshape(1, -1).to(device) if f_obs[i] is None else torch.concat([f_obs[i], torch.Tensor(obs["image"][i]).reshape(1, -1).to(device)], axis=0) + f_next_obs[i] = torch.Tensor(next_obs["image"][i]).reshape(1, -1).to(device) if f_next_obs[i] is None else torch.concat([f_next_obs[i], torch.Tensor(next_obs["image"][i]).reshape(1, -1).to(device)], axis=0) + f_actions[i] = torch.Tensor([actions[i]]).unsqueeze(0).to(device) if f_actions[i] is None else torch.concat([f_actions[i], torch.Tensor([actions[i]]).unsqueeze(0).to(device)], axis=0) + # f_rewards[i] = reward[i].unsqueeze(0).to(device) if f_rewards[i] is None else torch.concat([f_rewards[i], rewards[i].unsqueeze(0).to(device)], axis=0) + # f_risks = risk_ if f_risks is None else torch.concat([f_risks, risk_], axis=0) + # f_costs[i] = cost[i].unsqueeze(0).to(device) if f_costs[i] is None else torch.concat([f_costs[i], cost[i].unsqueeze(0).to(device)], axis=0) + # f_dones[i] = next_done[i].unsqueeze(0).to(device) if f_dones[i] is None else torch.concat([f_dones[i], next_done[i].unsqueeze(0).to(device)], axis=0) + # TRY NOT TO MODIFY: record rewards for plotting purposes if "final_info" in infos: - for info in infos["final_info"]: + for i, info in enumerate(infos["final_info"]): # Skip the envs that are not done if "episode" not in info: continue total_cost += cost + ep_len = info["episode"]["l"] + e_risks = np.array(list(reversed(range(int(ep_len))))) if cost > 0 else np.array([int(ep_len)]*int(ep_len)) + e_risks_quant = torch.Tensor(np.apply_along_axis(lambda x: np.histogram(x, bins=risk_bins)[0], 1, np.expand_dims(e_risks, 1))) + e_risks = torch.Tensor(e_risks) + if args.use_risk and args.fine_tune_risk != "None": + if args.risk_type == "binary": + risk_rb.add(f_obs[i], f_next_obs[i], f_actions[i], None, None, None, (e_risks <= args.fear_radius).float(), e_risks.unsqueeze(1)) + else: + risk_rb.add(f_obs[i], f_next_obs[i], f_actions[i], None, None, None, e_risks_quant, e_risks.unsqueeze(1)) + + f_obs[i], f_next_obs[i], f_actions[i] = None, None, None scores.append(info['episode']['r']) print(f"global_step={global_step}, episodic_return={info['episode']['r']}, total cost={total_cost}") writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) @@ -220,9 +339,11 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): if global_step % args.train_frequency == 0: data = rb.sample(args.batch_size) with torch.no_grad(): - target_max, _ = target_network(data.next_observations.reshape(args.batch_size, -1).float()).max(dim=1) + next_risk = risk_model(data.next_observations.reshape(args.batch_size, -1).float()) if args.use_risk else None + target_max, _ = target_network(data.next_observations.reshape(args.batch_size, -1).float(), next_risk).max(dim=1) td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten()) - old_val = q_network(data.observations.reshape(args.batch_size, -1).float()).gather(1, data.actions).squeeze() + risk = risk_model(data.observations.reshape(args.batch_size, -1).float()) if args.use_risk else None + old_val = q_network(data.observations.reshape(args.batch_size, -1).float(), risk).gather(1, data.actions).squeeze() loss = F.mse_loss(td_target, old_val) if global_step % 100 == 0: @@ -236,6 +357,19 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): loss.backward() optimizer.step() + ## Update Risk Network + if args.use_risk and args.fine_tune_risk != "None" and global_step % args.risk_update_period == 0: + risk_model.train() + risk_data = risk_rb.sample(args.risk_batch_size) + pred = risk_model(risk_data["next_obs"].to(device)) + risk_loss = criterion(pred, torch.argmax(risk_data["risks"].squeeze(), axis=1).to(device)) + opt_risk.zero_grad() + risk_loss.backward() + opt_risk.step() + risk_model.eval() + writer.add_scalar("charts/risk_loss", risk_loss.item(), global_step) + + # update target network if global_step % args.target_network_frequency == 0: for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()):