Skip to content

[Feature] IsaacLab wrapper #2937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 63 commits into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
ee4865a
Update
vmoens May 2, 2025
836c854
Update
vmoens May 6, 2025
321157f
Update
vmoens May 6, 2025
650d34e
Update
vmoens May 6, 2025
45f1bf2
Update
vmoens May 6, 2025
284cffd
Update
vmoens May 6, 2025
2e5de6d
Update
vmoens May 6, 2025
6e8d486
Update
vmoens May 6, 2025
1631ba4
Update
vmoens May 6, 2025
bc423b6
Update
vmoens May 6, 2025
b594d62
Update
vmoens May 6, 2025
56d58df
Update
vmoens May 6, 2025
e34fc05
Update
vmoens May 6, 2025
8008d88
Update
vmoens May 6, 2025
c59802e
Update
vmoens May 6, 2025
6d4d7ea
Update
vmoens May 6, 2025
828c1bd
Update
vmoens May 6, 2025
f103e8c
Update
vmoens May 6, 2025
30e244c
Update
vmoens May 6, 2025
29bc3ac
Update
vmoens May 6, 2025
524c342
Update
vmoens May 6, 2025
67f3b8a
Update
vmoens May 6, 2025
2f4544d
Update
vmoens May 6, 2025
788a4cd
Update
vmoens May 6, 2025
b5ec158
Update
vmoens May 7, 2025
a378b9b
Update
vmoens May 7, 2025
3d197a8
Update
vmoens May 7, 2025
22192ac
Update
vmoens May 7, 2025
3640654
Update
vmoens May 7, 2025
962ddfc
Update
vmoens May 7, 2025
2b2625b
Update
vmoens May 7, 2025
dcbbb4c
Update
vmoens May 7, 2025
295a828
Update
vmoens May 7, 2025
98d480e
Update
vmoens May 7, 2025
58ae14f
Update
vmoens May 7, 2025
2abaa41
Update
vmoens May 9, 2025
0699be9
Update
vmoens May 9, 2025
f5aa260
Update
vmoens May 9, 2025
04c8ef9
Update
vmoens May 9, 2025
cfdb98b
Update
vmoens May 9, 2025
4016986
Update
vmoens May 9, 2025
fe5403b
Update
vmoens May 9, 2025
81095b3
Update
vmoens May 9, 2025
44189fb
Update
vmoens May 9, 2025
2bed777
Update
vmoens May 9, 2025
ec1d009
Update
vmoens May 9, 2025
ce4f31b
Update
vmoens May 9, 2025
9fa150d
Update
vmoens May 9, 2025
a365c4c
Update
vmoens May 12, 2025
1132aed
Update
vmoens May 12, 2025
6c45d31
Update
vmoens May 12, 2025
6324de6
Update
vmoens May 13, 2025
e5c7b46
Update
vmoens May 13, 2025
00ce1b4
Update
vmoens May 13, 2025
3f61daa
Update
vmoens May 13, 2025
124729b
Update
vmoens May 13, 2025
fc79624
Update
vmoens May 13, 2025
ed64649
Update
vmoens May 13, 2025
9c41be3
Update
vmoens May 13, 2025
a09eaa8
Update
vmoens May 13, 2025
c694d24
Update
vmoens May 13, 2025
a4f0d3b
Update
vmoens May 13, 2025
7e8e498
Update
vmoens May 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/unittest/linux_libs/scripts_gym/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ set -e
this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
# Avoid error: "fatal: unsafe repository"
apt-get update && apt-get install -y git wget gcc g++

apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev libsdl2-dev libsdl2-2.0-0
apt-get install -y libglvnd0 libgl1 libglx0 libegl1 libgles2 xvfb libegl-dev libx11-dev freeglut3-dev

Expand Down
81 changes: 81 additions & 0 deletions .github/unittest/linux_libs/scripts_isaaclab/isaac.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#!/usr/bin/env bash

set -e
set -v

#if [[ "${{ github.ref }}" =~ release/* ]]; then
# export RELEASE=1
# export TORCH_VERSION=stable
#else
export RELEASE=0
export TORCH_VERSION=nightly
#fi

set -euo pipefail
export PYTHON_VERSION="3.10"
export CU_VERSION="12.8"
export TAR_OPTIONS="--no-same-owner"
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
export BATCHED_PIPE_TIMEOUT=60
export TD_GET_DEFAULTS_TO_NONE=1
export OMNI_KIT_ACCEPT_EULA=yes

nvidia-smi

# Setup
apt-get update && apt-get install -y git wget gcc g++
apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev libsdl2-dev libsdl2-2.0-0
apt-get install -y libglvnd0 libgl1 libglx0 libegl1 libgles2 xvfb libegl-dev libx11-dev freeglut3-dev

git config --global --add safe.directory '*'
root_dir="$(git rev-parse --show-toplevel)"
conda_dir="${root_dir}/conda"
env_dir="${root_dir}/env"
lib_dir="${env_dir}/lib"

cd "${root_dir}"

# install conda
printf "* Installing conda\n"
wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh"
bash ./miniconda.sh -b -f -p "${conda_dir}"
eval "$(${conda_dir}/bin/conda shell.bash hook)"


conda create --prefix ${env_dir} python=3.10 -y
conda activate ${env_dir}

# Pin pytorch to 2.5.1 for IsaacLab
conda install pytorch==2.5.1 torchvision==0.20.1 pytorch-cuda=12.4 -c pytorch -c nvidia -y

conda run -p ${env_dir} pip install --upgrade pip
conda run -p ${env_dir} pip install 'isaacsim[all,extscache]==4.5.0' --extra-index-url https://pypi.nvidia.com
conda install conda-forge::"cmake>3.22" -y

git clone https://github.com/isaac-sim/IsaacLab.git
cd IsaacLab
conda run -p ${env_dir} ./isaaclab.sh --install sb3
cd ../

# install tensordict
if [[ "$RELEASE" == 0 ]]; then
conda install "anaconda::cmake>=3.22" -y
conda run -p ${env_dir} python3 -m pip install "pybind11[global]"
conda run -p ${env_dir} python3 -m pip install git+https://github.com/pytorch/tensordict.git
else
conda run -p ${env_dir} python3 -m pip install tensordict
fi

# smoke test
conda run -p ${env_dir} python -c "import tensordict"

printf "* Installing torchrl\n"
conda run -p ${env_dir} python setup.py develop
conda run -p ${env_dir} python -c "import torchrl"

# Install pytest
conda run -p ${env_dir} python -m pip install pytest pytest-cov pytest-mock pytest-instafail pytest-rerunfailures pytest-error-for-skips pytest-asyncio

# Run tests
conda run -p ${env_dir} python -m pytest test/test_libs.py -k isaac -s
18 changes: 18 additions & 0 deletions .github/workflows/test-linux-libs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,24 @@ jobs:
./.github/unittest/linux_libs/scripts_gym/batch_scripts.sh
./.github/unittest/linux_libs/scripts_gym/post_process.sh

unittests-isaaclab:
strategy:
matrix:
python_version: ["3.10"]
cuda_arch_version: ["12.8"]
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments/Isaac') }}
uses: vmoens/test-infra/.github/workflows/isaac_linux_job_v2.yml@main
with:
repository: pytorch/rl
runner: "linux.g5.4xlarge.nvidia.gpu"
docker-image: "nvcr.io/nvidia/isaac-lab:2.1.0"
test-infra-repository: vmoens/test-infra
gpu-arch-type: cuda
gpu-arch-version: "12.8"
timeout: 120
script: |
./.github/unittest/linux_libs/scripts_isaaclab/isaac.sh

unittests-jumanji:
strategy:
matrix:
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,7 @@ the following function will return ``1`` when queried:
HabitatEnv
IsaacGymEnv
IsaacGymWrapper
IsaacLabWrapper
JumanjiEnv
JumanjiWrapper
MeltingpotEnv
Expand Down
115 changes: 89 additions & 26 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,32 +32,6 @@
import pytest
import torch

if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import (
_make_multithreaded_env,
CARTPOLE_VERSIONED,
get_available_devices,
get_default_devices,
HALFCHEETAH_VERSIONED,
PENDULUM_VERSIONED,
PONG_VERSIONED,
rand_reset,
retry,
rollout_consistency_assertion,
)
else:
from _utils_internal import (
_make_multithreaded_env,
CARTPOLE_VERSIONED,
get_available_devices,
get_default_devices,
HALFCHEETAH_VERSIONED,
PENDULUM_VERSIONED,
PONG_VERSIONED,
rand_reset,
retry,
rollout_consistency_assertion,
)
from packaging import version
from tensordict import (
assert_allclose_td,
Expand Down Expand Up @@ -155,6 +129,33 @@
ValueOperator,
)

if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import (
_make_multithreaded_env,
CARTPOLE_VERSIONED,
get_available_devices,
get_default_devices,
HALFCHEETAH_VERSIONED,
PENDULUM_VERSIONED,
PONG_VERSIONED,
rand_reset,
retry,
rollout_consistency_assertion,
)
else:
from _utils_internal import (
_make_multithreaded_env,
CARTPOLE_VERSIONED,
get_available_devices,
get_default_devices,
HALFCHEETAH_VERSIONED,
PENDULUM_VERSIONED,
PONG_VERSIONED,
rand_reset,
retry,
rollout_consistency_assertion,
)

_has_d4rl = importlib.util.find_spec("d4rl") is not None

_has_mo = importlib.util.find_spec("mo_gymnasium") is not None
Expand All @@ -166,6 +167,9 @@
_has_minari = importlib.util.find_spec("minari") is not None

_has_gymnasium = importlib.util.find_spec("gymnasium") is not None

_has_isaaclab = importlib.util.find_spec("isaaclab") is not None

_has_gym_regular = importlib.util.find_spec("gym") is not None
if _has_gymnasium:
set_gym_backend("gymnasium").set()
Expand Down Expand Up @@ -4541,6 +4545,65 @@ def test_render(self, rollout_steps):
assert not torch.equal(rollout_penultimate_image, image_from_env)


@pytest.mark.skipif(not _has_isaaclab, reason="Isaaclab not found")
class TestIsaacLab:
@pytest.fixture(scope="class")
def env(self):
torch.manual_seed(0)
import argparse

# This code block ensures that the Isaac app is started in headless mode
from isaaclab.app import AppLauncher

parser = argparse.ArgumentParser(description="Train an RL agent with TorchRL.")
AppLauncher.add_app_launcher_args(parser)
args_cli, hydra_args = parser.parse_known_args(["--headless"])
AppLauncher(args_cli)

# Imports and env
import gymnasium as gym
import isaaclab_tasks # noqa: F401
from isaaclab_tasks.manager_based.classic.ant.ant_env_cfg import AntEnvCfg
from torchrl.envs.libs.isaac_lab import IsaacLabWrapper

torchrl_logger.info("Making IsaacLab env...")
env = gym.make("Isaac-Ant-v0", cfg=AntEnvCfg())
torchrl_logger.info("Wrapping IsaacLab env...")
try:
env = IsaacLabWrapper(env)
yield env
finally:
torchrl_logger.info("Closing IsaacLab env...")
env.close()
torchrl_logger.info("Closed")

def test_isaaclab(self, env):
assert env.batch_size == (4096,)
assert env._is_batched
torchrl_logger.info("Checking env specs...")
env.check_env_specs(break_when_any_done="both")
torchrl_logger.info("Check succeeded!")

def test_isaac_collector(self, env):
col = SyncDataCollector(
env, env.rand_action, frames_per_batch=1000, total_frames=100_000_000
)
try:
for data in col:
assert data.shape == (4096, 1)
break
finally:
# We must do that, otherwise `__del__` calls `shutdown` and the next test will fail
col.shutdown(close_env=False)

def test_isaaclab_reset(self, env):
# Make a rollout that will stop as soon as a trajectory reaches a done state
r = env.rollout(1_000_000)

# Check that done obs are None
assert not r["next", "policy"][r["next", "done"].squeeze(-1)].isfinite().any()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lin-erica this should hold for you too, I don't think we have access to the last obs of a trajectory in Isaac

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's right.



if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
57 changes: 43 additions & 14 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,15 +278,19 @@ def pause(self):
f"Collector pause() is not implemented for {type(self).__name__}."
)

def async_shutdown(self, timeout: float | None = None) -> None:
def async_shutdown(
self, timeout: float | None = None, close_env: bool = True
) -> None:
"""Shuts down the collector when started asynchronously with the `start` method.

Arg:
timeout (float, optional): The maximum time to wait for the collector to shutdown.
close_env (bool, optional): If True, the collector will close the contained environment.
Defaults to `True`.

.. seealso:: :meth:`~.start`
"""
return self.shutdown(timeout=timeout)
return self.shutdown(timeout=timeout, close_env=close_env)

def update_policy_weights_(
self,
Expand Down Expand Up @@ -342,7 +346,7 @@ def next(self):
return None

@abc.abstractmethod
def shutdown(self, timeout: float | None = None) -> None:
def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None:
raise NotImplementedError

@abc.abstractmethod
Expand Down Expand Up @@ -1317,12 +1321,14 @@ def _run_iterator(self):
if self._stop:
return

def async_shutdown(self, timeout: float | None = None) -> None:
def async_shutdown(
self, timeout: float | None = None, close_env: bool = True
) -> None:
"""Finishes processes started by ray.init() during async execution."""
self._stop = True
if hasattr(self, "_thread") and self._thread.is_alive():
self._thread.join(timeout=timeout)
self.shutdown()
self.shutdown(close_env=close_env)

def _postproc(self, tensordict_out):
if self.split_trajs:
Expand Down Expand Up @@ -1582,14 +1588,20 @@ def reset(self, index=None, **kwargs) -> None:
)
self._shuttle["collector"] = collector_metadata

def shutdown(self, timeout: float | None = None) -> None:
"""Shuts down all workers and/or closes the local environment."""
def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None:
"""Shuts down all workers and/or closes the local environment.

Args:
timeout (float, optional): The timeout for closing pipes between workers.
No effect for this class.
close_env (bool, optional): Whether to close the environment. Defaults to `True`.
"""
if not self.closed:
self.closed = True
del self._shuttle
if self._use_buffers:
del self._final_rollout
if not self.env.is_closed:
if close_env and not self.env.is_closed:
self.env.close()
del self.env
return
Expand Down Expand Up @@ -2391,8 +2403,17 @@ def __del__(self):
# __del__ will not affect the program.
pass

def shutdown(self, timeout: float | None = None) -> None:
"""Shuts down all processes. This operation is irreversible."""
def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None:
"""Shuts down all processes. This operation is irreversible.

Args:
timeout (float, optional): The timeout for closing pipes between workers.
close_env (bool, optional): Whether to close the environment. Defaults to `True`.
"""
if not close_env:
raise RuntimeError(
f"Cannot shutdown {type(self).__name__} collector without environment being closed."
)
self._shutdown_main(timeout)

def _shutdown_main(self, timeout: float | None = None) -> None:
Expand Down Expand Up @@ -2665,7 +2686,11 @@ def next(self):
return super().next()

# for RPC
def shutdown(self, timeout: float | None = None) -> None:
def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None:
if not close_env:
raise RuntimeError(
f"Cannot shutdown {type(self).__name__} collector without environment being closed."
)
if hasattr(self, "out_buffer"):
del self.out_buffer
if hasattr(self, "buffers"):
Expand Down Expand Up @@ -3038,9 +3063,13 @@ def next(self):
return super().next()

# for RPC
def shutdown(self, timeout: float | None = None) -> None:
def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None:
if hasattr(self, "out_tensordicts"):
del self.out_tensordicts
if not close_env:
raise RuntimeError(
f"Cannot shutdown {type(self).__name__} collector without environment being closed."
)
return super().shutdown(timeout=timeout)

# for RPC
Expand Down Expand Up @@ -3382,8 +3411,8 @@ def next(self):
return super().next()

# for RPC
def shutdown(self, timeout: float | None = None) -> None:
return super().shutdown(timeout=timeout)
def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None:
return super().shutdown(timeout=timeout, close_env=close_env)

# for RPC
def set_seed(self, seed: int, static_seed: bool = False) -> int:
Expand Down
Loading
Loading