Skip to content

Commit

Permalink
Make results subdirs
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaixhin committed May 22, 2019
1 parent a5b1a60 commit 76c17ce
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,3 @@ dmypy.json

# Results
results/
checkpoints/
35 changes: 19 additions & 16 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.distributions import Normal
from torch.distributions.kl import kl_divergence
from torch.nn import functional as F
from torchvision.utils import save_image
from torchvision.utils import make_grid, save_image
from tqdm import tqdm
from env import CONTROL_SUITE_ENVS, Env, GYM_ENVS, EnvBatcher
from memory import ExperienceReplay
Expand All @@ -17,6 +17,7 @@

# Hyperparameters
parser = argparse.ArgumentParser(description='PlaNet')
parser.add_argument('--id', type=str, default='default', help='Experiment ID')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='Random seed')
parser.add_argument('--disable-cuda', action='store_true', help='Disable CUDA')
parser.add_argument('--env', type=str, default='Pendulum-v0', choices=GYM_ENVS + CONTROL_SUITE_ENVS, help='Gym/Control Suite environment')
Expand Down Expand Up @@ -62,22 +63,22 @@


# Setup
results_dir = os.path.join('results', args.id)
os.makedirs(results_dir, exist_ok=True)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available() and not args.disable_cuda:
args.device = torch.device('cuda')
torch.cuda.manual_seed(args.seed)
else:
args.device = torch.device('cpu')
os.makedirs('results', exist_ok=True)
os.makedirs('checkpoints', exist_ok=True)
metrics = {'steps': [], 'episodes': [], 'train_rewards': [], 'test_episodes': [], 'test_rewards': [], 'observation_loss': [], 'reward_loss': [], 'kl_loss': []}


# Initialise training environment and experience replay memory
env = Env(args.env, args.symbolic_env, args.seed, args.max_episode_length, args.action_repeat, args.bit_depth)
if args.load_experience:
D = torch.load(os.path.join('checkpoints', 'experience.pth'))
D = torch.load(os.path.join(results_dir, 'experience.pth'))
metrics['steps'], metrics['episodes'] = [D.steps] * D.episodes, list(range(1, D.episodes + 1))
else:
D = ExperienceReplay(args.experience_size, args.symbolic_env, env.observation_size, env.action_size, args.device)
Expand All @@ -102,7 +103,7 @@
param_list = list(transition_model.parameters()) + list(observation_model.parameters()) + list(reward_model.parameters()) + list(encoder.parameters())
optimiser = optim.Adam(param_list, lr=args.learning_rate, eps=1e-4)
if args.load_checkpoint > 0:
model_dicts = torch.load(os.path.join('checkpoints', 'models_%d.pth' % args.load_checkpoint))
model_dicts = torch.load(os.path.join(results_dir, 'models_%d.pth' % args.load_checkpoint))
transition_model.load_state_dict(model_dicts['transition_model'])
observation_model.load_state_dict(model_dicts['observation_model'])
reward_model.load_state_dict(model_dicts['reward_model'])
Expand Down Expand Up @@ -172,9 +173,9 @@ def update_belief_and_act(args, env, planner, transition_model, encoder, belief,
metrics['observation_loss'].append(losses[0])
metrics['reward_loss'].append(losses[1])
metrics['kl_loss'].append(losses[2])
lineplot(metrics['episodes'][-len(metrics['observation_loss']):], metrics['observation_loss'], 'observation_loss', 'results')
lineplot(metrics['episodes'][-len(metrics['reward_loss']):], metrics['reward_loss'], 'reward_loss', 'results')
lineplot(metrics['episodes'][-len(metrics['kl_loss']):], metrics['kl_loss'], 'kl_loss', 'results')
lineplot(metrics['episodes'][-len(metrics['observation_loss']):], metrics['observation_loss'], 'observation_loss', results_dir)
lineplot(metrics['episodes'][-len(metrics['reward_loss']):], metrics['reward_loss'], 'reward_loss', results_dir)
lineplot(metrics['episodes'][-len(metrics['kl_loss']):], metrics['kl_loss'], 'kl_loss', results_dir)


# Data collection
Expand All @@ -197,7 +198,7 @@ def update_belief_and_act(args, env, planner, transition_model, encoder, belief,
metrics['steps'].append(t + metrics['steps'][-1])
metrics['episodes'].append(episode)
metrics['train_rewards'].append(total_reward)
lineplot(metrics['episodes'][-len(metrics['train_rewards']):], metrics['train_rewards'], 'train_rewards', 'results')
lineplot(metrics['episodes'][-len(metrics['train_rewards']):], metrics['train_rewards'], 'train_rewards', results_dir)


# Test model
Expand All @@ -218,7 +219,7 @@ def update_belief_and_act(args, env, planner, transition_model, encoder, belief,
belief, posterior_state, action, next_observation, reward, done = update_belief_and_act(args, test_envs, planner, transition_model, encoder, belief, posterior_state, action, observation.to(device=args.device), test=False)
total_rewards += reward.numpy()
if not args.symbolic_env: # Collect real vs. predicted frames for video
video_frames.append(torch.cat(torch.cat([observation, observation_model(belief, posterior_state).cpu()], dim=3).split(1, dim=0), dim=2)[0].numpy())
video_frames.append(make_grid(torch.cat([observation, observation_model(belief, posterior_state).cpu()], dim=3), nrow=5).numpy())
observation = next_observation
if done.sum().item() == args.test_episodes:
pbar.close()
Expand All @@ -227,11 +228,13 @@ def update_belief_and_act(args, env, planner, transition_model, encoder, belief,
# Update and plot reward metrics (and write video if applicable) and save metrics
metrics['test_episodes'].append(episode)
metrics['test_rewards'].append(total_rewards.tolist())
lineplot(metrics['test_episodes'], metrics['test_rewards'], 'test_rewards', 'results')
lineplot(np.asarray(metrics['steps'])[np.asarray(metrics['test_episodes']) - 1], metrics['test_rewards'], 'test_rewards_steps', 'results', xaxis='step')
lineplot(metrics['test_episodes'], metrics['test_rewards'], 'test_rewards', results_dir)
lineplot(np.asarray(metrics['steps'])[np.asarray(metrics['test_episodes']) - 1], metrics['test_rewards'], 'test_rewards_steps', results_dir, xaxis='step')
if not args.symbolic_env:
write_video(video_frames, 'test_episode_%s' % str(episode).zfill(len(str(args.episodes))), 'results')
torch.save(metrics, os.path.join('results', 'metrics.pth'))
episode_str = str(episode).zfill(len(str(args.episodes)))
write_video(video_frames, 'test_episode_%s' % episode_str, results_dir) # Lossy compression
save_image(torch.as_tensor(video_frames[-1]), os.path.join(results_dir, 'test_episode_%s.png' % episode_str))
torch.save(metrics, os.path.join(results_dir, 'metrics.pth'))

# Set models to train mode
transition_model.train()
Expand All @@ -244,9 +247,9 @@ def update_belief_and_act(args, env, planner, transition_model, encoder, belief,

# Checkpoint models
if episode % args.checkpoint_interval == 0:
torch.save({'transition_model': transition_model.state_dict(), 'observation_model': observation_model.state_dict(), 'reward_model': reward_model.state_dict(), 'encoder': encoder.state_dict(), 'optimiser': optimiser.state_dict()}, os.path.join('checkpoints', 'models_%d.pth' % episode))
torch.save({'transition_model': transition_model.state_dict(), 'observation_model': observation_model.state_dict(), 'reward_model': reward_model.state_dict(), 'encoder': encoder.state_dict(), 'optimiser': optimiser.state_dict()}, os.path.join(results_dir, 'models_%d.pth' % episode))
if args.checkpoint_experience:
torch.save(D, os.path.join('checkpoints', 'experience.pth')) # Warning: will fail with MemoryError with large memory sizes
torch.save(D, os.path.join(results_dir, 'experience.pth')) # Warning: will fail with MemoryError with large memory sizes


# Close training environment
Expand Down

0 comments on commit 76c17ce

Please sign in to comment.