Skip to content

Commit

Permalink
add loss to penalize the action mean drifting away from 1
Browse files Browse the repository at this point in the history
  • Loading branch information
taochenshh committed Sep 23, 2020
1 parent b50fe4c commit 5eb1a14
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
12 changes: 7 additions & 5 deletions easyrl/agents/ppo_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from easyrl.utils.torch_util import load_ckpt_data
from easyrl.utils.torch_util import load_state_dict
from easyrl.utils.torch_util import move_to
from torch.distributions import Independent
from easyrl.utils.torch_util import save_model
from easyrl.utils.torch_util import torch_float
from easyrl.utils.torch_util import torch_to_np
Expand Down Expand Up @@ -163,21 +164,22 @@ def optim_preprocess(self, data):
raise ValueError('val, entropy, log_prob should be 1-dim!')
return val, old_val, ret, log_prob, old_log_prob, adv, entropy

def cal_loss(self, val, old_val, ret, log_prob, old_log_prob, adv, entropy):
def cal_loss(self, val, old_val, ret, log_prob, old_log_prob, adv, entropy, act_dist):
vf_loss = self.cal_val_loss(val=val, old_val=old_val, ret=ret)
ratio = torch.exp(log_prob - old_log_prob)
surr1 = adv * ratio
surr2 = adv * torch.clamp(ratio,
1 - ppo_cfg.clip_range,
1 + ppo_cfg.clip_range)
pg_loss = -torch.mean(torch.min(surr1, surr2))
# if entropy.item() < 0.2:
# ent_coef = 1
# else:
# ent_coef = ppo_cfg.ent_coef

ent_coef = ppo_cfg.ent_coef
loss = pg_loss - entropy * ent_coef + \
vf_loss * ppo_cfg.vf_coef
if isinstance(act_dist, Independent):
dist = torch.abs(act_dist.mean) - 1.5
act_penalty = torch.mean(torch.max(dist, torch.zeros_like(dist)))
loss = loss + act_penalty
return loss, pg_loss, vf_loss, ratio

def cal_val_loss(self, val, old_val, ret):
Expand Down
4 changes: 2 additions & 2 deletions examples/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from easyrl.runner.episodic_runner import EpisodicRunner
from easyrl.utils.common import set_random_seed
from easyrl.utils.gym_util import make_vec_env

from pybullet_envs.gym_locomotion_envs import AntBulletEnv

def main():
cfg_from_cmd(ppo_cfg)
Expand All @@ -27,7 +27,7 @@ def main():
skip_params = []
ppo_cfg.restore_cfg(skip_params=skip_params)
if ppo_cfg.env_name is None:
ppo_cfg.env_name = 'HalfCheetah-v2'
ppo_cfg.env_name = 'Ant-v2'
set_random_seed(ppo_cfg.seed)
env = make_vec_env(ppo_cfg.env_name,
ppo_cfg.num_envs,
Expand Down

0 comments on commit 5eb1a14

Please sign in to comment.