From 77d9e9cc68d8430b3804c0772693ac79c3b74eda Mon Sep 17 00:00:00 2001 From: Kaixhin Date: Fri, 14 Jun 2019 19:28:38 +0100 Subject: [PATCH] Add hyperparameters for data-efficient Rainbow --- README.md | 15 +++++++++++++++ agent.py | 2 +- main.py | 10 ++++++---- model.py | 22 +++++++++++++--------- 4 files changed, 35 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 8158b24..5574006 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,20 @@ Results and pretrained models can be found in the [releases](https://github.com/ - [x] Distributional RL [[7]](#references) - [x] Noisy Nets [[8]](#references) +Data-efficient Rainbow [[9]](#references) can be run using the following options: + +``` +python main.py --target-update 2000 \ + --T-max 100000 \ + --learn-start 1600 \ + --memory-capacity 100000 \ + --replay-frequency 1 \ + --multi-step 20 \ + --architecture canonical \ + --hidden-size 256 \ + --learning-rate 0.0001 +``` + Requirements ------------ @@ -46,3 +60,4 @@ References [6] [Reinforcement Learning: An Introduction](http://www.incompleteideas.net/sutton/book/ebook/the-book.html) [7] [A Distributional Perspective on Reinforcement Learning](https://arxiv.org/abs/1707.06887) [8] [Noisy Networks for Exploration](https://arxiv.org/abs/1706.10295) +[8] [When to Use Parametric Models in Reinforcement Learning?](https://arxiv.org/abs/1906.05243) diff --git a/agent.py b/agent.py index 4c095d4..ea8bb96 100644 --- a/agent.py +++ b/agent.py @@ -30,7 +30,7 @@ def __init__(self, args, env): for param in self.target_net.parameters(): param.requires_grad = False - self.optimiser = optim.Adam(self.online_net.parameters(), lr=args.lr, eps=args.adam_eps) + self.optimiser = optim.Adam(self.online_net.parameters(), lr=args.learning_rate, eps=args.adam_eps) # Resets noisy weights in all linear layers (of online net only) def reset_noise(self): diff --git a/main.py b/main.py index 0a3804d..82b1d7b 100644 --- a/main.py +++ b/main.py @@ -11,13 +11,15 @@ from tqdm import tqdm +# Note that hyperparameters may originally be reported in ATARI game frames instead of agent steps parser = argparse.ArgumentParser(description='Rainbow') parser.add_argument('--seed', type=int, default=123, help='Random seed') parser.add_argument('--disable-cuda', action='store_true', help='Disable CUDA') parser.add_argument('--game', type=str, default='space_invaders', choices=atari_py.list_games(), help='ATARI game') parser.add_argument('--T-max', type=int, default=int(50e6), metavar='STEPS', help='Number of training steps (4x number of frames)') -parser.add_argument('--max-episode-length', type=int, default=int(108e3), metavar='LENGTH', help='Max episode length (0 to disable)') +parser.add_argument('--max-episode-length', type=int, default=int(108e3), metavar='LENGTH', help='Max episode length in game frames (0 to disable)') parser.add_argument('--history-length', type=int, default=4, metavar='T', help='Number of consecutive states processed') +parser.add_argument('--architecture', type=str, default='canonical', choices=['canonical', 'data-efficient'], metavar='ARCH', help='Network architecture') parser.add_argument('--hidden-size', type=int, default=512, metavar='SIZE', help='Network hidden size') parser.add_argument('--noisy-std', type=float, default=0.1, metavar='σ', help='Initial standard deviation of noisy linear layers') parser.add_argument('--atoms', type=int, default=51, metavar='C', help='Discretised size of value distribution') @@ -30,12 +32,12 @@ parser.add_argument('--priority-weight', type=float, default=0.4, metavar='β', help='Initial prioritised experience replay importance sampling weight') parser.add_argument('--multi-step', type=int, default=3, metavar='n', help='Number of steps for multi-step return') parser.add_argument('--discount', type=float, default=0.99, metavar='γ', help='Discount factor') -parser.add_argument('--target-update', type=int, default=int(32e3), metavar='τ', help='Number of steps after which to update target network') +parser.add_argument('--target-update', type=int, default=int(8e3), metavar='τ', help='Number of steps after which to update target network') parser.add_argument('--reward-clip', type=int, default=1, metavar='VALUE', help='Reward clipping (0 to disable)') -parser.add_argument('--lr', type=float, default=0.0000625, metavar='η', help='Learning rate') +parser.add_argument('--learning-rate', type=float, default=0.0000625, metavar='η', help='Learning rate') parser.add_argument('--adam-eps', type=float, default=1.5e-4, metavar='ε', help='Adam epsilon') parser.add_argument('--batch-size', type=int, default=32, metavar='SIZE', help='Batch size') -parser.add_argument('--learn-start', type=int, default=int(80e3), metavar='STEPS', help='Number of steps before starting training') +parser.add_argument('--learn-start', type=int, default=int(20e3), metavar='STEPS', help='Number of steps before starting training') parser.add_argument('--evaluate', action='store_true', help='Evaluate only') parser.add_argument('--evaluation-interval', type=int, default=100000, metavar='STEPS', help='Number of training steps between evaluations') parser.add_argument('--evaluation-episodes', type=int, default=10, metavar='N', help='Number of evaluation episodes to average over') diff --git a/model.py b/model.py index 0e542e4..bb01885 100644 --- a/model.py +++ b/model.py @@ -50,19 +50,23 @@ def __init__(self, args, action_space): self.atoms = args.atoms self.action_space = action_space - self.conv1 = nn.Conv2d(args.history_length, 32, 8, stride=4, padding=1) - self.conv2 = nn.Conv2d(32, 64, 4, stride=2) - self.conv3 = nn.Conv2d(64, 64, 3) - self.fc_h_v = NoisyLinear(3136, args.hidden_size, std_init=args.noisy_std) - self.fc_h_a = NoisyLinear(3136, args.hidden_size, std_init=args.noisy_std) + if args.architecture == 'canonical': + self.convs = nn.Sequential(nn.Conv2d(args.history_length, 32, 8, stride=4, padding=1), nn.ReLU(), + nn.Conv2d(32, 64, 4, stride=2), nn.ReLU(), + nn.Conv2d(64, 64, 3, stride=1), nn.ReLU()) + self.conv_output_size = 3136 + elif args.architecture == 'data-efficient': + self.convs = nn.Sequential(nn.Conv2d(args.history_length, 32, 5, stride=5, padding=1), nn.ReLU(), + nn.Conv2d(32, 64, 5, stride=5), nn.ReLU()) + self.conv_output_size = 576 + self.fc_h_v = NoisyLinear(self.conv_output_size, args.hidden_size, std_init=args.noisy_std) + self.fc_h_a = NoisyLinear(self.conv_output_size, args.hidden_size, std_init=args.noisy_std) self.fc_z_v = NoisyLinear(args.hidden_size, self.atoms, std_init=args.noisy_std) self.fc_z_a = NoisyLinear(args.hidden_size, action_space * self.atoms, std_init=args.noisy_std) def forward(self, x, log=False): - x = F.relu(self.conv1(x)) - x = F.relu(self.conv2(x)) - x = F.relu(self.conv3(x)) - x = x.view(-1, 3136) + x = self.convs(x) + x = x.view(-1, self.conv_output_size) v = self.fc_z_v(F.relu(self.fc_h_v(x))) # Value stream a = self.fc_z_a(F.relu(self.fc_h_a(x))) # Advantage stream v, a = v.view(-1, 1, self.atoms), a.view(-1, self.action_space, self.atoms)