diff --git a/test/test_libs.py b/test/test_libs.py index 2861e24c3f6..7ddb0d4fc02 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3396,10 +3396,11 @@ class TestRoboHive: # In the CI, robohive should not coexist with other libs so that's fine. # Robohive logging behaviour can be controlled via ROBOHIVE_VERBOSITY=ALL/INFO/(WARN)/ERROR/ONCE/ALWAYS/SILENT @pytest.mark.parametrize("from_pixels", [False, True]) + @pytest.mark.parametrize("from_depths", [False, True]) @pytest.mark.parametrize("envname", RoboHiveEnv.available_envs) - def test_robohive(self, envname, from_pixels): + def test_robohive(self, envname, from_pixels, from_depths): with set_gym_backend("gymnasium"): - torchrl_logger.info(f"{envname}-{from_pixels}") + torchrl_logger.info(f"{envname}-{from_pixels}-{from_depths}") if any( substr in envname for substr in ("_vr3m", "_vrrl", "_vflat", "_vvc1s") ): @@ -3415,7 +3416,7 @@ def test_robohive(self, envname, from_pixels): torchrl_logger.info("no camera") return try: - env = RoboHiveEnv(envname, from_pixels=from_pixels) + env = RoboHiveEnv(envname, from_pixels=from_pixels, from_depths=from_depths) except AttributeError as err: if "'MjData' object has no attribute 'get_body_xipos'" in str(err): torchrl_logger.info("tcdm are broken") diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py index b25e9ebb63d..5e5c8f52393 100644 --- a/torchrl/envs/libs/robohive.py +++ b/torchrl/envs/libs/robohive.py @@ -95,6 +95,10 @@ class RoboHiveEnv(GymEnv, metaclass=_RoboHiveBuild): be returned (by default under the ``"pixels"`` entry in the output tensordict). If ``False``, observations (eg, states) and pixels will be returned whenever ``from_pixels=True``. Defaults to ``True``. + from_depths (bool, optional): if ``True``, an attempt to return the depth + observations from the env will be performed. By default, these observations + will be written under the ``"depths"`` entry. Requires ``from_pixels`` to be ``True``. + Defaults to ``False``. frame_skip (int, optional): if provided, indicates for how many steps the same action is to be repeated. The observation returned will be the last observation of the sequence, whereas the reward will be the sum @@ -155,6 +159,7 @@ def _build_env( # noqa: F811 env_name: str, from_pixels: bool = False, pixels_only: bool = False, + from_depths: bool = False, **kwargs, ) -> "gym.core.Env": # noqa: F821 if from_pixels: @@ -168,7 +173,9 @@ def _build_env( # noqa: F811 ) kwargs["cameras"] = self.get_available_cams(env_name) cams = list(kwargs.pop("cameras")) - env_name = self.register_visual_env(cams=cams, env_name=env_name) + env_name = self.register_visual_env( + cams=cams, env_name=env_name, from_depths=from_depths + ) elif "cameras" in kwargs and kwargs["cameras"]: raise RuntimeError("Got a list of cameras but from_pixels is set to False.") @@ -194,10 +201,6 @@ def _build_env( # noqa: F811 **kwargs, ) self.wrapper_frame_skip = 1 - if env.visual_keys: - from_pixels = bool(len(env.visual_keys)) - else: - from_pixels = False except TypeError as err: if "unexpected keyword argument 'frameskip" not in str(err): raise err @@ -209,6 +212,7 @@ def _build_env( # noqa: F811 # except Exception as err: # raise RuntimeError(f"Failed to build env {env_name}.") from err self.from_pixels = from_pixels + self.from_depths = from_depths self.render_device = render_device if kwargs.get("read_info", True): self.set_info_dict_reader(self.read_info) @@ -224,7 +228,7 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821 return out @classmethod - def register_visual_env(cls, env_name, cams): + def register_visual_env(cls, env_name, cams, from_depths): with set_directory(cls.CURR_DIR): from robohive.envs.env_variants import register_env_variant @@ -233,9 +237,9 @@ def register_visual_env(cls, env_name, cams): cams = sorted(cams) cams_rep = [i.replace("A:", "A_") for i in cams] new_env_name = "-".join([cam[:-3] for cam in cams_rep] + [env_name]) - if new_env_name in cls.env_list: - return new_env_name visual_keys = [f"rgb:{c}:224x224:2d" for c in cams] + if from_depths: + visual_keys.extend([f"d:{c}:224x224:2d" for c in cams]) register_env_variant( env_name, variants={ @@ -262,13 +266,17 @@ def get_obs(): if self.from_pixels: visual = self.env.get_exteroception() obs_dict.update(visual) - pixel_list = [] + pixel_list, depth_list = [], [] for obs_key in obs_dict: if obs_key.startswith("rgb"): pix = obs_dict[obs_key] if not pix.shape[0] == 1: pix = pix[None] pixel_list.append(pix) + elif obs_key.startswith("d:"): + dep = obs_dict[obs_key] + dep = dep[None] + depth_list.append(dep) elif obs_key in env.obs_keys: value = env.obs_dict[obs_key] if not value.shape: @@ -276,6 +284,8 @@ def get_obs(): _dict[obs_key] = value if pixel_list: _dict["pixels"] = np.concatenate(pixel_list, 0) + if depth_list: + _dict["depths"] = np.concatenate(depth_list, 0) return _dict for i in range(3): @@ -335,7 +345,7 @@ def read_obs(self, observation): pass # recover vec obsdict = {} - pixel_list = [] + pixel_list, depth_list = [], [] if self.from_pixels: visual = self.env.get_exteroception() observations.update(visual) @@ -345,6 +355,10 @@ def read_obs(self, observation): if not pix.shape[0] == 1: pix = pix[None] pixel_list.append(pix) + elif key.startswith("d:"): + dep = observations[key] + dep = dep[None] + depth_list.append(dep) elif key in self._env.obs_keys: value = observations[key] if not value.shape: @@ -354,6 +368,8 @@ def read_obs(self, observation): # obsvec = np.concatenate(obsvec, 0) if self.from_pixels: obsdict.update({"pixels": np.concatenate(pixel_list, 0)}) + if self.from_pixels and self.from_depths: + obsdict.update({"depths": np.concatenate(depth_list, 0)}) out = obsdict return super().read_obs(out)