diff --git a/.github/unittest/linux_libs/scripts_gym/setup_env.sh b/.github/unittest/linux_libs/scripts_gym/setup_env.sh index 25c493babee..aade606ba16 100755 --- a/.github/unittest/linux_libs/scripts_gym/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_gym/setup_env.sh @@ -10,7 +10,6 @@ set -e this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" # Avoid error: "fatal: unsafe repository" apt-get update && apt-get install -y git wget gcc g++ - apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev libsdl2-dev libsdl2-2.0-0 apt-get install -y libglvnd0 libgl1 libglx0 libegl1 libgles2 xvfb libegl-dev libx11-dev freeglut3-dev diff --git a/.github/unittest/linux_libs/scripts_isaaclab/isaac.sh b/.github/unittest/linux_libs/scripts_isaaclab/isaac.sh new file mode 100755 index 00000000000..1068d75f68c --- /dev/null +++ b/.github/unittest/linux_libs/scripts_isaaclab/isaac.sh @@ -0,0 +1,81 @@ +#!/usr/bin/env bash + +set -e +set -v + +#if [[ "${{ github.ref }}" =~ release/* ]]; then +# export RELEASE=1 +# export TORCH_VERSION=stable +#else +export RELEASE=0 +export TORCH_VERSION=nightly +#fi + +set -euo pipefail +export PYTHON_VERSION="3.10" +export CU_VERSION="12.8" +export TAR_OPTIONS="--no-same-owner" +export UPLOAD_CHANNEL="nightly" +export TF_CPP_MIN_LOG_LEVEL=0 +export BATCHED_PIPE_TIMEOUT=60 +export TD_GET_DEFAULTS_TO_NONE=1 +export OMNI_KIT_ACCEPT_EULA=yes + +nvidia-smi + +# Setup +apt-get update && apt-get install -y git wget gcc g++ +apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev libsdl2-dev libsdl2-2.0-0 +apt-get install -y libglvnd0 libgl1 libglx0 libegl1 libgles2 xvfb libegl-dev libx11-dev freeglut3-dev + +git config --global --add safe.directory '*' +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +cd "${root_dir}" + +# install conda +printf "* Installing conda\n" +wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh" +bash ./miniconda.sh -b -f -p "${conda_dir}" +eval "$(${conda_dir}/bin/conda shell.bash hook)" + + +conda create --prefix ${env_dir} python=3.10 -y +conda activate ${env_dir} + +# Pin pytorch to 2.5.1 for IsaacLab +conda install pytorch==2.5.1 torchvision==0.20.1 pytorch-cuda=12.4 -c pytorch -c nvidia -y + +conda run -p ${env_dir} pip install --upgrade pip +conda run -p ${env_dir} pip install 'isaacsim[all,extscache]==4.5.0' --extra-index-url https://pypi.nvidia.com +conda install conda-forge::"cmake>3.22" -y + +git clone https://github.com/isaac-sim/IsaacLab.git +cd IsaacLab +conda run -p ${env_dir} ./isaaclab.sh --install sb3 +cd ../ + +# install tensordict +if [[ "$RELEASE" == 0 ]]; then + conda install "anaconda::cmake>=3.22" -y + conda run -p ${env_dir} python3 -m pip install "pybind11[global]" + conda run -p ${env_dir} python3 -m pip install git+https://github.com/pytorch/tensordict.git +else + conda run -p ${env_dir} python3 -m pip install tensordict +fi + +# smoke test +conda run -p ${env_dir} python -c "import tensordict" + +printf "* Installing torchrl\n" +conda run -p ${env_dir} python setup.py develop +conda run -p ${env_dir} python -c "import torchrl" + +# Install pytest +conda run -p ${env_dir} python -m pip install pytest pytest-cov pytest-mock pytest-instafail pytest-rerunfailures pytest-error-for-skips pytest-asyncio + +# Run tests +conda run -p ${env_dir} python -m pytest test/test_libs.py -k isaac -s diff --git a/.github/workflows/test-linux-libs.yml b/.github/workflows/test-linux-libs.yml index fe17076c8aa..9cb1a9f1c5c 100644 --- a/.github/workflows/test-linux-libs.yml +++ b/.github/workflows/test-linux-libs.yml @@ -230,6 +230,24 @@ jobs: ./.github/unittest/linux_libs/scripts_gym/batch_scripts.sh ./.github/unittest/linux_libs/scripts_gym/post_process.sh + unittests-isaaclab: + strategy: + matrix: + python_version: ["3.10"] + cuda_arch_version: ["12.8"] + if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments/Isaac') }} + uses: vmoens/test-infra/.github/workflows/isaac_linux_job_v2.yml@main + with: + repository: pytorch/rl + runner: "linux.g5.4xlarge.nvidia.gpu" + docker-image: "nvcr.io/nvidia/isaac-lab:2.1.0" + test-infra-repository: vmoens/test-infra + gpu-arch-type: cuda + gpu-arch-version: "12.8" + timeout: 120 + script: | + ./.github/unittest/linux_libs/scripts_isaaclab/isaac.sh + unittests-jumanji: strategy: matrix: diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index e5602d0553f..ec8b29a9abd 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -1417,6 +1417,7 @@ the following function will return ``1`` when queried: HabitatEnv IsaacGymEnv IsaacGymWrapper + IsaacLabWrapper JumanjiEnv JumanjiWrapper MeltingpotEnv diff --git a/test/test_libs.py b/test/test_libs.py index ad0820b3e8a..18bae38c12a 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -32,32 +32,6 @@ import pytest import torch -if os.getenv("PYTORCH_TEST_FBCODE"): - from pytorch.rl.test._utils_internal import ( - _make_multithreaded_env, - CARTPOLE_VERSIONED, - get_available_devices, - get_default_devices, - HALFCHEETAH_VERSIONED, - PENDULUM_VERSIONED, - PONG_VERSIONED, - rand_reset, - retry, - rollout_consistency_assertion, - ) -else: - from _utils_internal import ( - _make_multithreaded_env, - CARTPOLE_VERSIONED, - get_available_devices, - get_default_devices, - HALFCHEETAH_VERSIONED, - PENDULUM_VERSIONED, - PONG_VERSIONED, - rand_reset, - retry, - rollout_consistency_assertion, - ) from packaging import version from tensordict import ( assert_allclose_td, @@ -155,6 +129,33 @@ ValueOperator, ) +if os.getenv("PYTORCH_TEST_FBCODE"): + from pytorch.rl.test._utils_internal import ( + _make_multithreaded_env, + CARTPOLE_VERSIONED, + get_available_devices, + get_default_devices, + HALFCHEETAH_VERSIONED, + PENDULUM_VERSIONED, + PONG_VERSIONED, + rand_reset, + retry, + rollout_consistency_assertion, + ) +else: + from _utils_internal import ( + _make_multithreaded_env, + CARTPOLE_VERSIONED, + get_available_devices, + get_default_devices, + HALFCHEETAH_VERSIONED, + PENDULUM_VERSIONED, + PONG_VERSIONED, + rand_reset, + retry, + rollout_consistency_assertion, + ) + _has_d4rl = importlib.util.find_spec("d4rl") is not None _has_mo = importlib.util.find_spec("mo_gymnasium") is not None @@ -166,6 +167,9 @@ _has_minari = importlib.util.find_spec("minari") is not None _has_gymnasium = importlib.util.find_spec("gymnasium") is not None + +_has_isaaclab = importlib.util.find_spec("isaaclab") is not None + _has_gym_regular = importlib.util.find_spec("gym") is not None if _has_gymnasium: set_gym_backend("gymnasium").set() @@ -4541,6 +4545,65 @@ def test_render(self, rollout_steps): assert not torch.equal(rollout_penultimate_image, image_from_env) +@pytest.mark.skipif(not _has_isaaclab, reason="Isaaclab not found") +class TestIsaacLab: + @pytest.fixture(scope="class") + def env(self): + torch.manual_seed(0) + import argparse + + # This code block ensures that the Isaac app is started in headless mode + from isaaclab.app import AppLauncher + + parser = argparse.ArgumentParser(description="Train an RL agent with TorchRL.") + AppLauncher.add_app_launcher_args(parser) + args_cli, hydra_args = parser.parse_known_args(["--headless"]) + AppLauncher(args_cli) + + # Imports and env + import gymnasium as gym + import isaaclab_tasks # noqa: F401 + from isaaclab_tasks.manager_based.classic.ant.ant_env_cfg import AntEnvCfg + from torchrl.envs.libs.isaac_lab import IsaacLabWrapper + + torchrl_logger.info("Making IsaacLab env...") + env = gym.make("Isaac-Ant-v0", cfg=AntEnvCfg()) + torchrl_logger.info("Wrapping IsaacLab env...") + try: + env = IsaacLabWrapper(env) + yield env + finally: + torchrl_logger.info("Closing IsaacLab env...") + env.close() + torchrl_logger.info("Closed") + + def test_isaaclab(self, env): + assert env.batch_size == (4096,) + assert env._is_batched + torchrl_logger.info("Checking env specs...") + env.check_env_specs(break_when_any_done="both") + torchrl_logger.info("Check succeeded!") + + def test_isaac_collector(self, env): + col = SyncDataCollector( + env, env.rand_action, frames_per_batch=1000, total_frames=100_000_000 + ) + try: + for data in col: + assert data.shape == (4096, 1) + break + finally: + # We must do that, otherwise `__del__` calls `shutdown` and the next test will fail + col.shutdown(close_env=False) + + def test_isaaclab_reset(self, env): + # Make a rollout that will stop as soon as a trajectory reaches a done state + r = env.rollout(1_000_000) + + # Check that done obs are None + assert not r["next", "policy"][r["next", "done"].squeeze(-1)].isfinite().any() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 8a69b8cee09..a00f902a716 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -278,15 +278,19 @@ def pause(self): f"Collector pause() is not implemented for {type(self).__name__}." ) - def async_shutdown(self, timeout: float | None = None) -> None: + def async_shutdown( + self, timeout: float | None = None, close_env: bool = True + ) -> None: """Shuts down the collector when started asynchronously with the `start` method. Arg: timeout (float, optional): The maximum time to wait for the collector to shutdown. + close_env (bool, optional): If True, the collector will close the contained environment. + Defaults to `True`. .. seealso:: :meth:`~.start` """ - return self.shutdown(timeout=timeout) + return self.shutdown(timeout=timeout, close_env=close_env) def update_policy_weights_( self, @@ -342,7 +346,7 @@ def next(self): return None @abc.abstractmethod - def shutdown(self, timeout: float | None = None) -> None: + def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None: raise NotImplementedError @abc.abstractmethod @@ -1317,12 +1321,14 @@ def _run_iterator(self): if self._stop: return - def async_shutdown(self, timeout: float | None = None) -> None: + def async_shutdown( + self, timeout: float | None = None, close_env: bool = True + ) -> None: """Finishes processes started by ray.init() during async execution.""" self._stop = True if hasattr(self, "_thread") and self._thread.is_alive(): self._thread.join(timeout=timeout) - self.shutdown() + self.shutdown(close_env=close_env) def _postproc(self, tensordict_out): if self.split_trajs: @@ -1582,14 +1588,20 @@ def reset(self, index=None, **kwargs) -> None: ) self._shuttle["collector"] = collector_metadata - def shutdown(self, timeout: float | None = None) -> None: - """Shuts down all workers and/or closes the local environment.""" + def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None: + """Shuts down all workers and/or closes the local environment. + + Args: + timeout (float, optional): The timeout for closing pipes between workers. + No effect for this class. + close_env (bool, optional): Whether to close the environment. Defaults to `True`. + """ if not self.closed: self.closed = True del self._shuttle if self._use_buffers: del self._final_rollout - if not self.env.is_closed: + if close_env and not self.env.is_closed: self.env.close() del self.env return @@ -2391,8 +2403,17 @@ def __del__(self): # __del__ will not affect the program. pass - def shutdown(self, timeout: float | None = None) -> None: - """Shuts down all processes. This operation is irreversible.""" + def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None: + """Shuts down all processes. This operation is irreversible. + + Args: + timeout (float, optional): The timeout for closing pipes between workers. + close_env (bool, optional): Whether to close the environment. Defaults to `True`. + """ + if not close_env: + raise RuntimeError( + f"Cannot shutdown {type(self).__name__} collector without environment being closed." + ) self._shutdown_main(timeout) def _shutdown_main(self, timeout: float | None = None) -> None: @@ -2665,7 +2686,11 @@ def next(self): return super().next() # for RPC - def shutdown(self, timeout: float | None = None) -> None: + def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None: + if not close_env: + raise RuntimeError( + f"Cannot shutdown {type(self).__name__} collector without environment being closed." + ) if hasattr(self, "out_buffer"): del self.out_buffer if hasattr(self, "buffers"): @@ -3038,9 +3063,13 @@ def next(self): return super().next() # for RPC - def shutdown(self, timeout: float | None = None) -> None: + def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None: if hasattr(self, "out_tensordicts"): del self.out_tensordicts + if not close_env: + raise RuntimeError( + f"Cannot shutdown {type(self).__name__} collector without environment being closed." + ) return super().shutdown(timeout=timeout) # for RPC @@ -3382,8 +3411,8 @@ def next(self): return super().next() # for RPC - def shutdown(self, timeout: float | None = None) -> None: - return super().shutdown(timeout=timeout) + def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None: + return super().shutdown(timeout=timeout, close_env=close_env) # for RPC def set_seed(self, seed: int, static_seed: bool = False) -> int: diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index de6c48c1402..b4014462a23 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -20,6 +20,7 @@ HabitatEnv, IsaacGymEnv, IsaacGymWrapper, + IsaacLabWrapper, JumanjiEnv, JumanjiWrapper, MeltingpotEnv, @@ -131,6 +132,7 @@ "ActionDiscretizer", "ActionMask", "VecNormV2", + "IsaacLabWrapper", "AutoResetEnv", "AutoResetTransform", "AsyncEnvPool", diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index efc761ae247..468debb1e71 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -519,6 +519,12 @@ def validated(self, value): def _reset( self, tensordict: TensorDictBase | None = None, **kwargs ) -> TensorDictBase: + if ( + tensordict is not None + and "_reset" in tensordict + and not tensordict["_reset"].all() + ): + raise RuntimeError("Partial resets are not handled at this level.") obs, info = self._reset_output_transform(self._env.reset(**kwargs)) source = self.read_obs(obs) diff --git a/torchrl/envs/libs/__init__.py b/torchrl/envs/libs/__init__.py index 8ae4695683c..908758876c8 100644 --- a/torchrl/envs/libs/__init__.py +++ b/torchrl/envs/libs/__init__.py @@ -16,6 +16,7 @@ set_gym_backend, ) from .habitat import HabitatEnv +from .isaac_lab import IsaacLabWrapper from .isaacgym import IsaacGymEnv, IsaacGymWrapper from .jumanji import JumanjiEnv, JumanjiWrapper from .meltingpot import MeltingpotEnv, MeltingpotWrapper @@ -32,22 +33,20 @@ "BraxWrapper", "DMControlEnv", "DMControlWrapper", - "MultiThreadedEnv", - "MultiThreadedEnvWrapper", - "gym_backend", "GymEnv", "GymWrapper", - "MOGymEnv", - "MOGymWrapper", - "register_gym_spec_conversion", - "set_gym_backend", "HabitatEnv", "IsaacGymEnv", "IsaacGymWrapper", + "IsaacLabWrapper", "JumanjiEnv", "JumanjiWrapper", + "MOGymEnv", + "MOGymWrapper", "MeltingpotEnv", "MeltingpotWrapper", + "MultiThreadedEnv", + "MultiThreadedEnvWrapper", "OpenMLEnv", "OpenSpielEnv", "OpenSpielWrapper", @@ -60,4 +59,7 @@ "UnityMLAgentsWrapper", "VmasEnv", "VmasWrapper", + "gym_backend", + "register_gym_spec_conversion", + "set_gym_backend", ] diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 6cab799d515..5c4defbc52d 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -53,6 +53,7 @@ _has_mo = importlib.util.find_spec("mo_gymnasium") is not None _has_sb3 = importlib.util.find_spec("stable_baselines3") is not None +_has_isaaclab = importlib.util.find_spec("isaaclab") is not None _has_minigrid = importlib.util.find_spec("minigrid") is not None @@ -793,6 +794,7 @@ class PixelObservationWrapper: class _GymAsyncMeta(_EnvPostInit): def __call__(cls, *args, **kwargs): + missing_obs_value = kwargs.pop("missing_obs_value", None) instance: GymWrapper = super().__call__(*args, **kwargs) # before gym 0.22, there was no final_observation @@ -803,6 +805,15 @@ def __call__(cls, *args, **kwargs): VecGymEnvTransform, ) + if _has_isaaclab: + from isaaclab.envs import ManagerBasedRLEnv + + kwargs = {} + if missing_obs_value is not None: + kwargs["missing_obs_value"] = missing_obs_value + if isinstance(instance._env.unwrapped, ManagerBasedRLEnv): + return TransformedEnv(instance, VecGymEnvTransform(**kwargs)) + if _has_sb3: from stable_baselines3.common.vec_env.base_vec_env import VecEnv @@ -845,7 +856,10 @@ def __call__(cls, *args, **kwargs): instance.observation_spec, backend=backend ) ) - return TransformedEnv(instance, VecGymEnvTransform()) + kwargs = {} + if missing_obs_value is not None: + kwargs["missing_obs_value"] = missing_obs_value + return TransformedEnv(instance, VecGymEnvTransform(**kwargs)) return instance @@ -892,6 +906,10 @@ class GymWrapper(GymLikeEnv, metaclass=_GymAsyncMeta): env step function. Set this to ``False`` if the environment is evaluated on GPU, such as IsaacLab. Defaults to ``True``. + missing_obs_value (Any, optional): default value to use as placeholder for missing observations, when + the environment is auto-resetting and missing observations cannot be found in the info dictionary + (e.g., with IsaacLab). This argument is passed to :class:`~torchrl.envs.VecGymEnvTransform` by + the metaclass. Attributes: available_envs (List[str]): a list of environments to build. @@ -1069,14 +1087,17 @@ def _post_init(self): @property def _is_batched(self): + tuple_of_classes = () if _has_sb3: from stable_baselines3.common.vec_env.base_vec_env import VecEnv - tuple_of_classes = (VecEnv,) - else: - tuple_of_classes = () + tuple_of_classes = tuple_of_classes + (VecEnv,) + if _has_isaaclab: + from isaaclab.envs import ManagerBasedRLEnv + + tuple_of_classes = tuple_of_classes + (ManagerBasedRLEnv,) return isinstance( - self._env, tuple_of_classes + (gym_backend("vector").VectorEnv,) + self._env.unwrapped, tuple_of_classes + (gym_backend("vector").VectorEnv,) ) @implement_for("gym") @@ -1562,7 +1583,10 @@ def _replace_reset(self, reset, kwargs): # noqa def _replace_reset(self, reset, kwargs): # noqa import gymnasium as gym - if self._env.autoreset_mode == gym.vector.AutoresetMode.DISABLED: + if ( + getattr(self._env, "autoreset_mode", None) + == gym.vector.AutoresetMode.DISABLED + ): options = {"reset_mask": reset.view(self.batch_size).numpy()} kwargs.setdefault("options", {}).update(options) return kwargs diff --git a/torchrl/envs/libs/isaac_lab.py b/torchrl/envs/libs/isaac_lab.py new file mode 100644 index 00000000000..ee6beaefa58 --- /dev/null +++ b/torchrl/envs/libs/isaac_lab.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import torch +from torchrl.envs.libs.gym import GymWrapper + + +class IsaacLabWrapper(GymWrapper): + """A wrapper for IsaacLab environments. + + Args: + env (scripts_isaaclab.envs.ManagerBasedRLEnv or equivalent): the environment instance to wrap. + categorical_action_encoding (bool, optional): if ``True``, categorical + specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), + otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). + Defaults to ``False``. + allow_done_after_reset (bool, optional): if ``True``, it is tolerated + for envs to be ``done`` just after :meth:`reset` is called. + Defaults to ``False``. + + For other arguments, see the :class:`torchrl.envs.GymWrapper` documentation. + + Refer to `the Isaac Lab doc for installation instructions `_. + + Example: + >>> # This code block ensures that the Isaac app is started in headless mode + >>> from scripts_isaaclab.app import AppLauncher + >>> import argparse + + >>> parser = argparse.ArgumentParser(description="Train an RL agent with TorchRL.") + >>> AppLauncher.add_app_launcher_args(parser) + >>> args_cli, hydra_args = parser.parse_known_args(["--headless"]) + >>> app_launcher = AppLauncher(args_cli) + + >>> # Imports and env + >>> import gymnasium as gym + >>> import isaaclab_tasks # noqa: F401 + >>> from isaaclab_tasks.manager_based.classic.ant.ant_env_cfg import AntEnvCfg + >>> from torchrl.envs.libs.isaac_lab import IsaacLabWrapper + + >>> env = gym.make("Isaac-Ant-v0", cfg=AntEnvCfg()) + >>> env = IsaacLabWrapper(env) + + """ + + def __init__( + self, + env: isaaclab.envs.ManagerBasedRLEnv, # noqa: F821 + *, + categorical_action_encoding: bool = False, + allow_done_after_reset: bool = True, + convert_actions_to_numpy: bool = False, + device: torch.device | None = None, + **kwargs, + ): + if device is None: + device = torch.device("cuda:0") + super().__init__( + env, + device=device, + categorical_action_encoding=categorical_action_encoding, + allow_done_after_reset=allow_done_after_reset, + convert_actions_to_numpy=convert_actions_to_numpy, + **kwargs, + ) + + def seed(self, seed: int | None): + self._set_seed(seed) + + def _output_transform(self, step_outputs_tuple): # noqa: F811 + # IsaacLab will modify the `terminated` and `truncated` tensors + # in-place. We clone them here to make sure data doesn't inadvertently get modified. + # The variable naming follows torchrl's convention here. + observations, reward, terminated, truncated, info = step_outputs_tuple + done = terminated | truncated + reward = reward.unsqueeze(-1) # to get to (num_envs, 1) + return ( + observations, + reward, + terminated.clone(), + truncated.clone(), + done.clone(), + info, + ) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 0a20935fc9e..19e2ad7ec7d 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -8868,16 +8868,21 @@ class VecGymEnvTransform(Transform): Args: final_name (str, optional): the name of the final observation in the dict. Defaults to `"final"`. + missing_obs_value (Any, optional): default value to use as placeholder for missing + last observations. Defaults to `np.nan`. .. note:: In general, this class should not be handled directly. It is created whenever a vectorized environment is placed within a :class:`GymWrapper`. """ - def __init__(self, final_name="final"): + def __init__(self, final_name: str = "final", missing_obs_value: Any = np.nan): self.final_name = final_name super().__init__() self._memo = {} + if not isinstance(missing_obs_value, torch.Tensor): + missing_obs_value = torch.tensor(missing_obs_value) + self.missing_obs_value = missing_obs_value def set_container(self, container: Transform | EnvBase) -> None: out = super().set_container(container) @@ -8908,7 +8913,7 @@ def _step( else: saved_next = next_tensordict.select(*self.obs_keys).clone() for obs_key in self.obs_keys: - next_tensordict[obs_key][done] = torch.tensor(np.nan) + next_tensordict[obs_key][done] = self.missing_obs_value self._memo["saved_next"] = saved_next else: diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 02bb4db7a40..c25939a4f62 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -14,7 +14,7 @@ import re import warnings from enum import Enum -from typing import Any +from typing import Any, Literal import torch @@ -687,7 +687,7 @@ def check_env_specs( check_dtype=True, seed: int | None = None, tensordict: TensorDictBase | None = None, - break_when_any_done: bool | str = None, + break_when_any_done: bool | Literal["both"] = None, ): """Tests an environment specs against the results of short rollout.