Skip to content

Commit

Permalink
sac with safety-gymnasium
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaustubh Mani committed Sep 29, 2023
1 parent 1c9da44 commit d6519be
Showing 1 changed file with 142 additions and 50 deletions.
192 changes: 142 additions & 50 deletions cleanrl/sac_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import time
from distutils.util import strtobool

import gym
import safety_gymnasium
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
Expand All @@ -14,12 +15,26 @@
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter

from src.models.risk_models import *
from src.datasets.risk_datasets import *
# from src.utils import *


def parse_args():
# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
help="the name of this experiment")
parser.add_argument("--early-termination", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="whether to terminate early i.e. when the catastrophe has happened")
parser.add_argument("--term-cost", type=int, default=1,
help="how many violations before you terminate")
parser.add_argument("--failure-penalty", type=float, default=0.0,
help="Reward Penalty when you fail")
parser.add_argument("--collect-data", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="store data while trianing")
parser.add_argument("--storage-path", type=str, default="./data/sac/term_1",
help="the storage path for the data collected")
parser.add_argument("--seed", type=int, default=1,
help="seed of the experiment")
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
Expand Down Expand Up @@ -64,26 +79,79 @@ def parse_args():
help="Entropy regularization coefficient.")
parser.add_argument("--autotune", type=lambda x:bool(strtobool(x)), default=True, nargs="?", const=True,
help="automatic tuning of the entropy coefficient")

## Arguments related to risk model
parser.add_argument("--use-risk", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="Use risk model or not ")
parser.add_argument("--risk-actor", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Use risk model in the actor or not ")
parser.add_argument("--risk-critic", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="Use risk model in the critic or not ")
parser.add_argument("--risk-model-path", type=str, default="None",
help="the id of the environment")
parser.add_argument("--binary-risk", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="Use risk model in the critic or not ")
parser.add_argument("--model-type", type=str, default="mlp",
help="specify the NN to use for the risk model")
parser.add_argument("--risk-bnorm", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
parser.add_argument("--risk-type", type=str, default="binary",
help="whether the risk is binary or continuous")
parser.add_argument("--fear-radius", type=int, default=5,
help="fear radius for training the risk model")
parser.add_argument("--num-risk-datapoints", type=int, default=1000,
help="fear radius for training the risk model")
parser.add_argument("--risk-update-period", type=int, default=1000,
help="how frequently to update the risk model")
parser.add_argument("--num-update-risk", type=int, default=10,
help="number of sgd steps to update the risk model")
parser.add_argument("--risk-lr", type=float, default=1e-7,
help="the learning rate of the optimizer")
parser.add_argument("--risk-batch-size", type=int, default=1000,
help="number of epochs to update the risk model")
parser.add_argument("--fine-tune-risk", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
parser.add_argument("--finetune-risk-online", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
parser.add_argument("--start-risk-update", type=int, default=10000,
help="number of epochs to update the risk model")
parser.add_argument("--rb-type", type=str, default="balanced",
help="which type of replay buffer to use for ")
parser.add_argument("--freeze-risk-layers", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
parser.add_argument("--weight", type=float, default=1.0,
help="weight for the 1 class in BCE loss")
parser.add_argument("--quantile-size", type=int, default=4, help="size of the risk quantile ")
parser.add_argument("--quantile-num", type=int, default=5, help="number of quantiles to make")
args = parser.parse_args()
# fmt: on
return args


def make_env(env_id, seed, idx, capture_video, run_name):

def make_env(cfg, seed, idx, capture_video, run_name):
def thunk():
env = gym.make(env_id)
if capture_video:
env = gym.make(cfg.env_id, render_mode="rgb_array", early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty)
else:
env = gym.make(cfg.env_id, early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty)
env = gym.wrappers.FlattenObservation(env) # deal with dm_control's Dict observation space
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env.seed(seed)
# env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
env = gym.wrappers.ClipAction(env)
env = gym.wrappers.NormalizeObservation(env)
env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10))
env = gym.wrappers.NormalizeReward(env, gamma=cfg.gamma)
env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10))
return env

return thunk


# ALGO LOGIC: initialize agent here:
class SoftQNetwork(nn.Module):
def __init__(self, env):
Expand Down Expand Up @@ -145,36 +213,36 @@ def get_action(self, x):


if __name__ == "__main__":
args = parse_args()
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
cfg = parse_args()
run_name = f"{cfg.env_id}__{cfg.exp_name}__{cfg.seed}__{int(time.time())}"
if cfg.track:
import wandb

wandb.init(
project=args.wandb_project_name,
entity=args.wandb_entity,
project=cfg.wandb_project_name,
entity=cfg.wandb_entity,
sync_tensorboard=True,
config=vars(args),
config=vars(cfg),
name=run_name,
monitor_gym=True,
save_code=True,
)
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(cfg).items()])),
)

# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
random.seed(cfg.seed)
np.random.seed(cfg.seed)
torch.manual_seed(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() and cfg.cuda else "cpu")

# env setup
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
envs = gym.vector.SyncVectorEnv([make_env(cfg, cfg.seed, 0, cfg.capture_video, run_name)])
assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"

max_action = float(envs.single_action_space.high[0])
Expand All @@ -186,71 +254,95 @@ def get_action(self, x):
qf2_target = SoftQNetwork(envs).to(device)
qf1_target.load_state_dict(qf1.state_dict())
qf2_target.load_state_dict(qf2.state_dict())
q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.q_lr)
actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr)
q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=cfg.q_lr)
actor_optimizer = optim.Adam(list(actor.parameters()), lr=cfg.policy_lr)

# Automatic entropy tuning
if args.autotune:
if cfg.autotune:
target_entropy = -torch.prod(torch.Tensor(envs.single_action_space.shape).to(device)).item()
log_alpha = torch.zeros(1, requires_grad=True, device=device)
alpha = log_alpha.exp().item()
a_optimizer = optim.Adam([log_alpha], lr=args.q_lr)
a_optimizer = optim.Adam([log_alpha], lr=cfg.q_lr)
else:
alpha = args.alpha
alpha = cfg.alpha

envs.single_observation_space.dtype = np.float32
rb = ReplayBuffer(
args.buffer_size,
cfg.buffer_size,
envs.single_observation_space,
envs.single_action_space,
device,
handle_timeout_termination=True,
)
start_time = time.time()

cum_cost, ep_cost = 0, 0

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
for global_step in range(args.total_timesteps):
obs, _ = envs.reset()
for global_step in range(cfg.total_timesteps):
# ALGO LOGIC: put action logic here
if global_step < args.learning_starts:
if global_step < cfg.learning_starts:
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
else:
actions, _, _ = actor.get_action(torch.Tensor(obs).to(device))
actions = actions.detach().cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)
# next_obs, rewards, dones, infos = envs.step(actions)
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
dones = np.logical_or(terminated, truncated)
if not dones:
cost = torch.Tensor(infos["cost"]).to(device).view(-1)
ep_cost += infos["cost"]; cum_cost += infos["cost"]
else:
cost = torch.Tensor(np.array([infos["final_info"][0]["cost"]])).to(device).view(-1)
ep_cost += np.array([infos["final_info"][0]["cost"]]); cum_cost += np.array([infos["final_info"][0]["cost"]])



# print(infos)
infos = [infos]
if "final_info" not in infos[0]:
continue

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
if "episode" in info.keys():
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
break
for info in infos[0]["final_info"]:
# print(info)
if info is None:
continue

print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("Performance/episodic_cost", ep_cost, global_step)
writer.add_scalar("Performance/cummulative_cost", cum_cost, global_step)
ep_cost = 0
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
# print(infos)
real_next_obs[idx] = infos[idx]["_final_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs

# ALGO LOGIC: training.
if global_step > args.learning_starts:
data = rb.sample(args.batch_size)
if global_step > cfg.learning_starts:
data = rb.sample(cfg.batch_size)
with torch.no_grad():
next_state_actions, next_state_log_pi, _ = actor.get_action(data.next_observations)
qf1_next_target = qf1_target(data.next_observations, next_state_actions)
qf2_next_target = qf2_target(data.next_observations, next_state_actions)
min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi
next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target).view(-1)
next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * cfg.gamma * (min_qf_next_target).view(-1)

qf1_a_values = qf1(data.observations, data.actions).view(-1)
qf2_a_values = qf2(data.observations, data.actions).view(-1)
qf1_a_values = qf1(data.observations.float(), data.actions.float()).view(-1)
qf2_a_values = qf2(data.observations.float(), data.actions.float()).view(-1)
qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
qf_loss = qf1_loss + qf2_loss
Expand All @@ -259,23 +351,23 @@ def get_action(self, x):
qf_loss.backward()
q_optimizer.step()

if global_step % args.policy_frequency == 0: # TD 3 Delayed update support
if global_step % cfg.policy_frequency == 0: # TD 3 Delayed update support
for _ in range(
args.policy_frequency
cfg.policy_frequency
): # compensate for the delay by doing 'actor_update_interval' instead of 1
pi, log_pi, _ = actor.get_action(data.observations)
qf1_pi = qf1(data.observations, pi)
qf2_pi = qf2(data.observations, pi)
qf1_pi = qf1(data.observations.float(), pi.float())
qf2_pi = qf2(data.observations.float(), pi.float())
min_qf_pi = torch.min(qf1_pi, qf2_pi).view(-1)
actor_loss = ((alpha * log_pi) - min_qf_pi).mean()

actor_optimizer.zero_grad()
actor_loss.backward()
actor_optimizer.step()

if args.autotune:
if cfg.autotune:
with torch.no_grad():
_, log_pi, _ = actor.get_action(data.observations)
_, log_pi, _ = actor.get_action(data.observations.float())
alpha_loss = (-log_alpha * (log_pi + target_entropy)).mean()

a_optimizer.zero_grad()
Expand All @@ -284,11 +376,11 @@ def get_action(self, x):
alpha = log_alpha.exp().item()

# update the target networks
if global_step % args.target_network_frequency == 0:
if global_step % cfg.target_network_frequency == 0:
for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):
target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
target_param.data.copy_(cfg.tau * param.data + (1 - cfg.tau) * target_param.data)
for param, target_param in zip(qf2.parameters(), qf2_target.parameters()):
target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
target_param.data.copy_(cfg.tau * param.data + (1 - cfg.tau) * target_param.data)

if global_step % 100 == 0:
writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step)
Expand All @@ -300,7 +392,7 @@ def get_action(self, x):
writer.add_scalar("losses/alpha", alpha, global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
if args.autotune:
if cfg.autotune:
writer.add_scalar("losses/alpha_loss", alpha_loss.item(), global_step)

envs.close()
Expand Down

0 comments on commit d6519be

Please sign in to comment.