From 60cd10446bef9d80f9f729c5dc0b0f9d89562d08 Mon Sep 17 00:00:00 2001 From: kurtamohler Date: Wed, 4 Sep 2024 01:18:22 -0700 Subject: [PATCH] [BugFix] Fix support for MiniGrid envs (#2416) --- .../linux_libs/scripts_gym/environment.yml | 1 + test/test_libs.py | 28 ++++++++++++++++ torchrl/envs/gym_like.py | 13 +++++--- torchrl/envs/libs/gym.py | 33 ++++++++++++++----- 4 files changed, 62 insertions(+), 13 deletions(-) diff --git a/.github/unittest/linux_libs/scripts_gym/environment.yml b/.github/unittest/linux_libs/scripts_gym/environment.yml index d30aa6d0f91..4c0c9269479 100644 --- a/.github/unittest/linux_libs/scripts_gym/environment.yml +++ b/.github/unittest/linux_libs/scripts_gym/environment.yml @@ -7,6 +7,7 @@ dependencies: - pip: # Initial version is required to install Atari ROMS in setup_env.sh - gym[atari]==0.13 + - minigrid - hypothesis - future - cloudpickle diff --git a/test/test_libs.py b/test/test_libs.py index cb551473690..6f5cc1bebeb 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -153,6 +153,16 @@ _has_meltingpot = importlib.util.find_spec("meltingpot") is not None +_has_minigrid = importlib.util.find_spec("minigrid") is not None + + +@pytest.fixture(scope="session", autouse=True) +def maybe_init_minigrid(): + if _has_minigrid and _has_gymnasium: + import minigrid + + minigrid.register_minigrid_envs() + def get_gym_pixel_wrapper(): try: @@ -1279,6 +1289,24 @@ def test_resetting_strategies(self, heterogeneous): gc.collect() +@pytest.mark.skipif( + not _has_minigrid or not _has_gymnasium, reason="MiniGrid not found" +) +class TestMiniGrid: + @pytest.mark.parametrize( + "id", + [ + "BabyAI-KeyCorridorS6R3-v0", + "MiniGrid-Empty-16x16-v0", + "MiniGrid-BlockedUnlockPickup-v0", + ], + ) + def test_minigrid(self, id): + env_base = gymnasium.make(id) + env = GymWrapper(env_base) + check_env_specs(env) + + @implement_for("gym", None, "0.26") def _make_gym_environment(env_name): # noqa: F811 gym = gym_backend() diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 9092d419075..995f245a8ac 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -12,10 +12,10 @@ import numpy as np import torch -from tensordict import TensorDict, TensorDictBase +from tensordict import NonTensorData, TensorDict, TensorDictBase from torchrl._utils import logger as torchrl_logger -from torchrl.data.tensor_specs import Composite, TensorSpec, Unbounded +from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded from torchrl.envs.common import _EnvWrapper, EnvBase @@ -283,9 +283,12 @@ def read_obs( observations = observations_dict else: for key, val in observations.items(): - observations[key] = self.observation_spec[key].encode( - val, ignore_device=True - ) + if isinstance(self.observation_spec[key], NonTensor): + observations[key] = NonTensorData(val) + else: + observations[key] = self.observation_spec[key].encode( + val, ignore_device=True + ) return observations def _step(self, tensordict: TensorDictBase) -> TensorDictBase: diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 34af87b75f9..a82286659cb 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -29,6 +29,7 @@ Composite, MultiCategorical, MultiOneHot, + NonTensor, OneHot, TensorSpec, Unbounded, @@ -55,6 +56,14 @@ _has_mo = importlib.util.find_spec("mo_gymnasium") is not None _has_sb3 = importlib.util.find_spec("stable_baselines3") is not None +_has_minigrid = importlib.util.find_spec("minigrid") is not None + + +def _minigrid_lib(): + assert _has_minigrid, "minigrid not found" + import minigrid + + return minigrid class set_gym_backend(_DecoratorContextManager): @@ -369,6 +378,8 @@ def _gym_to_torchrl_spec_transform( categorical_action_encoding=categorical_action_encoding, remap_state_to_observation=remap_state_to_observation, ) + elif _has_minigrid and isinstance(spec, _minigrid_lib().core.mission.MissionSpace): + return NonTensor((), device=device) else: raise NotImplementedError( f"spec of type {type(spec).__name__} is currently unaccounted for" @@ -766,14 +777,20 @@ def __init__(self, env=None, categorical_action_encoding=False, **kwargs): self._seed_calls_reset = None self._categorical_action_encoding = categorical_action_encoding if env is not None: - if "EnvCompatibility" in str( - env - ): # a hacky way of knowing if EnvCompatibility is part of the wrappers of env - raise ValueError( - "GymWrapper does not support the gym.wrapper.compatibility.EnvCompatibility wrapper. " - "If this feature is needed, detail your use case in an issue of " - "https://github.com/pytorch/rl/issues." - ) + try: + env_str = str(env) + except TypeError: + # MiniGrid has a bug where the __str__ method fails + pass + else: + if ( + "EnvCompatibility" in env_str + ): # a hacky way of knowing if EnvCompatibility is part of the wrappers of env + raise ValueError( + "GymWrapper does not support the gym.wrapper.compatibility.EnvCompatibility wrapper. " + "If this feature is needed, detail your use case in an issue of " + "https://github.com/pytorch/rl/issues." + ) libname = self.get_library_name(env) with set_gym_backend(libname): kwargs["env"] = env