Skip to content

Commit

Permalink
Fix key array caching bug
Browse files Browse the repository at this point in the history
  • Loading branch information
hang-yin committed Oct 10, 2024
1 parent 2f60281 commit af526a3
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions omnigibson/utils/vision_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,14 @@ class Remapper:
def __init__(self):
self.key_array = th.empty(0, dtype=th.int32, device="cuda") # Initialize the key_array as empty
self.known_ids = set()
self.unlabelled_ids = set()
self.warning_printed = set()

def clear(self):
"""Resets the key_array to empty."""
self.key_array = th.empty(0, dtype=th.int32, device="cuda")
self.known_ids = set()
self.unlabelled_ids = set()

def remap(self, old_mapping, new_mapping, image, image_keys=None):
"""
Expand Down Expand Up @@ -109,6 +111,15 @@ def remap(self, old_mapping, new_mapping, image, image_keys=None):
# Copy the previous key array into the new key array
self.key_array[: len(prev_key_array)] = prev_key_array

# Retrospectively inspect our cached ids against the old mapping and update the key array
updated_ids = set()
for unlabelled_id in self.unlabelled_ids:
if unlabelled_id in old_mapping and old_mapping[unlabelled_id] != "unlabelled":
# If an object was previously unlabelled but now has a label, we need to update the key array
updated_ids.add(unlabelled_id)
self.unlabelled_ids -= updated_ids
self.known_ids -= updated_ids

new_keys = old_mapping.keys() - self.known_ids
if new_keys:
self.known_ids.update(new_keys)
Expand All @@ -118,6 +129,9 @@ def remap(self, old_mapping, new_mapping, image, image_keys=None):
new_key = next((k for k, v in new_mapping.items() if v == label), None)
assert new_key is not None, f"Could not find a new key for label {label} in new_mapping!"
self.key_array[key] = new_key
if label == "unlabelled":
# Some objects in the image might be unlabelled first but later get a valid label later, so we keep track of them
self.unlabelled_ids.add(key)

# For all the values that exist in the image but not in old_mapping.keys(), we map them to whichever key in
# new_mapping that equals to 'unlabelled'. This is needed because some values in the image don't necessarily
Expand Down

0 comments on commit af526a3

Please sign in to comment.