diff --git a/test/test_libs.py b/test/test_libs.py index 2861e24c3f6..7ddb0d4fc02 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3396,10 +3396,11 @@ class TestRoboHive: # In the CI, robohive should not coexist with other libs so that's fine. # Robohive logging behaviour can be controlled via ROBOHIVE_VERBOSITY=ALL/INFO/(WARN)/ERROR/ONCE/ALWAYS/SILENT @pytest.mark.parametrize("from_pixels", [False, True]) + @pytest.mark.parametrize("from_depths", [False, True]) @pytest.mark.parametrize("envname", RoboHiveEnv.available_envs) - def test_robohive(self, envname, from_pixels): + def test_robohive(self, envname, from_pixels, from_depths): with set_gym_backend("gymnasium"): - torchrl_logger.info(f"{envname}-{from_pixels}") + torchrl_logger.info(f"{envname}-{from_pixels}-{from_depths}") if any( substr in envname for substr in ("_vr3m", "_vrrl", "_vflat", "_vvc1s") ): @@ -3415,7 +3416,7 @@ def test_robohive(self, envname, from_pixels): torchrl_logger.info("no camera") return try: - env = RoboHiveEnv(envname, from_pixels=from_pixels) + env = RoboHiveEnv(envname, from_pixels=from_pixels, from_depths=from_depths) except AttributeError as err: if "'MjData' object has no attribute 'get_body_xipos'" in str(err): torchrl_logger.info("tcdm are broken") diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py index 30132b931c7..5e5c8f52393 100644 --- a/torchrl/envs/libs/robohive.py +++ b/torchrl/envs/libs/robohive.py @@ -237,8 +237,6 @@ def register_visual_env(cls, env_name, cams, from_depths): cams = sorted(cams) cams_rep = [i.replace("A:", "A_") for i in cams] new_env_name = "-".join([cam[:-3] for cam in cams_rep] + [env_name]) - if new_env_name in cls.env_list: - return new_env_name visual_keys = [f"rgb:{c}:224x224:2d" for c in cams] if from_depths: visual_keys.extend([f"d:{c}:224x224:2d" for c in cams])