Skip to content

Commit

Permalink
Add has_attr for VecEnv
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Jan 30, 2025
1 parent ee8a77d commit 6f55734
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 11 deletions.
39 changes: 30 additions & 9 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,16 @@
Changelog
==========

Release 2.5.0 (2025-01-27)
Release 2.6.0a0 (WIP)
--------------------------

**New algorithm: SimBa in SBX, NumPy 2.0 support**


Breaking Changes:
^^^^^^^^^^^^^^^^^
- Increased minimum required version of PyTorch to 2.3.0
- Removed support for Python 3.8

New Features:
^^^^^^^^^^^^^
- Added support for NumPy v2.0: ``VecNormalize`` now cast normalized rewards to float32, updated bit flipping env to avoid overflow issues too
- Added official support for Python 3.12
- Added ``has_attr`` method for ``VecEnv`` to check if an attribute exists

Bug Fixes:
^^^^^^^^^^
Expand All @@ -30,12 +25,38 @@ Bug Fixes:

`SBX`_ (SB3 + Jax)
^^^^^^^^^^^^^^^^^^
- Added SimBa Policy: Simplicity Bias for Scaling Up Parameters in DRL
- Added support for parameter resets

Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^

Documentation:
^^^^^^^^^^^^^^


Release 2.5.0 (2025-01-27)
--------------------------

**New algorithm: SimBa in SBX, NumPy 2.0 support**


Breaking Changes:
^^^^^^^^^^^^^^^^^
- Increased minimum required version of PyTorch to 2.3.0
- Removed support for Python 3.8

New Features:
^^^^^^^^^^^^^
- Added support for NumPy v2.0: ``VecNormalize`` now cast normalized rewards to float32, updated bit flipping env to avoid overflow issues too
- Added official support for Python 3.12

`SBX`_ (SB3 + Jax)
^^^^^^^^^^^^^^^^^^
- Added SimBa Policy: Simplicity Bias for Scaling Up Parameters in DRL
- Added support for parameter resets

Others:
^^^^^^^
- Updated Dockerfile
Expand Down
18 changes: 18 additions & 0 deletions stable_baselines3/common/vec_env/base_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,21 @@ def close(self) -> None:
"""
raise NotImplementedError()

def has_attr(self, attr_name: str) -> bool:
"""
Check if an attribute exists for a vectorized environment.
:param attr_name: The name of the attribute to check
:return: True if 'attr_name' exists in all environments
"""
# Default implementation, will not work with things that cannot be pickled:
# https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/49
try:
self.get_attr(attr_name)
return True
except AttributeError:
return False

@abstractmethod
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
"""
Expand Down Expand Up @@ -392,6 +407,9 @@ def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]:
def get_images(self) -> Sequence[Optional[np.ndarray]]:
return self.venv.get_images()

def has_attr(self, attr_name: str) -> bool:
return self.venv.has_attr(attr_name)

def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
return self.venv.get_attr(attr_name, indices)

Expand Down
17 changes: 16 additions & 1 deletion stable_baselines3/common/vec_env/subproc_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from stable_baselines3.common.vec_env.patch_gym import _patch_env


def _worker(
def _worker( # noqa: C901
remote: mp.connection.Connection,
parent_remote: mp.connection.Connection,
env_fn_wrapper: CloudpickleWrapper,
Expand Down Expand Up @@ -58,6 +58,12 @@ def _worker(
remote.send(method(*data[1], **data[2]))
elif cmd == "get_attr":
remote.send(env.get_wrapper_attr(data))
elif cmd == "has_attr":
try:
env.get_wrapper_attr(data)
remote.send(True)
except AttributeError:
remote.send(False)
elif cmd == "set_attr":
remote.send(setattr(env, data[0], data[1])) # type: ignore[func-returns-value]
elif cmd == "is_wrapped":
Expand All @@ -66,6 +72,8 @@ def _worker(
raise NotImplementedError(f"`{cmd}` is not implemented in the worker")
except EOFError:
break
except KeyboardInterrupt:
break


class SubprocVecEnv(VecEnv):
Expand Down Expand Up @@ -165,6 +173,13 @@ def get_images(self) -> Sequence[Optional[np.ndarray]]:
outputs = [pipe.recv() for pipe in self.remotes]
return outputs

def has_attr(self, attr_name: str) -> bool:
"""Check if an attribute exists for a vectorized environment. (see base class)."""
target_remotes = self._get_target_remotes(indices=None)
for remote in target_remotes:
remote.send(("has_attr", attr_name))
return all([remote.recv() for remote in target_remotes])

def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
"""Return attribute from vectorized environment (see base class)."""
target_remotes = self._get_target_remotes(indices)
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.5.0
2.6.0a0
16 changes: 16 additions & 0 deletions tests/test_vec_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,28 @@ def make_env():
# we need a X server to test the "human" mode (uses OpenCV)
# vec_env.render(mode="human")

# Set a new attribute, on the last wrapper and on the env
assert not vec_env.has_attr("dummy")
# Set value for the last wrapper only
vec_env.set_attr("dummy", 12)
assert vec_env.get_attr("dummy") == [12] * N_ENVS
if vec_env_class == DummyVecEnv:
assert vec_env.envs[0].dummy == 12

assert not vec_env.has_attr("dummy2")
# Set the value on the original env
vec_env.env_method("set_wrapper_attr", "dummy2", 2)
assert vec_env.get_attr("dummy2") == [2] * N_ENVS
if vec_env_class == DummyVecEnv:
assert vec_env.envs[0].unwrapped.dummy2 == 2

env_method_results = vec_env.env_method("custom_method", 1, indices=None, dim_1=2)
setattr_results = []
# Set current_step to an arbitrary value
for env_idx in range(N_ENVS):
setattr_results.append(vec_env.set_attr("current_step", env_idx, indices=env_idx))
# Retrieve the value for each environment
assert vec_env.has_attr("current_step")
getattr_results = vec_env.get_attr("current_step")

assert len(env_method_results) == N_ENVS
Expand Down

0 comments on commit 6f55734

Please sign in to comment.