Skip to content

Commit

Permalink
added wrist optionally
Browse files Browse the repository at this point in the history
  • Loading branch information
beneisner committed May 21, 2024
1 parent 46bba6e commit a791ba1
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions src/rpad/rlbench_utils/placement_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_rgb_point_cloud_by_object_names(rgb, point_cloud, seg_labels, names):
return get_rgb_point_cloud_by_object_handles(rgb, point_cloud, seg_labels, handles)


def obs_to_rgb_point_cloud(obs):
def obs_to_rgb_point_cloud(obs, include_wrist_cam=False):
# Get the overhead, left, front, and right RGB images.
overhead_rgb = obs.overhead_rgb
left_rgb = obs.left_shoulder_rgb
Expand Down Expand Up @@ -84,31 +84,25 @@ def obs_to_rgb_point_cloud(obs):

# Stack the RGB and point cloud images together.
rgb = np.vstack(
(
overhead_rgb,
left_rgb,
right_rgb,
front_rgb,
# wrist_rgb,
)
(overhead_rgb, left_rgb, right_rgb, front_rgb)
if not include_wrist_cam
else (overhead_rgb, left_rgb, right_rgb, front_rgb, wrist_rgb)
)
point_cloud = np.vstack(
(
(overhead_point_cloud, left_point_cloud, right_point_cloud, front_point_cloud)
if not include_wrist_cam
else (
overhead_point_cloud,
left_point_cloud,
right_point_cloud,
front_point_cloud,
# wrist_point_cloud,
wrist_point_cloud,
)
)
mask = np.vstack(
(
overhead_mask,
left_mask,
right_mask,
front_mask,
# wrist_mask,
)
(overhead_mask, left_mask, right_mask, front_mask)
if not include_wrist_cam
else (overhead_mask, left_mask, right_mask, front_mask, wrist_mask)
)

return rgb, point_cloud, mask
Expand Down Expand Up @@ -284,6 +278,7 @@ def __init__(
debugging: bool = False,
anchor_mode: AnchorMode = AnchorMode.SINGLE_OBJECT,
action_mode: ActionMode = ActionMode.OBJECT,
include_wrist_cam: bool = False,
) -> None:
"""Dataset for RL-Bench placement tasks.
Expand All @@ -309,6 +304,7 @@ def __init__(
self.variation = 0
self.debugging = debugging
self.use_first_as_init_keyframe = use_first_as_init_keyframe
self.include_wrist_cam = include_wrist_cam

if self.task_name not in TASK_DICT:
raise ValueError(f"Task name {self.task_name} not supported.")
Expand Down Expand Up @@ -456,7 +452,9 @@ def _select_anchor_vals(rgb, point_cloud, mask):
)

# Merge all the initial point clouds and masks into one.
init_rgb, init_point_cloud, init_mask = obs_to_rgb_point_cloud(initial_obs)
init_rgb, init_point_cloud, init_mask = obs_to_rgb_point_cloud(
initial_obs, self.include_wrist_cam
)

init_action_rgb, init_action_point_cloud = _select_action_vals(
init_rgb, init_point_cloud, init_mask
Expand All @@ -467,7 +465,9 @@ def _select_anchor_vals(rgb, point_cloud, mask):
)

# Merge all the key point clouds and masks into one.
key_rgb, key_point_cloud, key_mask = obs_to_rgb_point_cloud(key_obs)
key_rgb, key_point_cloud, key_mask = obs_to_rgb_point_cloud(
key_obs, self.include_wrist_cam
)

# Split the key point cloud and rgb into action and anchor.
key_action_rgb, key_action_point_cloud = _select_action_vals(
Expand Down

0 comments on commit a791ba1

Please sign in to comment.