Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Key Array Caching #962

Merged
merged 5 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions omnigibson/sensors/vision_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,6 @@ def _remap_modality(self, modality, obs, info, raw_obs):
obs[modality], info[modality] = self._remap_instance_segmentation(
obs[modality],
id_to_labels,
obs["seg_semantic"],
info["seg_semantic"],
id=(modality == "seg_instance_id"),
)
elif "bbox" in modality:
Expand Down Expand Up @@ -387,16 +385,14 @@ def _remap_semantic_segmentation(self, img, id_to_labels):

return VisionSensor.SEMANTIC_REMAPPER.remap(replicator_mapping, semantic_class_id_to_name(), img, image_keys)

def _remap_instance_segmentation(self, img, id_to_labels, semantic_img, semantic_labels, id=False):
def _remap_instance_segmentation(self, img, id_to_labels, id=False):
"""
Remap the instance segmentation image to our own instance IDs.
Also, correct the id_to_labels input with our new labels and return it.

Args:
img (th.tensor): Instance segmentation image to remap
id_to_labels (dict): Dictionary of instance IDs to class labels
semantic_img (th.tensor): Semantic segmentation image to use for instance registry
semantic_labels (dict): Dictionary of semantic IDs to class labels
id (bool): Whether to remap for instance ID segmentation
Returns:
th.tensor: Remapped instance segmentation image
Expand Down
14 changes: 14 additions & 0 deletions omnigibson/utils/vision_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,14 @@ class Remapper:
def __init__(self):
self.key_array = th.empty(0, dtype=th.int32, device="cuda") # Initialize the key_array as empty
self.known_ids = set()
self.unlabelled_ids = set()
self.warning_printed = set()

def clear(self):
"""Resets the key_array to empty."""
self.key_array = th.empty(0, dtype=th.int32, device="cuda")
self.known_ids = set()
self.unlabelled_ids = set()

def remap(self, old_mapping, new_mapping, image, image_keys=None):
"""
Expand Down Expand Up @@ -109,6 +111,15 @@ def remap(self, old_mapping, new_mapping, image, image_keys=None):
# Copy the previous key array into the new key array
self.key_array[: len(prev_key_array)] = prev_key_array

# Retrospectively inspect our cached ids against the old mapping and update the key array
updated_ids = set()
for unlabelled_id in self.unlabelled_ids:
if unlabelled_id in old_mapping and old_mapping[unlabelled_id] != "unlabelled":
# If an object was previously unlabelled but now has a label, we need to update the key array
updated_ids.add(unlabelled_id)
self.unlabelled_ids -= updated_ids
self.known_ids -= updated_ids

new_keys = old_mapping.keys() - self.known_ids
if new_keys:
self.known_ids.update(new_keys)
Expand All @@ -118,6 +129,9 @@ def remap(self, old_mapping, new_mapping, image, image_keys=None):
new_key = next((k for k, v in new_mapping.items() if v == label), None)
assert new_key is not None, f"Could not find a new key for label {label} in new_mapping!"
self.key_array[key] = new_key
if label == "unlabelled":
# Some objects in the image might be unlabelled first but later get a valid label later, so we keep track of them
self.unlabelled_ids.add(key)

# For all the values that exist in the image but not in old_mapping.keys(), we map them to whichever key in
# new_mapping that equals to 'unlabelled'. This is needed because some values in the image don't necessarily
Expand Down
Loading