Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Return depth from RoboHiveEnv #2058

Merged
merged 6 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3396,10 +3396,11 @@ class TestRoboHive:
# In the CI, robohive should not coexist with other libs so that's fine.
# Robohive logging behaviour can be controlled via ROBOHIVE_VERBOSITY=ALL/INFO/(WARN)/ERROR/ONCE/ALWAYS/SILENT
@pytest.mark.parametrize("from_pixels", [False, True])
@pytest.mark.parametrize("from_depths", [False, True])
@pytest.mark.parametrize("envname", RoboHiveEnv.available_envs)
def test_robohive(self, envname, from_pixels):
def test_robohive(self, envname, from_pixels, from_depths):
with set_gym_backend("gymnasium"):
torchrl_logger.info(f"{envname}-{from_pixels}")
torchrl_logger.info(f"{envname}-{from_pixels}-{from_depths}")
if any(
substr in envname for substr in ("_vr3m", "_vrrl", "_vflat", "_vvc1s")
):
Expand All @@ -3415,7 +3416,7 @@ def test_robohive(self, envname, from_pixels):
torchrl_logger.info("no camera")
return
try:
env = RoboHiveEnv(envname, from_pixels=from_pixels)
env = RoboHiveEnv(envname, from_pixels=from_pixels, from_depths=from_depths)
except AttributeError as err:
if "'MjData' object has no attribute 'get_body_xipos'" in str(err):
torchrl_logger.info("tcdm are broken")
Expand Down
36 changes: 26 additions & 10 deletions torchrl/envs/libs/robohive.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ class RoboHiveEnv(GymEnv, metaclass=_RoboHiveBuild):
be returned (by default under the ``"pixels"`` entry in the output tensordict).
If ``False``, observations (eg, states) and pixels will be returned
whenever ``from_pixels=True``. Defaults to ``True``.
from_depths (bool, optional): if ``True``, an attempt to return the depth
observations from the env will be performed. By default, these observations
will be written under the ``"depths"`` entry. Requires ``from_pixels`` to be ``True``.
Defaults to ``False``.
frame_skip (int, optional): if provided, indicates for how many steps the
same action is to be repeated. The observation returned will be the
last observation of the sequence, whereas the reward will be the sum
Expand Down Expand Up @@ -155,6 +159,7 @@ def _build_env( # noqa: F811
env_name: str,
from_pixels: bool = False,
pixels_only: bool = False,
from_depths: bool = False,
**kwargs,
) -> "gym.core.Env": # noqa: F821
if from_pixels:
Expand All @@ -168,7 +173,9 @@ def _build_env( # noqa: F811
)
kwargs["cameras"] = self.get_available_cams(env_name)
cams = list(kwargs.pop("cameras"))
env_name = self.register_visual_env(cams=cams, env_name=env_name)
env_name = self.register_visual_env(
cams=cams, env_name=env_name, from_depths=from_depths
)

elif "cameras" in kwargs and kwargs["cameras"]:
raise RuntimeError("Got a list of cameras but from_pixels is set to False.")
Expand All @@ -194,10 +201,6 @@ def _build_env( # noqa: F811
**kwargs,
)
self.wrapper_frame_skip = 1
if env.visual_keys:
from_pixels = bool(len(env.visual_keys))
else:
from_pixels = False
except TypeError as err:
if "unexpected keyword argument 'frameskip" not in str(err):
raise err
Expand All @@ -209,6 +212,7 @@ def _build_env( # noqa: F811
# except Exception as err:
# raise RuntimeError(f"Failed to build env {env_name}.") from err
self.from_pixels = from_pixels
self.from_depths = from_depths
self.render_device = render_device
if kwargs.get("read_info", True):
self.set_info_dict_reader(self.read_info)
Expand All @@ -224,7 +228,7 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821
return out

@classmethod
def register_visual_env(cls, env_name, cams):
def register_visual_env(cls, env_name, cams, from_depths):
with set_directory(cls.CURR_DIR):
from robohive.envs.env_variants import register_env_variant

Expand All @@ -233,9 +237,9 @@ def register_visual_env(cls, env_name, cams):
cams = sorted(cams)
cams_rep = [i.replace("A:", "A_") for i in cams]
new_env_name = "-".join([cam[:-3] for cam in cams_rep] + [env_name])
if new_env_name in cls.env_list:
return new_env_name
visual_keys = [f"rgb:{c}:224x224:2d" for c in cams]
if from_depths:
visual_keys.extend([f"d:{c}:224x224:2d" for c in cams])
register_env_variant(
env_name,
variants={
Expand All @@ -262,20 +266,26 @@ def get_obs():
if self.from_pixels:
visual = self.env.get_exteroception()
obs_dict.update(visual)
pixel_list = []
pixel_list, depth_list = [], []
for obs_key in obs_dict:
if obs_key.startswith("rgb"):
pix = obs_dict[obs_key]
if not pix.shape[0] == 1:
pix = pix[None]
pixel_list.append(pix)
elif obs_key.startswith("d:"):
dep = obs_dict[obs_key]
dep = dep[None]
depth_list.append(dep)
elif obs_key in env.obs_keys:
value = env.obs_dict[obs_key]
if not value.shape:
value = value[None]
_dict[obs_key] = value
if pixel_list:
_dict["pixels"] = np.concatenate(pixel_list, 0)
if depth_list:
_dict["depths"] = np.concatenate(depth_list, 0)
return _dict

for i in range(3):
Expand Down Expand Up @@ -335,7 +345,7 @@ def read_obs(self, observation):
pass
# recover vec
obsdict = {}
pixel_list = []
pixel_list, depth_list = [], []
if self.from_pixels:
visual = self.env.get_exteroception()
observations.update(visual)
Expand All @@ -345,6 +355,10 @@ def read_obs(self, observation):
if not pix.shape[0] == 1:
pix = pix[None]
pixel_list.append(pix)
elif key.startswith("d:"):
dep = observations[key]
dep = dep[None]
depth_list.append(dep)
elif key in self._env.obs_keys:
value = observations[key]
if not value.shape:
Expand All @@ -354,6 +368,8 @@ def read_obs(self, observation):
# obsvec = np.concatenate(obsvec, 0)
if self.from_pixels:
obsdict.update({"pixels": np.concatenate(pixel_list, 0)})
if self.from_pixels and self.from_depths:
obsdict.update({"depths": np.concatenate(depth_list, 0)})
out = obsdict
return super().read_obs(out)

Expand Down
Loading