Skip to content

Commit

Permalink
Use RigidContactAPI for TouchingAnyCondition
Browse files Browse the repository at this point in the history
  • Loading branch information
hang-yin committed Oct 21, 2024
1 parent 8fb9fdf commit e562915
Showing 1 changed file with 37 additions and 8 deletions.
45 changes: 37 additions & 8 deletions omnigibson/transition_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from omnigibson.utils.python_utils import Registerable, classproperty, subclass_factory, torch_delete
from omnigibson.utils.registry_utils import Registry
from omnigibson.utils.ui_utils import create_module_logger, disclaimer
from omnigibson.utils.usd_utils import RigidContactAPI

# Create module logger
log = create_module_logger(module_name=__name__)
Expand Down Expand Up @@ -333,6 +334,10 @@ class TouchingAnyCondition(RuleCondition):
"""
Rule condition that prunes object candidates from @filter_1_name, only keeping any that are touching any object
from @filter_2_name
Note that this condition uses the RigidContactAPI for contact checking. This is not a persistent contact check,
meaning that if objects get in contact for some time and both fall asleep, the contact will not be detected.
To get persistent contact checking, please use contact_sensor.
"""

def __init__(self, filter_1_name, filter_2_name):
Expand All @@ -346,21 +351,45 @@ def __init__(self, filter_1_name, filter_2_name):
self._filter_1_name = filter_1_name
self._filter_2_name = filter_2_name

# Maps object to set of rigid bodies corresponding to filter 2
self._filter_2_bodies = None
# Will be filled in during self.initialize
# Maps object to the list of rigid body idxs in the global contact matrix corresponding to filter 1
self._filter_1_idxs = None

# If optimized, filter_2_idxs will be used, otherwise filter_2_bodies will be used!
# Maps object to the list of rigid body idxs in the global contact matrix corresponding to filter 2
self._filter_2_idxs = None

def refresh(self, object_candidates):
# Register body mappings
self._filter_2_bodies = {obj: set(obj.links.values()) for obj in object_candidates[self._filter_2_name]}
# Register idx mappings
self._filter_1_idxs = {
obj: [RigidContactAPI.get_body_row_idx(link.prim_path)[1] for link in obj.links.values()]
for obj in object_candidates[self._filter_1_name]
}
self._filter_2_idxs = {
obj: th.tensor(
[RigidContactAPI.get_body_col_idx(link.prim_path)[1] for link in obj.links.values()],
dtype=th.float32,
)
for obj in object_candidates[self._filter_2_name]
}

def __call__(self, object_candidates):
# Keep any of the @filter_2_name's objects
# Keep any object that has non-zero impulses between itself and any of the @filter_2_name's objects
objs = []

# Manually check contact
filter_2_bodies = set.union(*(self._filter_2_bodies[obj] for obj in object_candidates[self._filter_2_name]))
# Batch check for each object
for obj in object_candidates[self._filter_1_name]:
if len(obj.states[ContactBodies].get_value().intersection(filter_2_bodies)) > 0:
# Get all impulses between @obj and any object in @filter_2_name that are in the same scene
idxs_to_check = th.cat(
[
self._filter_2_idxs[obj2]
for obj2 in object_candidates[self._filter_2_name]
if obj2.scene == obj.scene
]
)
if th.any(
RigidContactAPI.get_all_impulses(obj.scene.idx)[self._filter_1_idxs[obj]][:, idxs_to_check.tolist()]
):
objs.append(obj)

# Update candidates
Expand Down

0 comments on commit e562915

Please sign in to comment.