Skip to content

Commit

Permalink
Add (disabled) learning rate schedule
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaixhin committed May 25, 2019
1 parent 992ba6a commit 2217d6f
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@
parser.add_argument('--global-kl-beta', type=float, default=0, metavar='βg', help='Global KL weight (0 to disable)')
parser.add_argument('--free-nats', type=float, default=3, metavar='F', help='Free nats')
parser.add_argument('--bit-depth', type=int, default=5, metavar='B', help='Image bit depth (quantisation)')
parser.add_argument('--learning-rate', type=float, default=1e-3, metavar='α', help='Learning rate') # Note that original has a linear learning rate decay, but it seems unlikely that this makes a significant difference
parser.add_argument('--learning-rate', type=float, default=1e-3, metavar='α', help='Learning rate')
parser.add_argument('--learning-rate-schedule', type=int, default=0, metavar='αS', help='Linear learning rate schedule (optimisation steps from 0 to final learning rate; 0 to disable)')
parser.add_argument('--adam-epsilon', type=float, default=1e-4, metavar='ε', help='Adam optimiser epsilon value')
# Note that original has a linear learning rate decay, but it seems unlikely that this makes a significant difference
parser.add_argument('--grad-clip-norm', type=float, default=1000, metavar='C', help='Gradient clipping norm')
parser.add_argument('--planning-horizon', type=int, default=12, metavar='H', help='Planning horizon distance')
parser.add_argument('--optimisation-iters', type=int, default=10, metavar='I', help='Planning optimisation iterations')
Expand Down Expand Up @@ -101,7 +104,7 @@
reward_model = RewardModel(args.belief_size, args.state_size, args.hidden_size, args.activation_function).to(device=args.device)
encoder = Encoder(args.symbolic_env, env.observation_size, args.embedding_size, args.activation_function).to(device=args.device)
param_list = list(transition_model.parameters()) + list(observation_model.parameters()) + list(reward_model.parameters()) + list(encoder.parameters())
optimiser = optim.Adam(param_list, lr=args.learning_rate, eps=1e-4)
optimiser = optim.Adam(param_list, lr=0 if args.learning_rate_schedule != 0 else args.learning_rate, eps=args.adam_epsilon)
if args.load_checkpoint > 0:
model_dicts = torch.load(os.path.join(results_dir, 'models_%d.pth' % args.load_checkpoint))
transition_model.load_state_dict(model_dicts['transition_model'])
Expand Down Expand Up @@ -160,6 +163,10 @@ def update_belief_and_act(args, env, planner, transition_model, encoder, belief,
if args.overshooting_reward_scale != 0:
reward_loss += (1 / args.overshooting_distance) * args.overshooting_reward_scale * F.mse_loss(bottle(reward_model, (beliefs, prior_states)) * seq_mask[:, :, 0], torch.cat(overshooting_vars[2], dim=1), reduction='none').mean(dim=(0, 1)) * (args.chunk_size - 1) # Update reward loss (compensating for extra average over each overshooting/open loop sequence)

# Apply linearly ramping learning rate schedule
if args.learning_rate_schedule != 0:
for group in optimiser.param_groups:
group['lr'] = min(group['lr'] + args.learning_rate / args.learning_rate_schedule, args.learning_rate)
# Update model parameters
optimiser.zero_grad()
(observation_loss + reward_loss + kl_loss).backward()
Expand Down Expand Up @@ -219,7 +226,7 @@ def update_belief_and_act(args, env, planner, transition_model, encoder, belief,
belief, posterior_state, action, next_observation, reward, done = update_belief_and_act(args, test_envs, planner, transition_model, encoder, belief, posterior_state, action, observation.to(device=args.device), test=False)
total_rewards += reward.numpy()
if not args.symbolic_env: # Collect real vs. predicted frames for video
video_frames.append(make_grid(torch.cat([observation, observation_model(belief, posterior_state).cpu()], dim=3), nrow=5).numpy() + 0.5) # Decentre
video_frames.append(make_grid(torch.cat([observation, observation_model(belief, posterior_state).cpu()], dim=3) + 0.5, nrow=5).numpy()) # Decentre
observation = next_observation
if done.sum().item() == args.test_episodes:
pbar.close()
Expand Down

0 comments on commit 2217d6f

Please sign in to comment.