Skip to content

Commit

Permalink
Fix preprocessing and postprocessing of observations
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaixhin committed May 28, 2019
1 parent 2217d6f commit ee9b996
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 9 deletions.
14 changes: 10 additions & 4 deletions env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,20 @@
CONTROL_SUITE_ACTION_REPEATS = {'cartpole': 8, 'reacher': 4, 'finger': 2, 'cheetah': 4, 'ball_in_cup': 6, 'walker': 2}


def quantise_centre_dequantise(images, bit_depth):
images.div_(2 ** (8 - bit_depth)).floor_().div_(2 ** bit_depth).sub_(0.5) # Quantise to given bit depth and centre
images.add_(torch.rand_like(images).div_(2 ** bit_depth)) # Dequantise (to approx. match likelihood of PDF of continuous images vs. PMF of discrete images)
# Preprocesses an observation inplace (from float32 Tensor [0, 255] to [-0.5, 0.5])
def preprocess_observation_(observation, bit_depth):
observation.div_(2 ** (8 - bit_depth)).floor_().div_(2 ** bit_depth).sub_(0.5) # Quantise to given bit depth and centre
observation.add_(torch.rand_like(observation).div_(2 ** bit_depth)) # Dequantise (to approx. match likelihood of PDF of continuous images vs. PMF of discrete images)


# Postprocess an observation for storage (from float32 numpy array [-0.5, 0.5] to uint8 numpy array [0, 255])
def postprocess_observation(observation, bit_depth):
return np.clip(np.floor((observation + 0.5) * 2 ** bit_depth) * 2 ** (8 - bit_depth), 0, 2 ** 8 - 1).astype(np.uint8)


def _images_to_observation(images, bit_depth):
images = torch.tensor(cv2.resize(images, (64, 64), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1), dtype=torch.float32) # Resize and put channel first
quantise_centre_dequantise(images, bit_depth) # Quantise, centre and dequantise inplace
preprocess_observation_(images, bit_depth) # Quantise, centre and dequantise inplace
return images.unsqueeze(dim=0) # Add batch dimension


Expand Down
6 changes: 3 additions & 3 deletions memory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import torch
from env import quantise_centre_dequantise
from env import postprocess_observation, preprocess_observation_


class ExperienceReplay():
Expand All @@ -21,7 +21,7 @@ def append(self, observation, action, reward, done):
if self.symbolic_env:
self.observations[self.idx] = observation.numpy()
else:
self.observations[self.idx] = np.multiply(observation.numpy() + 0.5, 255.).astype(np.uint8) # Decentre and discretise visual observations (to save memory)
self.observations[self.idx] = postprocess_observation(observation.numpy(), self.bit_depth) # Decentre and discretise visual observations (to save memory)
self.actions[self.idx] = action.numpy()
self.rewards[self.idx] = reward
self.nonterminals[self.idx] = not done
Expand All @@ -42,7 +42,7 @@ def _retrieve_batch(self, idxs, n, L):
vec_idxs = idxs.transpose().reshape(-1) # Unroll indices
observations = torch.as_tensor(self.observations[vec_idxs].astype(np.float32))
if not self.symbolic_env:
quantise_centre_dequantise(observations, self.bit_depth) # Undo discretisation for visual observations
preprocess_observation_(observations, self.bit_depth) # Undo discretisation for visual observations
return observations.reshape(L, n, *observations.shape[1:]), self.actions[vec_idxs].reshape(L, n, -1), self.rewards[vec_idxs].reshape(L, n), self.nonterminals[vec_idxs].reshape(L, n, 1)

# Returns a batch of sequence chunks uniformly sampled from the memory
Expand Down
5 changes: 3 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@ def lineplot(xs, ys_population, title, path='', xaxis='episode'):

if isinstance(ys_population[0], list) or isinstance(ys_population[0], tuple):
ys = np.asarray(ys_population, dtype=np.float32)
ys_min, ys_max, ys_mean, ys_std = ys.min(1), ys.max(1), ys.mean(1), ys.std(1)
ys_min, ys_max, ys_mean, ys_std, ys_median = ys.min(1), ys.max(1), ys.mean(1), ys.std(1), np.median(ys, 1)
ys_upper, ys_lower = ys_mean + ys_std, ys_mean - ys_std

trace_max = Scatter(x=xs, y=ys_max, line=Line(color=max_colour, dash='dash'), name='Max')
trace_upper = Scatter(x=xs, y=ys_upper, line=Line(color=transparent), name='+1 Std. Dev.', showlegend=False)
trace_mean = Scatter(x=xs, y=ys_mean, fill='tonexty', fillcolor=std_colour, line=Line(color=mean_colour), name='Mean')
trace_lower = Scatter(x=xs, y=ys_lower, fill='tonexty', fillcolor=std_colour, line=Line(color=transparent), name='-1 Std. Dev.', showlegend=False)
trace_min = Scatter(x=xs, y=ys_min, line=Line(color=max_colour, dash='dash'), name='Min')
data = [trace_upper, trace_mean, trace_lower, trace_min, trace_max]
trace_median = Scatter(x=xs, y=ys_median, line=Line(color=max_colour), name='Median')
data = [trace_upper, trace_mean, trace_lower, trace_min, trace_max, trace_median]
else:
data = [Scatter(x=xs, y=ys_population, line=Line(color=mean_colour))]
plotly.offline.plot({
Expand Down

0 comments on commit ee9b996

Please sign in to comment.