From b8b2d30a8399e1448d1f4c1264c343875727f053 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 3 Feb 2025 10:43:56 +0100 Subject: [PATCH] Add `has_attr` for `VecEnv` (#2077) * Add `has_attr` for `VecEnv` * Add special case for gymnasium<1.0 * Update changelog.rst * Update black version --- docs/misc/changelog.rst | 41 +++++++++++++++---- setup.py | 2 +- .../common/vec_env/base_vec_env.py | 18 ++++++++ .../common/vec_env/subproc_vec_env.py | 17 +++++++- stable_baselines3/version.txt | 2 +- tests/test_vec_envs.py | 18 ++++++++ 6 files changed, 86 insertions(+), 12 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index cf0db00fa..c7967a1e1 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,24 +3,20 @@ Changelog ========== -Release 2.5.0 (2025-01-27) +Release 2.6.0a0 (WIP) -------------------------- -**New algorithm: SimBa in SBX, NumPy 2.0 support** - Breaking Changes: ^^^^^^^^^^^^^^^^^ -- Increased minimum required version of PyTorch to 2.3.0 -- Removed support for Python 3.8 New Features: ^^^^^^^^^^^^^ -- Added support for NumPy v2.0: ``VecNormalize`` now cast normalized rewards to float32, updated bit flipping env to avoid overflow issues too -- Added official support for Python 3.12 +- Added ``has_attr`` method for ``VecEnv`` to check if an attribute exists Bug Fixes: ^^^^^^^^^^ +- `SubProcVecEnv` will now exit gracefully (without big traceback) when using `KeyboardInterrupt` `SB3-Contrib`_ ^^^^^^^^^^^^^^ @@ -30,12 +26,39 @@ Bug Fixes: `SBX`_ (SB3 + Jax) ^^^^^^^^^^^^^^^^^^ -- Added SimBa Policy: Simplicity Bias for Scaling Up Parameters in DRL -- Added support for parameter resets Deprecations: ^^^^^^^^^^^^^ +Others: +^^^^^^^ +- Updated black from v24 to v25 + +Documentation: +^^^^^^^^^^^^^^ + + +Release 2.5.0 (2025-01-27) +-------------------------- + +**New algorithm: SimBa in SBX, NumPy 2.0 support** + + +Breaking Changes: +^^^^^^^^^^^^^^^^^ +- Increased minimum required version of PyTorch to 2.3.0 +- Removed support for Python 3.8 + +New Features: +^^^^^^^^^^^^^ +- Added support for NumPy v2.0: ``VecNormalize`` now cast normalized rewards to float32, updated bit flipping env to avoid overflow issues too +- Added official support for Python 3.12 + +`SBX`_ (SB3 + Jax) +^^^^^^^^^^^^^^^^^^ +- Added SimBa Policy: Simplicity Bias for Scaling Up Parameters in DRL +- Added support for parameter resets + Others: ^^^^^^^ - Updated Dockerfile diff --git a/setup.py b/setup.py index fa24fc8a3..8123cf43a 100644 --- a/setup.py +++ b/setup.py @@ -98,7 +98,7 @@ # Lint code and sort imports (flake8 and isort replacement) "ruff>=0.3.1", # Reformat - "black>=24.2.0,<25", + "black>=25.1.0,<26", ], "docs": [ "sphinx>=5,<9", diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 71ee15e61..370113108 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -147,6 +147,21 @@ def close(self) -> None: """ raise NotImplementedError() + def has_attr(self, attr_name: str) -> bool: + """ + Check if an attribute exists for a vectorized environment. + + :param attr_name: The name of the attribute to check + :return: True if 'attr_name' exists in all environments + """ + # Default implementation, will not work with things that cannot be pickled: + # https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/49 + try: + self.get_attr(attr_name) + return True + except AttributeError: + return False + @abstractmethod def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]: """ @@ -392,6 +407,9 @@ def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: def get_images(self) -> Sequence[Optional[np.ndarray]]: return self.venv.get_images() + def has_attr(self, attr_name: str) -> bool: + return self.venv.has_attr(attr_name) + def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]: return self.venv.get_attr(attr_name, indices) diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 225eadd79..1563d70b1 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -17,7 +17,7 @@ from stable_baselines3.common.vec_env.patch_gym import _patch_env -def _worker( +def _worker( # noqa: C901 remote: mp.connection.Connection, parent_remote: mp.connection.Connection, env_fn_wrapper: CloudpickleWrapper, @@ -58,6 +58,12 @@ def _worker( remote.send(method(*data[1], **data[2])) elif cmd == "get_attr": remote.send(env.get_wrapper_attr(data)) + elif cmd == "has_attr": + try: + env.get_wrapper_attr(data) + remote.send(True) + except AttributeError: + remote.send(False) elif cmd == "set_attr": remote.send(setattr(env, data[0], data[1])) # type: ignore[func-returns-value] elif cmd == "is_wrapped": @@ -66,6 +72,8 @@ def _worker( raise NotImplementedError(f"`{cmd}` is not implemented in the worker") except EOFError: break + except KeyboardInterrupt: + break class SubprocVecEnv(VecEnv): @@ -165,6 +173,13 @@ def get_images(self) -> Sequence[Optional[np.ndarray]]: outputs = [pipe.recv() for pipe in self.remotes] return outputs + def has_attr(self, attr_name: str) -> bool: + """Check if an attribute exists for a vectorized environment. (see base class).""" + target_remotes = self._get_target_remotes(indices=None) + for remote in target_remotes: + remote.send(("has_attr", attr_name)) + return all([remote.recv() for remote in target_remotes]) + def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]: """Return attribute from vectorized environment (see base class).""" target_remotes = self._get_target_remotes(indices) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 437459cd9..3d87ca93f 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.5.0 +2.6.0a0 diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 7e4e5ec0b..43a693ddd 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -123,12 +123,30 @@ def make_env(): # we need a X server to test the "human" mode (uses OpenCV) # vec_env.render(mode="human") + # Set a new attribute, on the last wrapper and on the env + assert not vec_env.has_attr("dummy") + # Set value for the last wrapper only + vec_env.set_attr("dummy", 12) + assert vec_env.get_attr("dummy") == [12] * N_ENVS + if vec_env_class == DummyVecEnv: + assert vec_env.envs[0].dummy == 12 + + assert not vec_env.has_attr("dummy2") + # Set the value on the original env + # `set_wrapper_attr` doesn't exist before v1.0 + if gym.__version__ > "1": + vec_env.env_method("set_wrapper_attr", "dummy2", 2) + assert vec_env.get_attr("dummy2") == [2] * N_ENVS + if vec_env_class == DummyVecEnv: + assert vec_env.envs[0].unwrapped.dummy2 == 2 + env_method_results = vec_env.env_method("custom_method", 1, indices=None, dim_1=2) setattr_results = [] # Set current_step to an arbitrary value for env_idx in range(N_ENVS): setattr_results.append(vec_env.set_attr("current_step", env_idx, indices=env_idx)) # Retrieve the value for each environment + assert vec_env.has_attr("current_step") getattr_results = vec_env.get_attr("current_step") assert len(env_method_results) == N_ENVS