From e562915418d89e677d4f988d1149e42b56416acc Mon Sep 17 00:00:00 2001 From: hang-yin Date: Mon, 21 Oct 2024 11:36:36 -0700 Subject: [PATCH] Use RigidContactAPI for TouchingAnyCondition --- omnigibson/transition_rules.py | 45 ++++++++++++++++++++++++++++------ 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/omnigibson/transition_rules.py b/omnigibson/transition_rules.py index f99c89fd1..ecdca6e87 100644 --- a/omnigibson/transition_rules.py +++ b/omnigibson/transition_rules.py @@ -25,6 +25,7 @@ from omnigibson.utils.python_utils import Registerable, classproperty, subclass_factory, torch_delete from omnigibson.utils.registry_utils import Registry from omnigibson.utils.ui_utils import create_module_logger, disclaimer +from omnigibson.utils.usd_utils import RigidContactAPI # Create module logger log = create_module_logger(module_name=__name__) @@ -333,6 +334,10 @@ class TouchingAnyCondition(RuleCondition): """ Rule condition that prunes object candidates from @filter_1_name, only keeping any that are touching any object from @filter_2_name + + Note that this condition uses the RigidContactAPI for contact checking. This is not a persistent contact check, + meaning that if objects get in contact for some time and both fall asleep, the contact will not be detected. + To get persistent contact checking, please use contact_sensor. """ def __init__(self, filter_1_name, filter_2_name): @@ -346,21 +351,45 @@ def __init__(self, filter_1_name, filter_2_name): self._filter_1_name = filter_1_name self._filter_2_name = filter_2_name - # Maps object to set of rigid bodies corresponding to filter 2 - self._filter_2_bodies = None + # Will be filled in during self.initialize + # Maps object to the list of rigid body idxs in the global contact matrix corresponding to filter 1 + self._filter_1_idxs = None + + # If optimized, filter_2_idxs will be used, otherwise filter_2_bodies will be used! + # Maps object to the list of rigid body idxs in the global contact matrix corresponding to filter 2 + self._filter_2_idxs = None def refresh(self, object_candidates): - # Register body mappings - self._filter_2_bodies = {obj: set(obj.links.values()) for obj in object_candidates[self._filter_2_name]} + # Register idx mappings + self._filter_1_idxs = { + obj: [RigidContactAPI.get_body_row_idx(link.prim_path)[1] for link in obj.links.values()] + for obj in object_candidates[self._filter_1_name] + } + self._filter_2_idxs = { + obj: th.tensor( + [RigidContactAPI.get_body_col_idx(link.prim_path)[1] for link in obj.links.values()], + dtype=th.float32, + ) + for obj in object_candidates[self._filter_2_name] + } def __call__(self, object_candidates): - # Keep any of the @filter_2_name's objects + # Keep any object that has non-zero impulses between itself and any of the @filter_2_name's objects objs = [] - # Manually check contact - filter_2_bodies = set.union(*(self._filter_2_bodies[obj] for obj in object_candidates[self._filter_2_name])) + # Batch check for each object for obj in object_candidates[self._filter_1_name]: - if len(obj.states[ContactBodies].get_value().intersection(filter_2_bodies)) > 0: + # Get all impulses between @obj and any object in @filter_2_name that are in the same scene + idxs_to_check = th.cat( + [ + self._filter_2_idxs[obj2] + for obj2 in object_candidates[self._filter_2_name] + if obj2.scene == obj.scene + ] + ) + if th.any( + RigidContactAPI.get_all_impulses(obj.scene.idx)[self._filter_1_idxs[obj]][:, idxs_to_check.tolist()] + ): objs.append(obj) # Update candidates