Skip to content

Commit

Permalink
Merge pull request taochenshh#2 from taochenshh/sac
Browse files Browse the repository at this point in the history
Sac
  • Loading branch information
taochenshh authored Apr 4, 2021
2 parents 49ee848 + a0add65 commit 6af1a1a
Show file tree
Hide file tree
Showing 46 changed files with 2,454 additions and 537 deletions.
16 changes: 16 additions & 0 deletions easyrl/agents/base_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
from dataclasses import dataclass
import gym
from easyrl.envs.vec_normalize import VecNormalize
from easyrl.utils.gym_util import save_vec_normalized_env
from easyrl.utils.gym_util import load_vec_normalized_env

@dataclass
class BaseAgent:
env: gym.Env

def get_action(self, ob, sample=True, **kwargs):
raise NotImplementedError

def optimize(self, data, **kwargs):
raise NotImplementedError

def save_env(self, save_dir):
if isinstance(self.env, VecNormalize):
save_vec_normalized_env(self.env, save_dir)

def load_env(self, save_dir):
if isinstance(self.env, VecNormalize):
load_vec_normalized_env(self.env, save_dir)
196 changes: 81 additions & 115 deletions easyrl/agents/ppo_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from functools import partial
from pathlib import Path

import numpy as np
import torch
Expand All @@ -8,63 +8,70 @@
from torch.optim.lr_scheduler import LambdaLR

from easyrl.agents.base_agent import BaseAgent
from easyrl.configs.ppo_config import ppo_cfg
from easyrl.configs import cfg
from easyrl.utils.common import linear_decay_percent
from easyrl.utils.rl_logger import logger
from easyrl.utils.torch_util import action_entropy
from easyrl.utils.torch_util import action_from_dist
from easyrl.utils.torch_util import action_log_prob
from easyrl.utils.torch_util import get_latest_ckpt
from easyrl.utils.torch_util import clip_grad
from easyrl.utils.torch_util import load_ckpt_data
from easyrl.utils.torch_util import load_state_dict
from easyrl.utils.torch_util import load_torch_model
from easyrl.utils.torch_util import move_to
from easyrl.utils.torch_util import save_model
from easyrl.utils.torch_util import torch_float
from easyrl.utils.torch_util import torch_to_np


@dataclass
class PPOAgent(BaseAgent):
def __init__(self, actor, critic, same_body=False):
self.actor = actor
self.critic = critic
self.actor.to(ppo_cfg.device)
self.critic.to(ppo_cfg.device)
self.same_body = same_body
if ppo_cfg.vf_loss_type == 'mse':
self.val_loss_criterion = nn.MSELoss().to(ppo_cfg.device)
elif ppo_cfg.vf_loss_type == 'smoothl1':
self.val_loss_criterion = nn.SmoothL1Loss().to(ppo_cfg.device)
actor: nn.Module
critic: nn.Module
same_body: float = False

def __post_init__(self):
move_to([self.actor, self.critic],
device=cfg.alg.device)
if cfg.alg.vf_loss_type == 'mse':
self.val_loss_criterion = nn.MSELoss().to(cfg.alg.device)
elif cfg.alg.vf_loss_type == 'smoothl1':
self.val_loss_criterion = nn.SmoothL1Loss().to(cfg.alg.device)
else:
raise TypeError(f'Unknown value loss type: {ppo_cfg.vf_loss_type}!')
raise TypeError(f'Unknown value loss type: {cfg.alg.vf_loss_type}!')
all_params = list(self.actor.parameters()) + list(self.critic.parameters())
# keep unique elements only. The following code works for python >=3.7
# for earlier version of python, u need to use OrderedDict
self.all_params = dict.fromkeys(all_params).keys()
if ppo_cfg.max_steps > ppo_cfg.max_decay_steps:
raise ValueError('max_steps should be no greater than max_decay_steps.')
total_epochs = int(np.ceil(ppo_cfg.max_decay_steps / (ppo_cfg.num_envs *
ppo_cfg.episode_steps)))
if ppo_cfg.linear_decay_clip_range:
self.clip_range_decay_rate = ppo_cfg.clip_range / float(total_epochs)
if (cfg.alg.linear_decay_lr or cfg.alg.linear_decay_clip_range) and \
cfg.alg.max_steps > cfg.alg.max_decay_steps:
logger.warning('max_steps should not be greater than max_decay_steps.')
cfg.alg.max_decay_steps = int(cfg.alg.max_steps * 1.5)
logger.warning(f'Resetting max_decay_steps to {cfg.alg.max_decay_steps}!')
total_epochs = int(np.ceil(cfg.alg.max_decay_steps / (cfg.alg.num_envs *
cfg.alg.episode_steps)))
if cfg.alg.linear_decay_clip_range:
self.clip_range_decay_rate = cfg.alg.clip_range / float(total_epochs)

p_lr_lambda = partial(linear_decay_percent,
total_epochs=total_epochs)
optim_args = dict(
lr=ppo_cfg.policy_lr,
weight_decay=ppo_cfg.weight_decay
lr=cfg.alg.policy_lr,
weight_decay=cfg.alg.weight_decay
)
if not ppo_cfg.sgd:
optim_args['amsgrad'] = ppo_cfg.use_amsgrad
if not cfg.alg.sgd:
optim_args['amsgrad'] = cfg.alg.use_amsgrad
optim_func = optim.Adam
else:
optim_args['nesterov'] = True if ppo_cfg.momentum > 0 else False
optim_args['momentum'] = ppo_cfg.momentum
optim_args['nesterov'] = True if cfg.alg.momentum > 0 else False
optim_args['momentum'] = cfg.alg.momentum
optim_func = optim.SGD
if self.same_body:
optim_args['params'] = self.all_params
else:
optim_args['params'] = [{'params': self.actor.parameters(),
'lr': ppo_cfg.policy_lr},
'lr': cfg.alg.policy_lr},
{'params': self.critic.parameters(),
'lr': ppo_cfg.value_lr}]
'lr': cfg.alg.value_lr}]

self.optimizer = optim_func(**optim_args)

Expand All @@ -76,12 +83,11 @@ def __init__(self, actor, critic, same_body=False):
total_epochs=total_epochs)
self.lr_scheduler = LambdaLR(optimizer=self.optimizer,
lr_lambda=[p_lr_lambda, v_lr_lambda])
self.in_training = False

@torch.no_grad()
def get_action(self, ob, sample=True, *args, **kwargs):
self.eval_mode()
t_ob = torch.from_numpy(ob).float().to(ppo_cfg.device)
t_ob = torch_float(ob, device=cfg.alg.device)
act_dist, val = self.get_act_val(t_ob)
action = action_from_dist(act_dist,
sample=sample)
Expand All @@ -95,7 +101,7 @@ def get_action(self, ob, sample=True, *args, **kwargs):
return torch_to_np(action), action_info

def get_act_val(self, ob, *args, **kwargs):
ob = torch_float(ob, device=ppo_cfg.device)
ob = torch_float(ob, device=cfg.alg.device)
act_dist, body_out = self.actor(ob)
if self.same_body:
val, body_out = self.critic(body_x=body_out)
Expand All @@ -107,49 +113,41 @@ def get_act_val(self, ob, *args, **kwargs):
@torch.no_grad()
def get_val(self, ob, *args, **kwargs):
self.eval_mode()
ob = torch_float(ob, device=ppo_cfg.device)
ob = torch_float(ob, device=cfg.alg.device)
val, body_out = self.critic(x=ob)
val = val.squeeze(-1)
return val

def optimize(self, data, *args, **kwargs):
self.train_mode()
pre_res = self.optim_preprocess(data)
val, old_val, ret, log_prob, old_log_prob, adv, entropy = pre_res
entropy = torch.mean(entropy)
loss_res = self.cal_loss(val=val,
old_val=old_val,
ret=ret,
log_prob=log_prob,
old_log_prob=old_log_prob,
adv=adv,
entropy=entropy)
processed_data = pre_res
processed_data['entropy'] = torch.mean(processed_data['entropy'])
loss_res = self.cal_loss(**processed_data)
loss, pg_loss, vf_loss, ratio = loss_res
self.optimizer.zero_grad()
loss.backward()
grad_norm = None
if ppo_cfg.max_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(self.all_params,
ppo_cfg.max_grad_norm)

grad_norm = clip_grad(self.all_params, cfg.alg.max_grad_norm)
self.optimizer.step()
with torch.no_grad():
approx_kl = 0.5 * torch.mean(torch.pow(old_log_prob - log_prob, 2))
clip_frac = np.mean(np.abs(torch_to_np(ratio) - 1.0) > ppo_cfg.clip_range)
approx_kl = 0.5 * torch.mean(torch.pow(processed_data['old_log_prob'] -
processed_data['log_prob'], 2))
clip_frac = np.mean(np.abs(torch_to_np(ratio) - 1.0) > cfg.alg.clip_range)
optim_info = dict(
pg_loss=pg_loss.item(),
vf_loss=vf_loss.item(),
total_loss=loss.item(),
entropy=entropy.item(),
entropy=processed_data['entropy'].item(),
approx_kl=approx_kl.item(),
clip_frac=clip_frac
)
if grad_norm is not None:
optim_info['grad_norm'] = grad_norm
optim_info['grad_norm'] = grad_norm
return optim_info

def optim_preprocess(self, data):
self.train_mode()
for key, val in data.items():
data[key] = torch_float(val, device=ppo_cfg.device)
data[key] = torch_float(val, device=cfg.alg.device)
ob = data['ob']
action = data['action']
ret = data['ret']
Expand All @@ -162,26 +160,35 @@ def optim_preprocess(self, data):
entropy = action_entropy(act_dist, log_prob)
if not all([x.ndim == 1 for x in [val, entropy, log_prob]]):
raise ValueError('val, entropy, log_prob should be 1-dim!')
return val, old_val, ret, log_prob, old_log_prob, adv, entropy
processed_data = dict(
val=val,
old_val=old_val,
ret=ret,
log_prob=log_prob,
old_log_prob=old_log_prob,
adv=adv,
entropy=entropy
)
return processed_data

def cal_loss(self, val, old_val, ret, log_prob, old_log_prob, adv, entropy):
vf_loss = self.cal_val_loss(val=val, old_val=old_val, ret=ret)
ratio = torch.exp(log_prob - old_log_prob)
surr1 = adv * ratio
surr2 = adv * torch.clamp(ratio,
1 - ppo_cfg.clip_range,
1 + ppo_cfg.clip_range)
1 - cfg.alg.clip_range,
1 + cfg.alg.clip_range)
pg_loss = -torch.mean(torch.min(surr1, surr2))

loss = pg_loss - entropy * ppo_cfg.ent_coef + \
vf_loss * ppo_cfg.vf_coef
loss = pg_loss - entropy * cfg.alg.ent_coef + \
vf_loss * cfg.alg.vf_coef
return loss, pg_loss, vf_loss, ratio

def cal_val_loss(self, val, old_val, ret):
if ppo_cfg.clip_vf_loss:
if cfg.alg.clip_vf_loss:
clipped_val = old_val + torch.clamp(val - old_val,
-ppo_cfg.clip_range,
ppo_cfg.clip_range)
-cfg.alg.clip_range,
cfg.alg.clip_range)
vf_loss1 = torch.pow(val - ret, 2)
vf_loss2 = torch.pow(clipped_val - ret, 2)
vf_loss = 0.5 * torch.mean(torch.max(vf_loss1,
Expand All @@ -192,48 +199,28 @@ def cal_val_loss(self, val, old_val, ret):
return vf_loss

def train_mode(self):
self.in_training = True
self.actor.train()
self.critic.train()

def eval_mode(self):
self.in_training = False
self.actor.eval()
self.critic.eval()

def decay_lr(self):
self.lr_scheduler.step()

def get_lr(self):
try:
cur_lr = self.lr_scheduler.get_last_lr()
except AttributeError:
cur_lr = self.lr_scheduler.get_lr()
cur_lr = self.lr_scheduler.get_lr()
lrs = {'policy_lr': cur_lr[0]}
if len(cur_lr) > 1:
lrs['value_lr'] = cur_lr[1]
return lrs

def decay_clip_range(self):
ppo_cfg.clip_range -= self.clip_range_decay_rate
cfg.alg.clip_range -= self.clip_range_decay_rate

def save_model(self, is_best=False, step=None):
if not ppo_cfg.save_best_only and step is not None:
ckpt_file = ppo_cfg.model_dir \
.joinpath('ckpt_{:012d}.pt'.format(step))
else:
ckpt_file = None
if is_best:
best_model_file = ppo_cfg.model_dir \
.joinpath('model_best.pt')
else:
best_model_file = None

if not ppo_cfg.save_best_only:
saved_model_files = sorted(ppo_cfg.model_dir.glob('*.pt'))
if len(saved_model_files) > ppo_cfg.max_saved_models:
saved_model_files[0].unlink()

self.save_env(cfg.alg.model_dir)
data_to_save = {
'step': step,
'actor_state_dict': self.actor.state_dict(),
Expand All @@ -242,47 +229,26 @@ def save_model(self, is_best=False, step=None):
'lr_scheduler_state_dict': self.lr_scheduler.state_dict()
}

if ppo_cfg.linear_decay_clip_range:
data_to_save['clip_range'] = ppo_cfg.clip_range
if cfg.alg.linear_decay_clip_range:
data_to_save['clip_range'] = cfg.alg.clip_range
data_to_save['clip_range_decay_rate'] = self.clip_range_decay_rate
logger.info(f'Exploration steps: {step}')
for fl in [ckpt_file, best_model_file]:
if fl is not None:
logger.info(f'Saving checkpoint: {fl}.')
torch.save(data_to_save, fl)
save_model(data_to_save, cfg.alg, is_best=is_best, step=step)

def load_model(self, step=None, pretrain_model=None):
if pretrain_model is not None:
# if the pretrain_model is the path of the folder
# that contains the checkpoint files, then it will
# load the most recent one.
if isinstance(pretrain_model, str):
pretrain_model = Path(pretrain_model)
if pretrain_model.suffix != '.pt':
pretrain_model = get_latest_ckpt(pretrain_model)
ckpt_data = load_torch_model(pretrain_model)
load_state_dict(self.actor,
ckpt_data['actor_state_dict'])
load_state_dict(self.critic,
ckpt_data['critic_state_dict'])
return
if step is None:
ckpt_file = Path(ppo_cfg.model_dir) \
.joinpath('model_best.pt')
else:
ckpt_file = Path(ppo_cfg.model_dir) \
.joinpath('ckpt_{:012d}.pt'.format(step))

ckpt_data = load_torch_model(ckpt_file)
self.load_env(cfg.alg.model_dir)
ckpt_data = load_ckpt_data(cfg.alg, step=step,
pretrain_model=pretrain_model)
load_state_dict(self.actor,
ckpt_data['actor_state_dict'])
load_state_dict(self.critic,
ckpt_data['critic_state_dict'])
if pretrain_model is not None:
return
self.optimizer.load_state_dict(ckpt_data['optim_state_dict'])
self.lr_scheduler.load_state_dict(ckpt_data['lr_scheduler_state_dict'])
if ppo_cfg.linear_decay_clip_range:
if cfg.alg.linear_decay_clip_range:
self.clip_range_decay_rate = ckpt_data['clip_range_decay_rate']
ppo_cfg.clip_range = ckpt_data['clip_range']
cfg.alg.clip_range = ckpt_data['clip_range']
return ckpt_data['step']

def print_param_grad_status(self):
Expand Down
Loading

0 comments on commit 6af1a1a

Please sign in to comment.