Skip to content

Commit

Permalink
state action risk all new
Browse files Browse the repository at this point in the history
  • Loading branch information
manila95 committed Mar 20, 2024
1 parent d99febf commit e77187e
Show file tree
Hide file tree
Showing 2 changed files with 389 additions and 84 deletions.
183 changes: 99 additions & 84 deletions cleanrl/ppo_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,82 +6,23 @@
from distutils.util import strtobool

import gymnasium as gym
import safety_gymnasium
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.normal import Normal
from torch.utils.tensorboard import SummaryWriter


def parse_args():
# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
help="the name of this experiment")
parser.add_argument("--seed", type=int, default=1,
help="seed of the experiment")
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, `torch.backends.cudnn.deterministic=False`")
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, cuda will be enabled by default")
parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="if toggled, this experiment will be tracked with Weights and Biases")
parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
help="the wandb's project name")
parser.add_argument("--wandb-entity", type=str, default=None,
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to capture videos of the agent performances (check out `videos` folder)")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="HalfCheetah-v4",
help="the id of the environment")
parser.add_argument("--total-timesteps", type=int, default=1000000,
help="total timesteps of the experiments")
parser.add_argument("--learning-rate", type=float, default=3e-4,
help="the learning rate of the optimizer")
parser.add_argument("--num-envs", type=int, default=1,
help="the number of parallel game environments")
parser.add_argument("--num-steps", type=int, default=2048,
help="the number of steps to run in each environment per policy rollout")
parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggle learning rate annealing for policy and value networks")
parser.add_argument("--gamma", type=float, default=0.99,
help="the discount factor gamma")
parser.add_argument("--gae-lambda", type=float, default=0.95,
help="the lambda for the general advantage estimation")
parser.add_argument("--num-minibatches", type=int, default=32,
help="the number of mini-batches")
parser.add_argument("--update-epochs", type=int, default=10,
help="the K epochs to update the policy")
parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggles advantages normalization")
parser.add_argument("--clip-coef", type=float, default=0.2,
help="the surrogate clipping coefficient")
parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
parser.add_argument("--ent-coef", type=float, default=0.0,
help="coefficient of the entropy")
parser.add_argument("--vf-coef", type=float, default=0.5,
help="coefficient of the value function")
parser.add_argument("--max-grad-norm", type=float, default=0.5,
help="the maximum norm for the gradient clipping")
parser.add_argument("--target-kl", type=float, default=None,
help="the target KL divergence threshold")
args = parser.parse_args()
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
# fmt: on
return args
from utils import *


def make_env(env_id, idx, capture_video, run_name, gamma):
def thunk():
if capture_video:
env = gym.make(env_id, render_mode="rgb_array")
env = gym.make(args.env_id, render_mode="rgb_array", early_termination=args.early_termination, term_cost=args.term_cost, failure_penalty=args.failure_penalty, reward_goal=args.reward_goal, reward_distance=args.reward_distance)
else:
env = gym.make(env_id)
env = gym.make(args.env_id, early_termination=args.early_termination, term_cost=args.term_cost, failure_penalty=args.failure_penalty, reward_goal=args.reward_goal, reward_distance=args.reward_distance)
env = gym.wrappers.FlattenObservation(env) # deal with dm_control's Dict observation space
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
Expand All @@ -104,28 +45,32 @@ def layer_init(layer, std=np.sqrt(2), bias_const=0.0):


class Agent(nn.Module):
def __init__(self, envs):
def __init__(self, envs, risk_size=0):
super().__init__()
self.critic = nn.Sequential(
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod()+risk_size, 64)),
nn.Tanh(),
layer_init(nn.Linear(64, 64)),
nn.Tanh(),
layer_init(nn.Linear(64, 1), std=1.0),
)
self.actor_mean = nn.Sequential(
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod()+risk_size, 64)),
nn.Tanh(),
layer_init(nn.Linear(64, 64)),
nn.Tanh(),
layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01),
)
self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape)))

def get_value(self, x):
def get_value(self, x, risk=None):
if risk is not None:
x = torch.cat([x, risk], axis=-1)
return self.critic(x)

def get_action_and_value(self, x, action=None):
def get_action_and_value(self, x, risk=None, action=None):
if risk is not None:
x = torch.cat([x, risk], axis=-1)
action_mean = self.actor_mean(x)
action_logstd = self.actor_logstd.expand_as(action_mean)
action_std = torch.exp(action_logstd)
Expand All @@ -138,17 +83,17 @@ def get_action_and_value(self, x, action=None):
if __name__ == "__main__":
args = parse_args()
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
import wandb

wandb.init(
project=args.wandb_project_name,
entity=args.wandb_entity,
sync_tensorboard=True,
config=vars(args),
name=run_name,
monitor_gym=True,
save_code=True,
# if args.track:
import wandb

wandb.init(
project=args.wandb_project_name,
entity=args.wandb_entity,
sync_tensorboard=True,
config=vars(args),
name=run_name,
monitor_gym=True,
save_code=True,
)
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
Expand All @@ -170,9 +115,37 @@ def get_action_and_value(self, x, action=None):
)
assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"

agent = Agent(envs).to(device)

risk_size = args.quantile_num if args.use_risk else 0


agent = Agent(envs, risk_size).to(device)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)


if args.use_risk:
risk_input_size = np.array(envs.single_observation_space.shape).prod() if args.risk_input == "state" else np.array(envs.single_observation_space.shape).prod() + np.prod(envs.single_action_space.shape)
risk_model = BayesRiskEst(risk_input_size, risk_size=risk_size)

if os.path.exists(args.risk_model_path):
risk_model.load_state_dict(torch.load(args.risk_model_path, map_location=device))
print("Risk model loaded successfully")

risk_model.to(device)

if args.fine_tune_risk:
## Intializing replay buffer
risk_rb = ReplayBuffer()
## Initializing optimizer
opt_risk = optim.Adam(risk_model.parameters(), lr=args.risk_lr, eps=1e-5)
## Risk loss function
risk_criterion = nn.NLLLoss().to(device)

## Dimensions
action_dim = envs.single_action_space.shape[0]
obs_dim = envs.single_observation_space.shape[0]


# ALGO Logic: Storage setup
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
Expand All @@ -189,6 +162,8 @@ def get_action_and_value(self, x, action=None):
next_done = torch.zeros(args.num_envs).to(device)
num_updates = args.total_timesteps // args.batch_size

actions_null = torch.zeros((1, action_dim)).to(device)
ep_trans = {"obs": [None]*args.num_envs, "actions": [None]*args.num_envs}
for update in range(1, num_updates + 1):
# Annealing the rate if instructed to do so.
if args.anneal_lr:
Expand All @@ -203,32 +178,65 @@ def get_action_and_value(self, x, action=None):

# ALGO LOGIC: action logic
with torch.no_grad():
action, logprob, _, value = agent.get_action_and_value(next_obs)
risk = risk_model(torch.cat([next_obs, actions_null], axis=-1)) if args.use_risk else None
action, logprob, _, value = agent.get_action_and_value(next_obs, risk)
values[step] = value.flatten()
actions[step] = action
logprobs[step] = logprob

# TRY NOT TO MODIFY: execute the game and log data.
obs_ = next_obs
next_obs, reward, terminated, truncated, infos = envs.step(action.cpu().numpy())
done = np.logical_or(terminated, truncated)
rewards[step] = torch.tensor(reward).to(device).view(-1)
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)
if args.use_risk and args.fine_tune_risk:
## Storing episodic data to dump later
for i in range(args.num_envs):
obs_i, next_obs_i, actions_i = obs_[i].unsqueeze(0), next_obs[i].unsqueeze(0), action[i].unsqueeze(0)
obs_i = torch.cat([obs_i, actions_i], axis=-1)
next_obs_i = torch.cat([next_obs_i, actions_null], axis=-1)
# print(obs_i.size(), next_obs_i.size(), actions_i.size())
ep_trans["obs"][i] = obs_i if ep_trans["obs"][i] is None else torch.cat([ep_trans["obs"][i], obs_i], axis=0)
# print(ep_trans["obs"][i].size(), obs_i.size())
ep_trans["obs"][i] = torch.cat([ep_trans["obs"][i], next_obs_i], axis=0)


## Train Risk model
if args.use_risk and args.fine_tune_risk:
risk_loss = 0
if (global_step >= args.start_risk_update) and (global_step % args.risk_update_period == 0):
for _ in range(args.num_risk_epochs):
risk_data = risk_rb.sample(args.num_risk_samples)
risk_loss += train_risk(args, risk_model, risk_data, risk_criterion, opt_risk, args.num_update_risk, device)
writer.add_scalar("risk/risk_loss", risk_loss / args.num_risk_epochs, global_step)


# Only print when at least 1 env is done
if "final_info" not in infos:
continue

for info in infos["final_info"]:
for i, info in enumerate(infos["final_info"]):
# Skip the envs that are not done
if info is None:
continue
ep_len = info["episode"]["l"]


if args.use_risk and args.fine_tune_risk:
ep_risks = np.array(list(reversed(range(int(ep_len))))) if info["cost"] > 0 else np.array([int(ep_len)]*int(ep_len))
ep_risks = torch.repeat_interleave(torch.Tensor(ep_risks), 2).to(device).unsqueeze(1)
risk_rb.add(ep_trans["obs"][i], ep_risks, ep_risks)
ep_trans["obs"][i] = None

print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)

# bootstrap value if not done
with torch.no_grad():
next_value = agent.get_value(next_obs).reshape(1, -1)
next_obs_risk = risk_model(torch.cat([next_obs, torch.zeros((next_obs.size()[0], action_dim)).to(device)], axis=-1)) if args.use_risk else None
next_value = agent.get_value(next_obs, next_obs_risk).reshape(1, -1)
advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
for t in reversed(range(args.num_steps)):
Expand All @@ -250,6 +258,10 @@ def get_action_and_value(self, x, action=None):
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)

if args.use_risk:
with torch.no_grad():
b_risks = risk_model(torch.cat([b_obs, b_actions], axis=-1))

# Optimizing the policy and value network
b_inds = np.arange(args.batch_size)
clipfracs = []
Expand All @@ -259,7 +271,10 @@ def get_action_and_value(self, x, action=None):
end = start + args.minibatch_size
mb_inds = b_inds[start:end]

_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds])
if args.use_risk:
_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_risks[mb_inds], b_actions[mb_inds])
else:
_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], None, b_actions[mb_inds])
logratio = newlogprob - b_logprobs[mb_inds]
ratio = logratio.exp()

Expand Down
Loading

0 comments on commit e77187e

Please sign in to comment.