Skip to content

Commit

Permalink
Fix Frozen Lake integration test
Browse files Browse the repository at this point in the history
Summary:
This diff fixes a Frozen Lake integration test.
In the process, we fix `OneHotObservationsFromDiscrete` and eliminate an unused and not useful environment adapter, `BoxObservationsFromDiscrete`.

Reviewed By: yiwan-rl

Differential Revision: D53559195

fbshipit-source-id: 88e983b05993f76eeb1f84d2a75bd99597fb13d7
  • Loading branch information
rodrigodesalvobraz authored and facebook-github-bot committed Feb 8, 2024
1 parent 07c8e0c commit 6014cdd
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 75 deletions.
51 changes: 10 additions & 41 deletions pearl/utils/instantiations/environments/environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
try:
import gymnasium as gym
except ModuleNotFoundError:
import gym
import gym # noqa

import numpy as np
import torch
import torch.nn.functional as F
from pearl.api.action import Action
Expand Down Expand Up @@ -60,11 +59,9 @@ def __str__(self) -> str:
return type(self).__name__


class BoxObservationsEnvironmentBase(Environment, ABC):
class ObservationTransformationEnvironmentAdapterBase(Environment, ABC):
"""
An environment adapter mapping a Discrete observation space into
a Box observation space with dimension 1.
This is useful to use with agents expecting tensor observations.
A base for environment adapters tranforming observations.
"""

def __init__(
Expand Down Expand Up @@ -106,33 +103,7 @@ def short_description(self) -> str:
return self.__class__.__name__


class BoxObservationsFromDiscrete(BoxObservationsEnvironmentBase):
"""
An environment adapter mapping a Discrete observation space into
a Box observation space with dimension 1.
The observations are tensors of length 1 containing the original observations.
This is useful to use with agents expecting tensor observations.
"""

def __init__(self, base_environment: Environment) -> None:
super(BoxObservationsFromDiscrete, self).__init__(base_environment)

@staticmethod
def make_observation_space(base_environment: Environment) -> Space:
low_action = np.array([0])
# pyre-fixme: need to add this property in Environment
# and implement it in all concrete subclasses
assert isinstance(base_environment.observation_space, DiscreteSpace)
high_action = np.array([base_environment.observation_space.n - 1])
# pyre-fixme: returning Gym Box but needs to return Pearl Space
return gym.spaces.Box(low=low_action, high=high_action, shape=(1,))

def compute_tensor_observation(self, observation: Observation) -> torch.Tensor:
return torch.tensor([observation])


class OneHotObservationsFromDiscrete(BoxObservationsEnvironmentBase):
class OneHotObservationsFromDiscrete(ObservationTransformationEnvironmentAdapterBase):
"""
An environment adapter mapping a Discrete observation space into
a Box observation space with dimension 1
Expand All @@ -146,28 +117,26 @@ def __init__(self, base_environment: Environment) -> None:

@staticmethod
def make_observation_space(base_environment: Environment) -> Space:
# pyre-fixme: need to add this property in Environment
# pyre-fixme: need to add `observation_space` property in Environment
# and implement it in all concrete subclasses
assert isinstance(base_environment.observation_space, DiscreteSpace)
n = base_environment.observation_space.n
low = np.full((n,), 0)
high = np.full((n,), 1)
# pyre-fixme: returning Gym Box but needs to return Pearl Space
return gym.spaces.Box(low=low, high=high, shape=(n,))
elements = [F.one_hot(torch.tensor(i), n).float() for i in range(n)]
return DiscreteSpace(elements)

def compute_tensor_observation(self, observation: Observation) -> torch.Tensor:
if isinstance(observation, torch.Tensor):
observation_tensor = observation
else:
observation_tensor = torch.tensor(observation)
# pyre-fixme: need to add this property in Environment
# pyre-fixme: need to add `observation_space` property in Environment
# and implement it in all concrete subclasses
assert isinstance(self.base_environment.observation_space, DiscreteSpace)
return F.one_hot(
observation_tensor,
self.base_environment.observation_space.n,
)
).float()

@property
def short_description(self) -> str:
return "One-hot observations"
return f"One-hot observations on {self.base_environment}"
81 changes: 47 additions & 34 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,17 @@
# LICENSE file in the root directory of this source tree.
#

from pearl.utils.instantiations.spaces.discrete import DiscreteSpace

try:
import gymnasium as gym
except ModuleNotFoundError:
import gym # noqa

import unittest

from gym.envs.toy_text.frozen_lake import generate_random_map

from pearl.action_representation_modules.one_hot_action_representation_module import (
OneHotActionTensorRepresentationModule,
)
Expand Down Expand Up @@ -72,6 +81,10 @@
target_return_is_reached,
)

from pearl.utils.instantiations.environments.environments import (
OneHotObservationsFromDiscrete,
)

from pearl.utils.instantiations.environments.gym_environment import GymEnvironment
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace

Expand Down Expand Up @@ -113,40 +126,40 @@ def test_dqn(self) -> None:
)
)

# def test_dqn_on_frozen_lake(self) -> None:
# """
# This test is checking if DQN will eventually solve FrozenLake-v1
# whose observations need to be wrapped in a one-hot representation.
# """
# # TODO: flaky: sometimes not even 5,000 episodes is enough for learning
# # Need to debug.
#
# environment = OneHotObservationsFromDiscrete(
# GymEnvironment("FrozenLake-v1", is_slippery=False)
# )
# state_dim = environment.observation_space.shape[0]
# agent = PearlAgent(
# policy_learner=DeepQLearning(
# state_dim=state_dim,
# action_space=environment.action_space,
# hidden_dims=[state_dim // 2, state_dim // 2],
# training_rounds=40,
# ),
# replay_buffer=FIFOOffPolicyReplayBuffer(1000),
# )

# self.assertTrue(
# target_return_is_reached(
# target_return=1.0,
# required_target_returns_in_a_row=5,
# max_episodes=1000,
# agent=agent,
# env=environment,
# learn=True,
# learn_after_episode=True,
# exploit=False,
# )
# )
def test_dqn_on_frozen_lake(self) -> None:
"""
This test is checking if DQN will eventually solve FrozenLake-v1
whose observations need to be wrapped in a one-hot representation.
"""
environment = OneHotObservationsFromDiscrete(
GymEnvironment(
"FrozenLake-v1", is_slippery=False, desc=generate_random_map(size=4)
)
)
assert isinstance(environment.action_space, DiscreteSpace)
state_dim = environment.observation_space.n
agent = PearlAgent(
policy_learner=DeepQLearning(
state_dim=state_dim,
action_space=environment.action_space,
hidden_dims=[64],
training_rounds=20,
),
replay_buffer=FIFOOffPolicyReplayBuffer(1000),
)

self.assertTrue(
target_return_is_reached(
target_return=1.0,
required_target_returns_in_a_row=5,
max_episodes=1000,
agent=agent,
env=environment,
learn=True,
learn_after_episode=True,
exploit=False,
)
)

def test_double_dqn(self) -> None:
"""
Expand Down

0 comments on commit 6014cdd

Please sign in to comment.