From 6014cddcf605c3ffceb1a562d2231592541d7605 Mon Sep 17 00:00:00 2001 From: Rodrigo de Salvo Braz Date: Thu, 8 Feb 2024 11:54:17 -0800 Subject: [PATCH] Fix Frozen Lake integration test Summary: This diff fixes a Frozen Lake integration test. In the process, we fix `OneHotObservationsFromDiscrete` and eliminate an unused and not useful environment adapter, `BoxObservationsFromDiscrete`. Reviewed By: yiwan-rl Differential Revision: D53559195 fbshipit-source-id: 88e983b05993f76eeb1f84d2a75bd99597fb13d7 --- .../environments/environments.py | 51 +++--------- test/integration/test_integration.py | 81 +++++++++++-------- 2 files changed, 57 insertions(+), 75 deletions(-) diff --git a/pearl/utils/instantiations/environments/environments.py b/pearl/utils/instantiations/environments/environments.py index 9d2d3c9b..a75ad7e4 100644 --- a/pearl/utils/instantiations/environments/environments.py +++ b/pearl/utils/instantiations/environments/environments.py @@ -17,9 +17,8 @@ try: import gymnasium as gym except ModuleNotFoundError: - import gym + import gym # noqa -import numpy as np import torch import torch.nn.functional as F from pearl.api.action import Action @@ -60,11 +59,9 @@ def __str__(self) -> str: return type(self).__name__ -class BoxObservationsEnvironmentBase(Environment, ABC): +class ObservationTransformationEnvironmentAdapterBase(Environment, ABC): """ - An environment adapter mapping a Discrete observation space into - a Box observation space with dimension 1. - This is useful to use with agents expecting tensor observations. + A base for environment adapters tranforming observations. """ def __init__( @@ -106,33 +103,7 @@ def short_description(self) -> str: return self.__class__.__name__ -class BoxObservationsFromDiscrete(BoxObservationsEnvironmentBase): - """ - An environment adapter mapping a Discrete observation space into - a Box observation space with dimension 1. - The observations are tensors of length 1 containing the original observations. - - This is useful to use with agents expecting tensor observations. - """ - - def __init__(self, base_environment: Environment) -> None: - super(BoxObservationsFromDiscrete, self).__init__(base_environment) - - @staticmethod - def make_observation_space(base_environment: Environment) -> Space: - low_action = np.array([0]) - # pyre-fixme: need to add this property in Environment - # and implement it in all concrete subclasses - assert isinstance(base_environment.observation_space, DiscreteSpace) - high_action = np.array([base_environment.observation_space.n - 1]) - # pyre-fixme: returning Gym Box but needs to return Pearl Space - return gym.spaces.Box(low=low_action, high=high_action, shape=(1,)) - - def compute_tensor_observation(self, observation: Observation) -> torch.Tensor: - return torch.tensor([observation]) - - -class OneHotObservationsFromDiscrete(BoxObservationsEnvironmentBase): +class OneHotObservationsFromDiscrete(ObservationTransformationEnvironmentAdapterBase): """ An environment adapter mapping a Discrete observation space into a Box observation space with dimension 1 @@ -146,28 +117,26 @@ def __init__(self, base_environment: Environment) -> None: @staticmethod def make_observation_space(base_environment: Environment) -> Space: - # pyre-fixme: need to add this property in Environment + # pyre-fixme: need to add `observation_space` property in Environment # and implement it in all concrete subclasses assert isinstance(base_environment.observation_space, DiscreteSpace) n = base_environment.observation_space.n - low = np.full((n,), 0) - high = np.full((n,), 1) - # pyre-fixme: returning Gym Box but needs to return Pearl Space - return gym.spaces.Box(low=low, high=high, shape=(n,)) + elements = [F.one_hot(torch.tensor(i), n).float() for i in range(n)] + return DiscreteSpace(elements) def compute_tensor_observation(self, observation: Observation) -> torch.Tensor: if isinstance(observation, torch.Tensor): observation_tensor = observation else: observation_tensor = torch.tensor(observation) - # pyre-fixme: need to add this property in Environment + # pyre-fixme: need to add `observation_space` property in Environment # and implement it in all concrete subclasses assert isinstance(self.base_environment.observation_space, DiscreteSpace) return F.one_hot( observation_tensor, self.base_environment.observation_space.n, - ) + ).float() @property def short_description(self) -> str: - return "One-hot observations" + return f"One-hot observations on {self.base_environment}" diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 8a68030e..1df416f1 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -5,8 +5,17 @@ # LICENSE file in the root directory of this source tree. # +from pearl.utils.instantiations.spaces.discrete import DiscreteSpace + +try: + import gymnasium as gym +except ModuleNotFoundError: + import gym # noqa + import unittest +from gym.envs.toy_text.frozen_lake import generate_random_map + from pearl.action_representation_modules.one_hot_action_representation_module import ( OneHotActionTensorRepresentationModule, ) @@ -72,6 +81,10 @@ target_return_is_reached, ) +from pearl.utils.instantiations.environments.environments import ( + OneHotObservationsFromDiscrete, +) + from pearl.utils.instantiations.environments.gym_environment import GymEnvironment from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace @@ -113,40 +126,40 @@ def test_dqn(self) -> None: ) ) - # def test_dqn_on_frozen_lake(self) -> None: - # """ - # This test is checking if DQN will eventually solve FrozenLake-v1 - # whose observations need to be wrapped in a one-hot representation. - # """ - # # TODO: flaky: sometimes not even 5,000 episodes is enough for learning - # # Need to debug. - # - # environment = OneHotObservationsFromDiscrete( - # GymEnvironment("FrozenLake-v1", is_slippery=False) - # ) - # state_dim = environment.observation_space.shape[0] - # agent = PearlAgent( - # policy_learner=DeepQLearning( - # state_dim=state_dim, - # action_space=environment.action_space, - # hidden_dims=[state_dim // 2, state_dim // 2], - # training_rounds=40, - # ), - # replay_buffer=FIFOOffPolicyReplayBuffer(1000), - # ) - - # self.assertTrue( - # target_return_is_reached( - # target_return=1.0, - # required_target_returns_in_a_row=5, - # max_episodes=1000, - # agent=agent, - # env=environment, - # learn=True, - # learn_after_episode=True, - # exploit=False, - # ) - # ) + def test_dqn_on_frozen_lake(self) -> None: + """ + This test is checking if DQN will eventually solve FrozenLake-v1 + whose observations need to be wrapped in a one-hot representation. + """ + environment = OneHotObservationsFromDiscrete( + GymEnvironment( + "FrozenLake-v1", is_slippery=False, desc=generate_random_map(size=4) + ) + ) + assert isinstance(environment.action_space, DiscreteSpace) + state_dim = environment.observation_space.n + agent = PearlAgent( + policy_learner=DeepQLearning( + state_dim=state_dim, + action_space=environment.action_space, + hidden_dims=[64], + training_rounds=20, + ), + replay_buffer=FIFOOffPolicyReplayBuffer(1000), + ) + + self.assertTrue( + target_return_is_reached( + target_return=1.0, + required_target_returns_in_a_row=5, + max_episodes=1000, + agent=agent, + env=environment, + learn=True, + learn_after_episode=True, + exploit=False, + ) + ) def test_double_dqn(self) -> None: """