From 26a96780a22d88a3dd8ce8e99a4148fdcb5754b5 Mon Sep 17 00:00:00 2001 From: beduffy Date: Mon, 18 Mar 2019 17:00:03 +0000 Subject: [PATCH] Renamed args.tau to args.gae_lambda and fixed typo --- main.py | 4 ++-- train.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index 369fdf9..bbb8114 100644 --- a/main.py +++ b/main.py @@ -20,8 +20,8 @@ help='learning rate (default: 0.0001)') parser.add_argument('--gamma', type=float, default=0.99, help='discount factor for rewards (default: 0.99)') -parser.add_argument('--tau', type=float, default=1.00, - help='parameter for GAE (default: 1.00)') +parser.add_argument('--gae-lambda', type=float, default=1.00, + help='lambda parameter for GAE (default: 1.00)') parser.add_argument('--entropy-coef', type=float, default=0.01, help='entropy term coefficient (default: 0.01)') parser.add_argument('--value-loss-coef', type=float, default=0.5, diff --git a/train.py b/train.py index 5184b91..da31b5e 100644 --- a/train.py +++ b/train.py @@ -92,10 +92,10 @@ def train(rank, args, shared_model, counter, lock, optimizer=None): advantage = R - values[i] value_loss = value_loss + 0.5 * advantage.pow(2) - # Generalized Advantage Estimataion + # Generalized Advantage Estimation delta_t = rewards[i] + args.gamma * \ values[i + 1] - values[i] - gae = gae * args.gamma * args.tau + delta_t + gae = gae * args.gamma * args.gae_lambda + delta_t policy_loss = policy_loss - \ log_probs[i] * gae.detach() - args.entropy_coef * entropies[i]