Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Doc update: custom envs, IsaacLab, Brax and dm_control #2072

Merged
merged 3 commits into from
Jan 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/guide/custom_env.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,24 @@ That is to say, your environment must implement the following methods (and inher
Under the hood, when a channel-last image is passed, SB3 uses a ``VecTransposeImage`` wrapper to re-order the channels.


.. note::

SB3 doesn't support ``Discrete`` and ``MultiDiscrete`` spaces with ``start!=0``. However, you can update your environment or use a wrapper to make your env compatible with SB3:

.. code-block:: python

import gymnasium as gym

class ShiftWrapper(gym.Wrapper):
"""Allow to use Discrete() action spaces with start!=0"""
def __init__(self, env: gym.Env) -> None:
super().__init__(env)
assert isinstance(env.action_space, gym.spaces.Discrete)
self.action_space = gym.spaces.Discrete(env.action_space.n, start=0)

def step(self, action: int):
return self.env.step(action + self.env.action_space.start)


.. code-block:: python

Expand Down
45 changes: 22 additions & 23 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -735,41 +735,40 @@ A2C policy gradient updates on the model.
print(f"Best fitness: {top_candidates[0][1]:.2f}")


SB3 and ProcgenEnv
------------------
SB3 with Isaac Lab, Brax, Procgen, EnvPool
------------------------------------------

Some environments like `Procgen <https://github.com/openai/procgen>`_ already produce a vectorized
environment (see discussion in `issue #314 <https://github.com/DLR-RM/stable-baselines3/issues/314>`_). In order to use it with SB3, you must wrap it in a ``VecMonitor`` wrapper which will also allow
to keep track of the agent progress.
Some massively parallel simulations such as `EnvPool <https://github.com/sail-sg/envpool>`_, `Isaac Lab <https://github.com/isaac-sim/IsaacLab>`_, `Brax <https://github.com/google/brax>`_ or `ProcGen <https://github.com/Farama-Foundation/Procgen2>`_ already produce a vectorized environment to speed up data collection (see discussion in `issue #314 <https://github.com/DLR-RM/stable-baselines3/issues/314>`_).

.. code-block:: python
To use SB3 with these tools, you need to wrap the env with tool-specific ``VecEnvWrapper`` that pre-processes the data for SB3,
you can find links to some of these wrappers in `issue #772 <https://github.com/DLR-RM/stable-baselines3/issues/772#issuecomment-1048657002>`_.

from procgen import ProcgenEnv
- Isaac Lab wrapper: `link <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/utils/wrappers/sb3.py>`__
- Brax: `link <https://gist.github.com/araffin/a7a576ec1453e74d9bb93120918ef7e7>`__
- EnvPool: `link <https://github.com/sail-sg/envpool/blob/main/examples/sb3_examples/ppo.py>`__

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor

# ProcgenEnv is already vectorized
venv = ProcgenEnv(num_envs=2, env_name="starpilot")
SB3 with DeepMind Control (dm_control)
--------------------------------------

# To use only part of the observation:
# venv = VecExtractDictObs(venv, "rgb")
If you want to use SB3 with `dm_control <https://github.com/google-deepmind/dm_control>`_, you need to use two wrappers (one from `shimmy <https://github.com/Farama-Foundation/Shimmy>`_, one pre-built one) to convert it to a Gymnasium compatible environment:

# Wrap with a VecMonitor to collect stats and avoid errors
venv = VecMonitor(venv=venv)
.. code-block:: python

model = PPO("MultiInputPolicy", venv, verbose=1)
model.learn(10_000)
import shimmy
import stable_baselines3 as sb3
from dm_control import suite
from gymnasium.wrappers import FlattenObservation

# Available envs:
# suite._DOMAINS and suite.dog.SUITE

SB3 with EnvPool or Isaac Gym
-----------------------------
env = suite.load(domain_name="dog", task_name="run")
gym_env = FlattenObservation(shimmy.DmControlCompatibilityV0(env))

Just like Procgen (see above), `EnvPool <https://github.com/sail-sg/envpool>`_ and `Isaac Gym <https://github.com/NVIDIA-Omniverse/IsaacGymEnvs>`_ accelerate the environment by
already providing a vectorized implementation.
model = sb3.PPO("MlpPolicy", gym_env, verbose=1)
model.learn(10_000, progress_bar=True)

To use SB3 with those tools, you must wrap the env with tool's specific ``VecEnvWrapper`` that will pre-process the data for SB3,
you can find links to those wrappers in `issue #772 <https://github.com/DLR-RM/stable-baselines3/issues/772#issuecomment-1048657002>`_.


Record a Video
Expand Down
1 change: 1 addition & 0 deletions docs/guide/sbx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Implemented algorithms:
- Twin Delayed DDPG (TD3)
- Deep Deterministic Policy Gradient (DDPG)
- Batch Normalization in Deep Reinforcement Learning (CrossQ)
- Simplicity Bias for Scaling Up Parameters in Deep Reinforcement Learning (SimBa)


As SBX follows SB3 API, it is also compatible with the `RL Zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_.
Expand Down
3 changes: 3 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ Documentation:
- Add FootstepNet Envs to the project page (@cgaspard3333)
- Added FRASA to the project page (@MarcDcls)
- Fixed atari example (@chrisgao99)
- Add a note about ``Discrete`` action spaces with ``start!=0``
- Update doc for massively parallel simulators (Isaac Lab, Brax, ...)
- Add dm_control example

Release 2.4.1 (2024-12-20)
--------------------------
Expand Down
3 changes: 2 additions & 1 deletion stable_baselines3/common/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def _check_non_zero_start(space: spaces.Space, space_type: str = "observation",
warnings.warn(
f"{type(space).__name__} {space_type} space {maybe_key} with a non-zero start (start={space.start}) "
"is not supported by Stable-Baselines3. "
f"You can use a wrapper or update your {space_type} space."
"You can use a wrapper (see https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html) "
f"or update your {space_type} space."
)


Expand Down
2 changes: 1 addition & 1 deletion tests/test_vec_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def test_get_original():
assert not np.array_equal(orig_obs, obs)
assert not np.array_equal(orig_rewards, rewards)
np.testing.assert_allclose(venv.normalize_obs(orig_obs), obs)
np.testing.assert_allclose(venv.normalize_reward(orig_rewards), rewards)
np.testing.assert_allclose(venv.normalize_reward(orig_rewards), rewards, atol=1e-6)


def test_get_original_dict():
Expand Down
Loading