Skip to content

Commit

Permalink
Centre data
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaixhin committed May 22, 2019
1 parent 76c17ce commit 992ba6a
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
10 changes: 7 additions & 3 deletions env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@
CONTROL_SUITE_ACTION_REPEATS = {'cartpole': 8, 'reacher': 4, 'finger': 2, 'cheetah': 4, 'ball_in_cup': 6, 'walker': 2}


def _images_to_observation(images, bit_depth):
images = torch.tensor(cv2.resize(images, (64, 64), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1), dtype=torch.float32) # Resize
images.div_(2 ** (8 - bit_depth)).floor_().div_(2 ** bit_depth) # Quantise to given bit depth (note that original implementation also centers data)
def quantise_centre_dequantise(images, bit_depth):
images.div_(2 ** (8 - bit_depth)).floor_().div_(2 ** bit_depth).sub_(0.5) # Quantise to given bit depth and centre
images.add_(torch.rand_like(images).div_(2 ** bit_depth)) # Dequantise (to approx. match likelihood of PDF of continuous images vs. PMF of discrete images)


def _images_to_observation(images, bit_depth):
images = torch.tensor(cv2.resize(images, (64, 64), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1), dtype=torch.float32) # Resize and put channel first
quantise_centre_dequantise(images, bit_depth) # Quantise, centre and dequantise inplace
return images.unsqueeze(dim=0) # Add batch dimension


Expand Down
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
D = torch.load(os.path.join(results_dir, 'experience.pth'))
metrics['steps'], metrics['episodes'] = [D.steps] * D.episodes, list(range(1, D.episodes + 1))
else:
D = ExperienceReplay(args.experience_size, args.symbolic_env, env.observation_size, env.action_size, args.device)
D = ExperienceReplay(args.experience_size, args.symbolic_env, env.observation_size, env.action_size, args.bit_depth, args.device)
# Initialise dataset D with S random seed episodes
for s in range(1, args.seed_episodes + 1):
observation, done, t = env.reset(), False, 0
Expand Down Expand Up @@ -219,7 +219,7 @@ def update_belief_and_act(args, env, planner, transition_model, encoder, belief,
belief, posterior_state, action, next_observation, reward, done = update_belief_and_act(args, test_envs, planner, transition_model, encoder, belief, posterior_state, action, observation.to(device=args.device), test=False)
total_rewards += reward.numpy()
if not args.symbolic_env: # Collect real vs. predicted frames for video
video_frames.append(make_grid(torch.cat([observation, observation_model(belief, posterior_state).cpu()], dim=3), nrow=5).numpy())
video_frames.append(make_grid(torch.cat([observation, observation_model(belief, posterior_state).cpu()], dim=3), nrow=5).numpy() + 0.5) # Decentre
observation = next_observation
if done.sum().item() == args.test_episodes:
pbar.close()
Expand Down
12 changes: 7 additions & 5 deletions memory.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import numpy as np
import torch
from env import quantise_centre_dequantise


class ExperienceReplay():
def __init__(self, size, symbolic_env, observation_size, action_size, device):
def __init__(self, size, symbolic_env, observation_size, action_size, bit_depth, device):
self.device = device
self.symbolic_env = symbolic_env
self.size = size
Expand All @@ -14,12 +15,13 @@ def __init__(self, size, symbolic_env, observation_size, action_size, device):
self.idx = 0
self.full = False # Tracks if memory has been filled/all slots are valid
self.steps, self.episodes = 0, 0 # Tracks how much experience has been used in total
self.bit_depth = bit_depth

def append(self, observation, action, reward, done):
if self.symbolic_env:
self.observations[self.idx] = observation.numpy()
else:
self.observations[self.idx] = np.multiply(observation.numpy(), 255.).astype(np.uint8) # Discretise visual observations (to save memory)
self.observations[self.idx] = np.multiply(observation.numpy() + 0.5, 255.).astype(np.uint8) # Decentre and discretise visual observations (to save memory)
self.actions[self.idx] = action.numpy()
self.rewards[self.idx] = reward
self.nonterminals[self.idx] = not done
Expand All @@ -38,12 +40,12 @@ def _sample_idx(self, L):

def _retrieve_batch(self, idxs, n, L):
vec_idxs = idxs.transpose().reshape(-1) # Unroll indices
observations = self.observations[vec_idxs].astype(np.float32)
observations = torch.as_tensor(self.observations[vec_idxs].astype(np.float32))
if not self.symbolic_env:
observations = np.divide(observations, 255.) # Undo discretisation for visual observations
quantise_centre_dequantise(observations, self.bit_depth) # Undo discretisation for visual observations
return observations.reshape(L, n, *observations.shape[1:]), self.actions[vec_idxs].reshape(L, n, -1), self.rewards[vec_idxs].reshape(L, n), self.nonterminals[vec_idxs].reshape(L, n, 1)

# Returns a batch of sequence chunks uniformly sampled from the memory
def sample(self, n, L):
batch = self._retrieve_batch(np.asarray([self._sample_idx(L) for _ in range(n)]), n, L)
return [torch.from_numpy(item).to(device=self.device) for item in batch]
return [torch.as_tensor(item).to(device=self.device) for item in batch]

0 comments on commit 992ba6a

Please sign in to comment.