From 5eb1a1483d2b402d91951c8a4cde39e1bfd22624 Mon Sep 17 00:00:00 2001 From: taochenshh Date: Wed, 23 Sep 2020 12:29:02 -0400 Subject: [PATCH] add loss to penalize the action mean drifting away from 1 --- easyrl/agents/ppo_agent.py | 12 +++++++----- examples/ppo.py | 4 ++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/easyrl/agents/ppo_agent.py b/easyrl/agents/ppo_agent.py index 86fa7da..26f501c 100644 --- a/easyrl/agents/ppo_agent.py +++ b/easyrl/agents/ppo_agent.py @@ -17,6 +17,7 @@ from easyrl.utils.torch_util import load_ckpt_data from easyrl.utils.torch_util import load_state_dict from easyrl.utils.torch_util import move_to +from torch.distributions import Independent from easyrl.utils.torch_util import save_model from easyrl.utils.torch_util import torch_float from easyrl.utils.torch_util import torch_to_np @@ -163,7 +164,7 @@ def optim_preprocess(self, data): raise ValueError('val, entropy, log_prob should be 1-dim!') return val, old_val, ret, log_prob, old_log_prob, adv, entropy - def cal_loss(self, val, old_val, ret, log_prob, old_log_prob, adv, entropy): + def cal_loss(self, val, old_val, ret, log_prob, old_log_prob, adv, entropy, act_dist): vf_loss = self.cal_val_loss(val=val, old_val=old_val, ret=ret) ratio = torch.exp(log_prob - old_log_prob) surr1 = adv * ratio @@ -171,13 +172,14 @@ def cal_loss(self, val, old_val, ret, log_prob, old_log_prob, adv, entropy): 1 - ppo_cfg.clip_range, 1 + ppo_cfg.clip_range) pg_loss = -torch.mean(torch.min(surr1, surr2)) - # if entropy.item() < 0.2: - # ent_coef = 1 - # else: - # ent_coef = ppo_cfg.ent_coef + ent_coef = ppo_cfg.ent_coef loss = pg_loss - entropy * ent_coef + \ vf_loss * ppo_cfg.vf_coef + if isinstance(act_dist, Independent): + dist = torch.abs(act_dist.mean) - 1.5 + act_penalty = torch.mean(torch.max(dist, torch.zeros_like(dist))) + loss = loss + act_penalty return loss, pg_loss, vf_loss, ratio def cal_val_loss(self, val, old_val, ret): diff --git a/examples/ppo.py b/examples/ppo.py index ff607ff..1b81fdb 100644 --- a/examples/ppo.py +++ b/examples/ppo.py @@ -12,7 +12,7 @@ from easyrl.runner.episodic_runner import EpisodicRunner from easyrl.utils.common import set_random_seed from easyrl.utils.gym_util import make_vec_env - +from pybullet_envs.gym_locomotion_envs import AntBulletEnv def main(): cfg_from_cmd(ppo_cfg) @@ -27,7 +27,7 @@ def main(): skip_params = [] ppo_cfg.restore_cfg(skip_params=skip_params) if ppo_cfg.env_name is None: - ppo_cfg.env_name = 'HalfCheetah-v2' + ppo_cfg.env_name = 'Ant-v2' set_random_seed(ppo_cfg.seed) env = make_vec_env(ppo_cfg.env_name, ppo_cfg.num_envs,