Skip to content

Commit

Permalink
ignore collisions
Browse files Browse the repository at this point in the history
  • Loading branch information
beneisner committed May 13, 2024
1 parent d2ad10c commit 657288a
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 3 deletions.
72 changes: 69 additions & 3 deletions notebooks/explore_dset2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"source": [
"from rpad.rlbench_utils.placement_dataset import RLBenchPlacementDataset, load_handle_mapping, load_state_pos_dict, TASK_DICT\n",
"import numpy as np\n",
"\n",
"from rpad.rlbench_utils.task_info import RLBENCH_10_TASKS\n",
"from rpad.visualize_3d.plots import segmentation_fig\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
Expand All @@ -32,7 +32,7 @@
"outputs": [],
"source": [
"dset = RLBenchPlacementDataset(\n",
" dataset_root=\"/data/rlbench10/\",\n",
" dataset_root=\"/data/rlbench10_collisions/\",\n",
" # task_name=\"stack_wine\",\n",
" # task_name=\"insert_onto_square_peg\",\n",
" # task_name=\"insert_usb_in_computer\",\n",
Expand All @@ -42,11 +42,38 @@
" # task_name=\"solve_puzzle\",\n",
" # task_name=\"take_umbrella_out_of_umbrella_stand\",\n",
" task_name=\"slide_block_to_target\",\n",
" demos=range(10),\n",
" demos=range(100),\n",
" phase=\"all\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = dset[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data[\"ignore_collisions\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data.keys()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -707,6 +734,45 @@
"unique_elements"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for task_name in RLBENCH_10_TASKS:\n",
" print(\"--------------------\")\n",
" print(f\"Task: {task_name}\")\n",
" print(\"--------------------\")\n",
" for phase in TASK_DICT[task_name][\"phase\"].keys():\n",
"\n",
" dset = RLBenchPlacementDataset(\n",
" dataset_root=\"/data/rlbench10_collisions/\",\n",
" # task_name=\"stack_wine\",\n",
" # task_name=\"insert_onto_square_peg\",\n",
" # task_name=\"insert_usb_in_computer\",\n",
" # task_name=\"phone_on_base\",\n",
" # task_name=\"put_toilet_roll_on_stand\",\n",
" # task_name=\"place_hanger_on_rack\",\n",
" # task_name=\"solve_puzzle\",\n",
" # task_name=\"take_umbrella_out_of_umbrella_stand\",\n",
" task_name=task_name,\n",
" demos=range(100),\n",
" phase=phase,\n",
" ) \n",
" ignore_collisions_all = []\n",
" for i in range(len(dset)):\n",
" try:\n",
" data = dset[i]\n",
" ignore_collisions_all.append(data[\"ignore_collisions\"])\n",
" except:\n",
" print(f\"Error in task {task_name}, phase {phase}, demo {i}\")\n",
" ignore_all = (np.array(ignore_collisions_all).any()) \n",
" print(f\"Phase: {phase}; Ignore Collisions: {ignore_all}\")\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
7 changes: 7 additions & 0 deletions src/rpad/rlbench_utils/placement_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,12 @@ def extract_pose(obs, key):
T_init_key = T_action_key_world @ np.linalg.inv(T_action_init_world)
T_anchor_key_world = extract_pose(key_obs, "anchor_pose_name")

if hasattr(initial_obs, "ignore_collisions"):
ignore_collisions = initial_obs.ignore_collisions
ignore_collisions = torch.from_numpy(ignore_collisions.astype(np.int32))
else:
ignore_collisions = None

return {
"init_action_rgb": torch.from_numpy(init_action_rgb),
"init_action_pc": torch.from_numpy(init_action_point_cloud),
Expand All @@ -489,4 +495,5 @@ def extract_pose(obs, key):
"key_front_mask": torch.from_numpy(key_obs.front_mask.astype(np.int32)),
"phase": phase,
"phase_onehot": torch.from_numpy(phase_onehot),
"ignore_collisions": ignore_collisions,
}

0 comments on commit 657288a

Please sign in to comment.