Skip to content

Commit

Permalink
unit tests with from_depths
Browse files Browse the repository at this point in the history
  • Loading branch information
sriramsk1999 committed Apr 20, 2024
1 parent 7d8c693 commit 9c39412
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
7 changes: 4 additions & 3 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3396,10 +3396,11 @@ class TestRoboHive:
# In the CI, robohive should not coexist with other libs so that's fine.
# Robohive logging behaviour can be controlled via ROBOHIVE_VERBOSITY=ALL/INFO/(WARN)/ERROR/ONCE/ALWAYS/SILENT
@pytest.mark.parametrize("from_pixels", [False, True])
@pytest.mark.parametrize("from_depths", [False, True])
@pytest.mark.parametrize("envname", RoboHiveEnv.available_envs)
def test_robohive(self, envname, from_pixels):
def test_robohive(self, envname, from_pixels, from_depths):
with set_gym_backend("gymnasium"):
torchrl_logger.info(f"{envname}-{from_pixels}")
torchrl_logger.info(f"{envname}-{from_pixels}-{from_depths}")
if any(
substr in envname for substr in ("_vr3m", "_vrrl", "_vflat", "_vvc1s")
):
Expand All @@ -3415,7 +3416,7 @@ def test_robohive(self, envname, from_pixels):
torchrl_logger.info("no camera")
return
try:
env = RoboHiveEnv(envname, from_pixels=from_pixels)
env = RoboHiveEnv(envname, from_pixels=from_pixels, from_depths=from_depths)
except AttributeError as err:
if "'MjData' object has no attribute 'get_body_xipos'" in str(err):
torchrl_logger.info("tcdm are broken")
Expand Down
2 changes: 0 additions & 2 deletions torchrl/envs/libs/robohive.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,6 @@ def register_visual_env(cls, env_name, cams, from_depths):
cams = sorted(cams)
cams_rep = [i.replace("A:", "A_") for i in cams]
new_env_name = "-".join([cam[:-3] for cam in cams_rep] + [env_name])
if new_env_name in cls.env_list:
return new_env_name
visual_keys = [f"rgb:{c}:224x224:2d" for c in cams]
if from_depths:
visual_keys.extend([f"d:{c}:224x224:2d" for c in cams])
Expand Down

0 comments on commit 9c39412

Please sign in to comment.