From 296e7a047eb573c4bafc19ad39e527f95ff9d6b4 Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Wed, 29 May 2024 10:45:58 +0100 Subject: [PATCH] Add `has_wrapper_attr` (#1070) --- gymnasium/core.py | 11 +++++++++++ gymnasium/utils/play.py | 14 +++++--------- tests/test_core.py | 7 ++++++- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/gymnasium/core.py b/gymnasium/core.py index 54bc6c28a..ab4d44860 100644 --- a/gymnasium/core.py +++ b/gymnasium/core.py @@ -264,6 +264,10 @@ def __exit__(self, *args: Any): # propagate exception return False + def has_wrapper_attr(self, name: str) -> bool: + """Checks if the attribute `name` exists in the environment.""" + return hasattr(self, name) + def get_wrapper_attr(self, name: str) -> Any: """Gets the attribute `name` from the environment.""" return getattr(self, name) @@ -392,6 +396,13 @@ def wrapper_spec(cls, **kwargs: Any) -> WrapperSpec: kwargs=kwargs, ) + def has_wrapper_attr(self, name: str) -> bool: + """Checks if the given attribute is within the wrapper or its environment.""" + if hasattr(self, name): + return True + else: + return self.env.has_wrapper_attr(name) + def get_wrapper_attr(self, name: str) -> Any: """Gets an attribute from the wrapper and lower environments if `name` doesn't exist in this object. diff --git a/gymnasium/utils/play.py b/gymnasium/utils/play.py index 50678f9c1..8c414e79e 100644 --- a/gymnasium/utils/play.py +++ b/gymnasium/utils/play.py @@ -70,15 +70,13 @@ def _get_relevant_keys( self, keys_to_action: dict[tuple[int], int] | None = None ) -> set: if keys_to_action is None: - if hasattr(self.env, "get_keys_to_action"): - keys_to_action = self.env.get_keys_to_action() - elif hasattr(self.env.unwrapped, "get_keys_to_action"): - keys_to_action = self.env.unwrapped.get_keys_to_action() + if self.env.has_wrapper_attr("get_keys_to_action"): + keys_to_action = self.env.get_wrapper_attr("get_keys_to_action")() else: assert self.env.spec is not None raise MissingKeysToAction( f"{self.env.spec.id} does not have explicit key to action mapping, " - "please specify one manually" + "please specify one manually, `play(env, keys_to_action=...)`" ) assert isinstance(keys_to_action, dict) relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), [])) @@ -244,10 +242,8 @@ def play( env.reset(seed=seed) if keys_to_action is None: - if hasattr(env, "get_keys_to_action"): - keys_to_action = env.get_keys_to_action() - elif hasattr(env.unwrapped, "get_keys_to_action"): - keys_to_action = env.unwrapped.get_keys_to_action() + if env.has_wrapper_attr("get_keys_to_action"): + keys_to_action = env.get_wrapper_attr("get_keys_to_action")() else: assert env.spec is not None raise MissingKeysToAction( diff --git a/tests/test_core.py b/tests/test_core.py index 374bf4bb9..9ce032bce 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -168,22 +168,25 @@ def test_reward_observation_action_wrapper(): def test_get_set_wrapper_attr(): env = gym.make("CartPole-v1") + assert env is not env.unwrapped # Test get_wrapper_attr with pytest.raises(AttributeError): env.gravity assert env.unwrapped.gravity is not None + assert env.has_wrapper_attr("gravity") assert env.get_wrapper_attr("gravity") is not None with pytest.raises(AttributeError): env.unknown_attr + assert env.has_wrapper_attr("unknown_attr") is False with pytest.raises(AttributeError): env.get_wrapper_attr("unknown_attr") # Test set_wrapper_attr env.set_wrapper_attr("gravity", 10.0) with pytest.raises(AttributeError): - env.gravity + env.gravity # checks the top level wrapper hasn't been updated assert env.unwrapped.gravity == 10.0 assert env.get_wrapper_attr("gravity") == 10.0 @@ -195,10 +198,12 @@ def test_get_set_wrapper_attr(): # Test with OrderEnforcing (intermediate wrapper) assert not isinstance(env, OrderEnforcing) + # show that the base and top level objects don't contain the attribute with pytest.raises(AttributeError): env._disable_render_order_enforcing with pytest.raises(AttributeError): env.unwrapped._disable_render_order_enforcing + assert env.has_wrapper_attr("_disable_render_order_enforcing") assert env.get_wrapper_attr("_disable_render_order_enforcing") is False env.set_wrapper_attr("_disable_render_order_enforcing", True)