Skip to content

Commit

Permalink
working with risk
Browse files Browse the repository at this point in the history
  • Loading branch information
manila95 committed Nov 22, 2023
1 parent a69d612 commit e6fd43f
Showing 1 changed file with 147 additions and 13 deletions.
160 changes: 147 additions & 13 deletions cleanrl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.buffers import ReplayBuffer as sb3buffer
from torch.utils.tensorboard import SummaryWriter

from src.models.risk_models import *
from src.datasets.risk_datasets import *
from src.utils import *


def parse_args():
# fmt: off
Expand Down Expand Up @@ -70,6 +74,54 @@ def parse_args():
help="timestep to start learning")
parser.add_argument("--train-frequency", type=int, default=10,
help="the frequency of training")

## Arguments related to risk model
parser.add_argument("--use-risk", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="Use risk model or not ")
parser.add_argument("--risk-actor", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Use risk model in the actor or not ")
parser.add_argument("--risk-critic", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="Use risk model in the critic or not ")
parser.add_argument("--risk-model-path", type=str, default="None",
help="the id of the environment")
parser.add_argument("--binary-risk", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="Use risk model in the critic or not ")
parser.add_argument("--model-type", type=str, default="bayesian",
help="specify the NN to use for the risk model")
parser.add_argument("--risk-bnorm", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
parser.add_argument("--risk-type", type=str, default="quantile",
help="whether the risk is binary or continuous")
parser.add_argument("--fear-radius", type=int, default=5,
help="fear radius for training the risk model")
parser.add_argument("--num-risk-datapoints", type=int, default=1000,
help="fear radius for training the risk model")
parser.add_argument("--risk-update-period", type=int, default=1000,
help="how frequently to update the risk model")
parser.add_argument("--num-risk-epochs", type=int, default=1,
help="number of sgd steps to update the risk model")
parser.add_argument("--num-update-risk", type=int, default=10,
help="number of sgd steps to update the risk model")
parser.add_argument("--risk-lr", type=float, default=1e-7,
help="the learning rate of the optimizer")
parser.add_argument("--risk-batch-size", type=int, default=1000,
help="number of epochs to update the risk model")
parser.add_argument("--fine-tune-risk", type=str, default="None",
help="fine tune risk by which method")
parser.add_argument("--finetune-risk-online", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
parser.add_argument("--start-risk-update", type=int, default=10000,
help="number of epochs to update the risk model")
parser.add_argument("--rb-type", type=str, default="simple",
help="which type of replay buffer to use for ")
parser.add_argument("--freeze-risk-layers", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
parser.add_argument("--weight", type=float, default=1.0,
help="weight for the 1 class in BCE loss")
parser.add_argument("--quantile-size", type=int, default=2, help="size of the risk quantile ")
parser.add_argument("--quantile-num", type=int, default=10, help="number of quantiles to make")
parser.add_argument("--risk-penalty", type=float, default=0., help="penalty to impose for entering risky states")
parser.add_argument("--risk-penalty-start", type=float, default=20., help="penalty to impose for entering risky states")
args = parser.parse_args()
# fmt: on
assert args.num_envs == 1, "vectorized envs are not supported at the moment"
Expand All @@ -92,20 +144,25 @@ def thunk():
return thunk




# ALGO LOGIC: initialize agent here:
class QNetwork(nn.Module):
def __init__(self, env):
def __init__(self, env, risk_size=0):
super().__init__()
self.network = nn.Sequential(
nn.Linear(np.array(env.single_observation_space["image"].shape).prod(), 120),
nn.Linear(np.array(env.single_observation_space["image"].shape).prod()+risk_size, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, env.single_action_space.n),
)

def forward(self, x):
return self.network(x)
def forward(self, x, risk=None):
if risk is None:
return self.network(x)
else:
return self.network(torch.cat([x, risk], axis=1))


def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
Expand Down Expand Up @@ -151,19 +208,55 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
torch.set_default_tensor_type('torch.cuda.FloatTensor')

device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

args.use_risk = False if args.risk_model_path == "None" else True
# env setup
envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
)
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

q_network = QNetwork(envs).to(device)

risk_model_class = {"bayesian": {"continuous": BayesRiskEstCont, "binary": BayesRiskEst, "quantile": BayesRiskEst},
"mlp": {"continuous": RiskEst, "binary": RiskEst}}

risk_size_dict = {"continuous": 1, "binary": 2, "quantile": args.quantile_num}
risk_size = risk_size_dict[args.risk_type]
risk_bins = np.array([i*args.quantile_size for i in range(args.quantile_num)])

if args.use_risk:
risk_model = risk_model_class[args.model_type][args.risk_type](obs_size=np.array(envs.single_observation_space["image"].shape).prod(), batch_norm=True, out_size=risk_size)
if os.path.exists(args.risk_model_path):
risk_model.load_state_dict(torch.load(args.risk_model_path, map_location=device))
print("Pretrained risk model loaded successfully")

risk_model.to(device)
risk_model.eval()


if args.fine_tune_risk != "None" and args.use_risk:
if args.rb_type == "balanced":
risk_rb = ReplayBufferBalanced(buffer_size=args.total_timesteps)
else:
risk_rb = ReplayBuffer(buffer_size=args.total_timesteps)
#, observation_space=envs.single_observation_space, action_space=envs.single_action_space)
if args.risk_type == "quantile":
weight_tensor = torch.Tensor([1]*args.quantile_num).to(device)
weight_tensor[0] = args.weight
elif args.risk_type == "binary":
weight_tensor = torch.Tensor([1., args.weight]).to(device)
if args.model_type == "bayesian":
criterion = nn.NLLLoss(weight=weight_tensor)
else:
criterion = nn.BCEWithLogitsLoss(pos_weight=weight_tensor)
opt_risk = optim.Adam(filter(lambda p: p.requires_grad, risk_model.parameters()), lr=args.risk_lr, eps=1e-10)


q_network = QNetwork(envs, risk_size=risk_size).to(device)
optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)
target_network = QNetwork(envs).to(device)
target_network = QNetwork(envs, risk_size=risk_size).to(device)
target_network.load_state_dict(q_network.state_dict())

rb = ReplayBuffer(
rb = sb3buffer(
args.buffer_size,
envs.single_observation_space["image"],
envs.single_action_space,
Expand All @@ -172,6 +265,8 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
)
start_time = time.time()


f_obs, f_next_obs, f_actions = [None]*args.num_envs, [None]*args.num_envs, [None]*args.num_envs
# TRY NOT TO MODIFY: start the game
obs, _ = envs.reset(seed=args.seed)
obs = obs
Expand All @@ -183,19 +278,43 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
if random.random() < epsilon:
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
else:
q_values = q_network(torch.Tensor(obs["image"]).reshape(args.num_envs, -1).to(device))
obs_in = torch.Tensor(obs["image"]).reshape(args.num_envs, -1).to(device)
with torch.no_grad():
risk = risk_model(obs_in) if args.use_risk else None
q_values = q_network(obs_in, risk)
actions = torch.argmax(q_values, dim=1).cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
cost = int(terminated) and (rewards == 0)
if (args.fine_tune_risk != "None" and args.use_risk) or args.collect_data:
for i in range(args.num_envs):
f_obs[i] = torch.Tensor(obs["image"][i]).reshape(1, -1).to(device) if f_obs[i] is None else torch.concat([f_obs[i], torch.Tensor(obs["image"][i]).reshape(1, -1).to(device)], axis=0)
f_next_obs[i] = torch.Tensor(next_obs["image"][i]).reshape(1, -1).to(device) if f_next_obs[i] is None else torch.concat([f_next_obs[i], torch.Tensor(next_obs["image"][i]).reshape(1, -1).to(device)], axis=0)
f_actions[i] = torch.Tensor([actions[i]]).unsqueeze(0).to(device) if f_actions[i] is None else torch.concat([f_actions[i], torch.Tensor([actions[i]]).unsqueeze(0).to(device)], axis=0)
# f_rewards[i] = reward[i].unsqueeze(0).to(device) if f_rewards[i] is None else torch.concat([f_rewards[i], rewards[i].unsqueeze(0).to(device)], axis=0)
# f_risks = risk_ if f_risks is None else torch.concat([f_risks, risk_], axis=0)
# f_costs[i] = cost[i].unsqueeze(0).to(device) if f_costs[i] is None else torch.concat([f_costs[i], cost[i].unsqueeze(0).to(device)], axis=0)
# f_dones[i] = next_done[i].unsqueeze(0).to(device) if f_dones[i] is None else torch.concat([f_dones[i], next_done[i].unsqueeze(0).to(device)], axis=0)

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
for info in infos["final_info"]:
for i, info in enumerate(infos["final_info"]):
# Skip the envs that are not done
if "episode" not in info:
continue
total_cost += cost
ep_len = info["episode"]["l"]
e_risks = np.array(list(reversed(range(int(ep_len))))) if cost > 0 else np.array([int(ep_len)]*int(ep_len))
e_risks_quant = torch.Tensor(np.apply_along_axis(lambda x: np.histogram(x, bins=risk_bins)[0], 1, np.expand_dims(e_risks, 1)))
e_risks = torch.Tensor(e_risks)
if args.use_risk and args.fine_tune_risk != "None":
if args.risk_type == "binary":
risk_rb.add(f_obs[i], f_next_obs[i], f_actions[i], None, None, None, (e_risks <= args.fear_radius).float(), e_risks.unsqueeze(1))
else:
risk_rb.add(f_obs[i], f_next_obs[i], f_actions[i], None, None, None, e_risks_quant, e_risks.unsqueeze(1))

f_obs[i], f_next_obs[i], f_actions[i] = None, None, None
scores.append(info['episode']['r'])
print(f"global_step={global_step}, episodic_return={info['episode']['r']}, total cost={total_cost}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
Expand All @@ -220,9 +339,11 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
if global_step % args.train_frequency == 0:
data = rb.sample(args.batch_size)
with torch.no_grad():
target_max, _ = target_network(data.next_observations.reshape(args.batch_size, -1).float()).max(dim=1)
next_risk = risk_model(data.next_observations.reshape(args.batch_size, -1).float()) if args.use_risk else None
target_max, _ = target_network(data.next_observations.reshape(args.batch_size, -1).float(), next_risk).max(dim=1)
td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten())
old_val = q_network(data.observations.reshape(args.batch_size, -1).float()).gather(1, data.actions).squeeze()
risk = risk_model(data.observations.reshape(args.batch_size, -1).float()) if args.use_risk else None
old_val = q_network(data.observations.reshape(args.batch_size, -1).float(), risk).gather(1, data.actions).squeeze()
loss = F.mse_loss(td_target, old_val)

if global_step % 100 == 0:
Expand All @@ -236,6 +357,19 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
loss.backward()
optimizer.step()

## Update Risk Network
if args.use_risk and args.fine_tune_risk != "None" and global_step % args.risk_update_period == 0:
risk_model.train()
risk_data = risk_rb.sample(args.risk_batch_size)
pred = risk_model(risk_data["next_obs"].to(device))
risk_loss = criterion(pred, torch.argmax(risk_data["risks"].squeeze(), axis=1).to(device))
opt_risk.zero_grad()
risk_loss.backward()
opt_risk.step()
risk_model.eval()
writer.add_scalar("charts/risk_loss", risk_loss.item(), global_step)


# update target network
if global_step % args.target_network_frequency == 0:
for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()):
Expand Down

0 comments on commit e6fd43f

Please sign in to comment.