From 992ba6a889b004a62e8be7d112e8a030e55eab88 Mon Sep 17 00:00:00 2001 From: Kaixhin Date: Wed, 22 May 2019 18:49:17 +0100 Subject: [PATCH] Centre data --- env.py | 10 +++++++--- main.py | 4 ++-- memory.py | 12 +++++++----- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/env.py b/env.py index 6a16f7d..d54adf1 100644 --- a/env.py +++ b/env.py @@ -8,10 +8,14 @@ CONTROL_SUITE_ACTION_REPEATS = {'cartpole': 8, 'reacher': 4, 'finger': 2, 'cheetah': 4, 'ball_in_cup': 6, 'walker': 2} -def _images_to_observation(images, bit_depth): - images = torch.tensor(cv2.resize(images, (64, 64), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1), dtype=torch.float32) # Resize - images.div_(2 ** (8 - bit_depth)).floor_().div_(2 ** bit_depth) # Quantise to given bit depth (note that original implementation also centers data) +def quantise_centre_dequantise(images, bit_depth): + images.div_(2 ** (8 - bit_depth)).floor_().div_(2 ** bit_depth).sub_(0.5) # Quantise to given bit depth and centre images.add_(torch.rand_like(images).div_(2 ** bit_depth)) # Dequantise (to approx. match likelihood of PDF of continuous images vs. PMF of discrete images) + + +def _images_to_observation(images, bit_depth): + images = torch.tensor(cv2.resize(images, (64, 64), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1), dtype=torch.float32) # Resize and put channel first + quantise_centre_dequantise(images, bit_depth) # Quantise, centre and dequantise inplace return images.unsqueeze(dim=0) # Add batch dimension diff --git a/main.py b/main.py index 231a9f1..f30b064 100644 --- a/main.py +++ b/main.py @@ -81,7 +81,7 @@ D = torch.load(os.path.join(results_dir, 'experience.pth')) metrics['steps'], metrics['episodes'] = [D.steps] * D.episodes, list(range(1, D.episodes + 1)) else: - D = ExperienceReplay(args.experience_size, args.symbolic_env, env.observation_size, env.action_size, args.device) + D = ExperienceReplay(args.experience_size, args.symbolic_env, env.observation_size, env.action_size, args.bit_depth, args.device) # Initialise dataset D with S random seed episodes for s in range(1, args.seed_episodes + 1): observation, done, t = env.reset(), False, 0 @@ -219,7 +219,7 @@ def update_belief_and_act(args, env, planner, transition_model, encoder, belief, belief, posterior_state, action, next_observation, reward, done = update_belief_and_act(args, test_envs, planner, transition_model, encoder, belief, posterior_state, action, observation.to(device=args.device), test=False) total_rewards += reward.numpy() if not args.symbolic_env: # Collect real vs. predicted frames for video - video_frames.append(make_grid(torch.cat([observation, observation_model(belief, posterior_state).cpu()], dim=3), nrow=5).numpy()) + video_frames.append(make_grid(torch.cat([observation, observation_model(belief, posterior_state).cpu()], dim=3), nrow=5).numpy() + 0.5) # Decentre observation = next_observation if done.sum().item() == args.test_episodes: pbar.close() diff --git a/memory.py b/memory.py index afd109d..7ee35af 100644 --- a/memory.py +++ b/memory.py @@ -1,9 +1,10 @@ import numpy as np import torch +from env import quantise_centre_dequantise class ExperienceReplay(): - def __init__(self, size, symbolic_env, observation_size, action_size, device): + def __init__(self, size, symbolic_env, observation_size, action_size, bit_depth, device): self.device = device self.symbolic_env = symbolic_env self.size = size @@ -14,12 +15,13 @@ def __init__(self, size, symbolic_env, observation_size, action_size, device): self.idx = 0 self.full = False # Tracks if memory has been filled/all slots are valid self.steps, self.episodes = 0, 0 # Tracks how much experience has been used in total + self.bit_depth = bit_depth def append(self, observation, action, reward, done): if self.symbolic_env: self.observations[self.idx] = observation.numpy() else: - self.observations[self.idx] = np.multiply(observation.numpy(), 255.).astype(np.uint8) # Discretise visual observations (to save memory) + self.observations[self.idx] = np.multiply(observation.numpy() + 0.5, 255.).astype(np.uint8) # Decentre and discretise visual observations (to save memory) self.actions[self.idx] = action.numpy() self.rewards[self.idx] = reward self.nonterminals[self.idx] = not done @@ -38,12 +40,12 @@ def _sample_idx(self, L): def _retrieve_batch(self, idxs, n, L): vec_idxs = idxs.transpose().reshape(-1) # Unroll indices - observations = self.observations[vec_idxs].astype(np.float32) + observations = torch.as_tensor(self.observations[vec_idxs].astype(np.float32)) if not self.symbolic_env: - observations = np.divide(observations, 255.) # Undo discretisation for visual observations + quantise_centre_dequantise(observations, self.bit_depth) # Undo discretisation for visual observations return observations.reshape(L, n, *observations.shape[1:]), self.actions[vec_idxs].reshape(L, n, -1), self.rewards[vec_idxs].reshape(L, n), self.nonterminals[vec_idxs].reshape(L, n, 1) # Returns a batch of sequence chunks uniformly sampled from the memory def sample(self, n, L): batch = self._retrieve_batch(np.asarray([self._sample_idx(L) for _ in range(n)]), n, L) - return [torch.from_numpy(item).to(device=self.device) for item in batch] + return [torch.as_tensor(item).to(device=self.device) for item in batch]