From a5b1a60ca975309c1dda3b65e0d804ce39545d19 Mon Sep 17 00:00:00 2001 From: Kaixhin Date: Tue, 21 May 2019 15:40:11 +0100 Subject: [PATCH] Quantise images to 5-bit depth --- env.py | 27 ++++++++++++++++++--------- main.py | 7 ++++--- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/env.py b/env.py index 0b38e8d..6a16f7d 100644 --- a/env.py +++ b/env.py @@ -8,8 +8,15 @@ CONTROL_SUITE_ACTION_REPEATS = {'cartpole': 8, 'reacher': 4, 'finger': 2, 'cheetah': 4, 'ball_in_cup': 6, 'walker': 2} +def _images_to_observation(images, bit_depth): + images = torch.tensor(cv2.resize(images, (64, 64), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1), dtype=torch.float32) # Resize + images.div_(2 ** (8 - bit_depth)).floor_().div_(2 ** bit_depth) # Quantise to given bit depth (note that original implementation also centers data) + images.add_(torch.rand_like(images).div_(2 ** bit_depth)) # Dequantise (to approx. match likelihood of PDF of continuous images vs. PMF of discrete images) + return images.unsqueeze(dim=0) # Add batch dimension + + class ControlSuiteEnv(): - def __init__(self, env, symbolic, seed, max_episode_length, action_repeat): + def __init__(self, env, symbolic, seed, max_episode_length, action_repeat, bit_depth): from dm_control import suite from dm_control.suite.wrappers import pixels domain, task = env.split('-') @@ -21,6 +28,7 @@ def __init__(self, env, symbolic, seed, max_episode_length, action_repeat): self.action_repeat = action_repeat if action_repeat != CONTROL_SUITE_ACTION_REPEATS[domain]: print('Using action repeat %d; recommended action repeat for domain is %d' % (action_repeat, CONTROL_SUITE_ACTION_REPEATS[domain])) + self.bit_depth = bit_depth def reset(self): self.t = 0 # Reset internal timer @@ -28,7 +36,7 @@ def reset(self): if self.symbolic: return torch.tensor(np.concatenate([np.array([obs]) if isinstance(obs, float) else obs for obs in state.observation.values()], axis=0), dtype=torch.float32).unsqueeze(dim=0) else: - return torch.tensor(cv2.resize(self._env.physics.render(camera_id=0), (64, 64), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1), dtype=torch.float32).div_(255).unsqueeze(dim=0) + return _images_to_observation(self._env.physics.render(camera_id=0), self.bit_depth) def step(self, action): action = action.detach().numpy() @@ -43,7 +51,7 @@ def step(self, action): if self.symbolic: observation = torch.tensor(np.concatenate([np.array([obs]) if isinstance(obs, float) else obs for obs in state.observation.values()], axis=0), dtype=torch.float32).unsqueeze(dim=0) else: - observation = torch.tensor(cv2.resize(self._env.physics.render(camera_id=0), (64, 64), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1), dtype=torch.float32).div_(255).unsqueeze(dim=0) + observation = _images_to_observation(self._env.physics.render(camera_id=0), self.bit_depth) return observation, reward, done def render(self): @@ -70,13 +78,14 @@ def sample_random_action(self): class GymEnv(): - def __init__(self, env, symbolic, seed, max_episode_length, action_repeat): + def __init__(self, env, symbolic, seed, max_episode_length, action_repeat, bit_depth): import gym self.symbolic = symbolic self._env = gym.make(env) self._env.seed(seed) self.max_episode_length = max_episode_length self.action_repeat = action_repeat + self.bit_depth = bit_depth def reset(self): self.t = 0 # Reset internal timer @@ -84,7 +93,7 @@ def reset(self): if self.symbolic: return torch.tensor(state, dtype=torch.float32).unsqueeze(dim=0) else: - return torch.tensor(cv2.resize(self._env.render(mode='rgb_array'), (64, 64), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1), dtype=torch.float32).div_(255).unsqueeze(dim=0) + return _images_to_observation(self._env.render(mode='rgb_array'), self.bit_depth) def step(self, action): action = action.detach().numpy() @@ -99,7 +108,7 @@ def step(self, action): if self.symbolic: observation = torch.tensor(state, dtype=torch.float32).unsqueeze(dim=0) else: - observation = torch.tensor(cv2.resize(self._env.render(mode='rgb_array'), (64, 64), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1), dtype=torch.float32).div_(255).unsqueeze(dim=0) + observation = _images_to_observation(self._env.render(mode='rgb_array'), self.bit_depth) return observation, reward, done def render(self): @@ -121,11 +130,11 @@ def sample_random_action(self): return torch.from_numpy(self._env.action_space.sample()) -def Env(env, symbolic, seed, max_episode_length, action_repeat): +def Env(env, symbolic, seed, max_episode_length, action_repeat, bit_depth): if env in GYM_ENVS: - return GymEnv(env, symbolic, seed, max_episode_length, action_repeat) + return GymEnv(env, symbolic, seed, max_episode_length, action_repeat, bit_depth) elif env in CONTROL_SUITE_ENVS: - return ControlSuiteEnv(env, symbolic, seed, max_episode_length, action_repeat) + return ControlSuiteEnv(env, symbolic, seed, max_episode_length, action_repeat, bit_depth) # Wrapper for batching environments together diff --git a/main.py b/main.py index b694279..fc86133 100644 --- a/main.py +++ b/main.py @@ -40,6 +40,7 @@ parser.add_argument('--overshooting-reward-scale', type=float, default=0, metavar='R>1', help='Latent overshooting reward prediction weight for t > 1 (0 to disable)') parser.add_argument('--global-kl-beta', type=float, default=0, metavar='βg', help='Global KL weight (0 to disable)') parser.add_argument('--free-nats', type=float, default=3, metavar='F', help='Free nats') +parser.add_argument('--bit-depth', type=int, default=5, metavar='B', help='Image bit depth (quantisation)') parser.add_argument('--learning-rate', type=float, default=1e-3, metavar='α', help='Learning rate') # Note that original has a linear learning rate decay, but it seems unlikely that this makes a significant difference parser.add_argument('--grad-clip-norm', type=float, default=1000, metavar='C', help='Gradient clipping norm') parser.add_argument('--planning-horizon', type=int, default=12, metavar='H', help='Planning horizon distance') @@ -48,7 +49,7 @@ parser.add_argument('--top-candidates', type=int, default=100, metavar='K', help='Number of top candidates to fit') parser.add_argument('--test-interval', type=int, default=25, metavar='I', help='Test interval (episodes)') parser.add_argument('--test-episodes', type=int, default=10, metavar='E', help='Number of test episodes') -parser.add_argument('--checkpoint-interval', type=int, default=25, metavar='I', help='Checkpoint interval (episodes)') +parser.add_argument('--checkpoint-interval', type=int, default=50, metavar='I', help='Checkpoint interval (episodes)') parser.add_argument('--checkpoint-experience', action='store_true', help='Checkpoint experience replay') parser.add_argument('--load-experience', action='store_true', help='Load experience replay (from checkpoint dir)') parser.add_argument('--load-checkpoint', type=int, default=0, metavar='E', help='Load model checkpoint (from given episode)') @@ -74,7 +75,7 @@ # Initialise training environment and experience replay memory -env = Env(args.env, args.symbolic_env, args.seed, args.max_episode_length, args.action_repeat) +env = Env(args.env, args.symbolic_env, args.seed, args.max_episode_length, args.action_repeat, args.bit_depth) if args.load_experience: D = torch.load(os.path.join('checkpoints', 'experience.pth')) metrics['steps'], metrics['episodes'] = [D.steps] * D.episodes, list(range(1, D.episodes + 1)) @@ -207,7 +208,7 @@ def update_belief_and_act(args, env, planner, transition_model, encoder, belief, reward_model.eval() encoder.eval() # Initialise parallelised test environments - test_envs = EnvBatcher(Env, (args.env, args.symbolic_env, args.seed, args.max_episode_length, args.action_repeat), {}, args.test_episodes) + test_envs = EnvBatcher(Env, (args.env, args.symbolic_env, args.seed, args.max_episode_length, args.action_repeat, args.bit_depth), {}, args.test_episodes) with torch.no_grad(): observation, total_rewards, video_frames = test_envs.reset(), np.zeros((args.test_episodes, )), []