Skip to content

Commit

Permalink
state action risk
Browse files Browse the repository at this point in the history
  • Loading branch information
manila95 committed Dec 18, 2023
1 parent cd0f682 commit a974cd6
Showing 1 changed file with 34 additions and 34 deletions.
68 changes: 34 additions & 34 deletions cleanrl/ppo_continuous_action_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import gymnasium as gym
import numpy as np
import torch
import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
Expand Down Expand Up @@ -247,7 +248,9 @@ def get_action_and_value(self, x, risk, action=None):
probs = Normal(action_mean, action_std)
if action is None:
action = probs.sample()
return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.get_value(x, risk)
candidates = probs.sample_n(5)

return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.get_value(x, risk), candidates

class RiskAgent1(nn.Module):
def __init__(self, envs, linear_size=64, risk_size=2):
Expand Down Expand Up @@ -280,7 +283,9 @@ def get_action_and_value(self, x, risk, action=None):
probs = Normal(action_mean, action_std)
if action is None:
action = probs.sample()
return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)
candidates = probs.sample_n(5)

return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x), candidates


class Agent(nn.Module):
Expand Down Expand Up @@ -312,7 +317,8 @@ def get_action_and_value(self, x, action=None):
probs = Normal(action_mean, action_std)
if action is None:
action = probs.sample()
return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)
candidates = probs.sample_n(5)
return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x), candidates


class ContRiskAgent(nn.Module):
Expand Down Expand Up @@ -412,28 +418,6 @@ def risk_sgd_step(cfg, model, data, criterion, opt, device):
return loss


def train_risk(cfg, model, data, criterion, opt, device):
model.train()
dataset = RiskyDataset(data["next_obs"].to('cpu'), None, data["risks"].to('cpu'), False, risk_type=cfg.risk_type,
fear_clip=None, fear_radius=cfg.fear_radius, one_hot=True, quantile_size=cfg.quantile_size, quantile_num=cfg.quantile_num)
dataloader = DataLoader(dataset, batch_size=cfg.risk_batch_size, shuffle=True, num_workers=10, generator=torch.Generator(device='cpu'))
net_loss = 0
for batch in dataloader:
pred = model(get_risk_obs(cfg, batch[0]).to(device))
if cfg.model_type == "mlp":
loss = criterion(pred, batch[1].squeeze().to(device))
else:
loss = criterion(pred, torch.argmax(batch[1].squeeze(), axis=1).to(device))
opt.zero_grad()
loss.backward()
opt.step()

net_loss += loss.item()
torch.save(model.state_dict(), os.path.join(wandb.run.dir, "risk_model.pt"))
wandb.save("risk_model.pt")
model.eval()
print("risk_loss:", net_loss)
return net_loss

def test_policy(cfg, agent, envs, device, risk_model=None):
envs = gym.vector.SyncVectorEnv(
Expand Down Expand Up @@ -574,7 +558,7 @@ def train(cfg):
agent = RiskAgent(envs=envs, risk_size=risk_size).to(device)
#else:
# agent = ContRiskAgent(envs=envs).to(device)
risk_model = risk_model_class[cfg.model_type][cfg.risk_type](obs_size=96, batch_norm=True, out_size=risk_size)
risk_model = risk_model_class[cfg.model_type][cfg.risk_type](obs_size=96, batch_norm=True, out_size=risk_size, action_size=envs.single_action_space.shape[0], model_type="state_action_risk")
if os.path.exists(cfg.risk_model_path):
risk_model.load_state_dict(torch.load(cfg.risk_model_path, map_location=device))
print("Pretrained risk model loaded successfully")
Expand Down Expand Up @@ -702,14 +686,20 @@ def train(cfg):
# ALGO LOGIC: action logic
with torch.no_grad():
if cfg.use_risk:
action, logprob, _, value = agent.get_action_and_value(next_obs, next_risk)
action, logprob, _, value, candidates = agent.get_action_and_value(next_obs, next_risk)
else:
action, logprob, _, value = agent.get_action_and_value(next_obs)
action, logprob, _, value, candidates = agent.get_action_and_value(next_obs)

values[step] = value.flatten()
actions[step] = action
logprobs[step] = logprob

if cfg.use_risk:
with torch.no_grad():
candidates = candidates.squeeze()
# print(next_obs_risk.repeat(5, 1).size(), candidates.size())
candidates_risk = torch.sum(torch.exp(risk_model(next_obs_risk.repeat(5, 1).to(device), candidates))[:, :2], -1)
action = candidates[torch.argmin(candidates_risk)]
# TRY NOT TO MODIFY: execute the game and log data.
next_obs, reward, terminated, truncated, infos = envs.step(action.cpu().numpy())
done = np.logical_or(terminated, truncated)
Expand Down Expand Up @@ -738,12 +728,13 @@ def train(cfg):
for i in range(cfg.num_envs):
f_obs[i] = obs_[i].unsqueeze(0).to(device) if f_obs[i] is None else torch.concat([f_obs[i], obs_[i].unsqueeze(0).to(device)], axis=0)
f_next_obs[i] = next_obs[i].unsqueeze(0).to(device) if f_next_obs[i] is None else torch.concat([f_next_obs[i], next_obs[i].unsqueeze(0).to(device)], axis=0)
f_actions[i] = action[i].unsqueeze(0).to(device) if f_actions[i] is None else torch.concat([f_actions[i], action[i].unsqueeze(0).to(device)], axis=0)
f_actions[i] = action.unsqueeze(0).to(device) if f_actions[i] is None else torch.concat([f_actions[i], action.unsqueeze(0).to(device)], axis=0)
f_rewards[i] = reward[i].unsqueeze(0).to(device) if f_rewards[i] is None else torch.concat([f_rewards[i], reward[i].unsqueeze(0).to(device)], axis=0)
# f_risks = risk_ if f_risks is None else torch.concat([f_risks, risk_], axis=0)
f_costs[i] = cost[i].unsqueeze(0).to(device) if f_costs[i] is None else torch.concat([f_costs[i], cost[i].unsqueeze(0).to(device)], axis=0)
f_dones[i] = next_done[i].unsqueeze(0).to(device) if f_dones[i] is None else torch.concat([f_dones[i], next_done[i].unsqueeze(0).to(device)], axis=0)

# print(f_actions[0].size())
obs_ = next_obs
# if global_step % cfg.update_risk_model == 0 and cfg.fine_tune_risk:
# if cfg.use_risk and (global_step > cfg.start_risk_update and cfg.fine_tune_risk) and global_step % cfg.risk_update_period == 0:
Expand All @@ -767,15 +758,24 @@ def train(cfg):
writer.add_scalar("risk/risk_loss", risk_loss, global_step)
elif cfg.fine_tune_risk == "off" and cfg.use_risk:
if cfg.use_risk and (global_step > cfg.start_risk_update and cfg.fine_tune_risk) and global_step % cfg.risk_update_period == 0:
for epoch in range(cfg.num_risk_epochs):
for epoch in tqdm.tqdm(range(cfg.num_risk_epochs)):
total_risk_updates +=1
print(total_risk_updates)
if cfg.finetune_risk_online:
print("I am online")
data = rb.slice_data(-cfg.risk_batch_size*cfg.num_update_risk, 0)
else:
data = rb.sample(cfg.risk_batch_size*cfg.num_update_risk)
risk_loss = train_risk(cfg, risk_model, data, criterion, opt_risk, device)
state = torch.cat([data["obs"], data["next_obs"]], axis=0)
actions = torch.cat([data["actions"], torch.zeros_like(data["actions"])], axis=0)
dist_to_fail = torch.cat([data["dist_to_fail"], data["dist_to_fail"]], axis=0)
print(state.size(), actions.size(), dist_to_fail.size())
risk_dataset = RiskyDataset(state.to(device), actions.to(device), dist_to_fail.to(device), True, risk_type=cfg.risk_type,
fear_clip=None, fear_radius=cfg.fear_radius, one_hot=True, quantile_size=cfg.quantile_size, quantile_num=cfg.quantile_num)
risk_dataloader = DataLoader(risk_dataset, batch_size=cfg.risk_batch_size, shuffle=True, num_workers=4, generator=torch.Generator(device=device))

risk_loss = train_risk(risk_model, risk_dataloader, criterion, opt_risk, 1, device, train_mode="state_action")

writer.add_scalar("risk/risk_loss", risk_loss, global_step)

# Only print when at least 1 env is done
Expand Down Expand Up @@ -853,7 +853,7 @@ def train(cfg):
if cfg.risk_type == "binary":
rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], (f_risks <= cfg.fear_radius).float(), e_risks.unsqueeze(1))
else:
rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], f_risks, f_risks)
rb.add(get_risk_obs(cfg, f_obs[i]), get_risk_obs(cfg, f_next_obs[i]), f_actions[i], f_rewards[i], f_dones[i], f_costs[i], f_risks, f_risks)

f_obs[i] = None
f_next_obs[i] = None
Expand Down Expand Up @@ -915,9 +915,9 @@ def train(cfg):
mb_inds = b_inds[start:end]

if cfg.use_risk:
_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_risks[mb_inds], b_actions[mb_inds])
_, newlogprob, entropy, newvalue, cands = agent.get_action_and_value(b_obs[mb_inds], b_risks[mb_inds], b_actions[mb_inds])
else:
_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds])
_, newlogprob, entropy, newvalue, cands = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds])

logratio = newlogprob - b_logprobs[mb_inds]
ratio = logratio.exp()
Expand Down

0 comments on commit a974cd6

Please sign in to comment.