Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Functionality to reset_state for specific object #268

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions sam2/sam2_video_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,70 @@ def reset_state(self, inference_state):
inference_state["output_dict_per_obj"].clear()
inference_state["temp_output_dict_per_obj"].clear()

@torch.inference_mode()
def reset_state_for_objectId(self, inference_state, obj_id):
"""Remove all input points, masks, and tracking results for a specific object ID."""
if obj_id not in inference_state["obj_id_to_idx"]:
raise ValueError(f"Object ID {obj_id} not found in the current state.")

obj_idx = inference_state["obj_id_to_idx"][obj_id]

# Clear point inputs for the specific object
inference_state["point_inputs_per_obj"][obj_idx].clear()

# Clear mask inputs for the specific object
inference_state["mask_inputs_per_obj"][obj_idx].clear()

# Clear output dict for the specific object
inference_state["output_dict_per_obj"][obj_idx]["cond_frame_outputs"].clear()
inference_state["output_dict_per_obj"][obj_idx]["non_cond_frame_outputs"].clear()

# Clear temporary output dict for the specific object
inference_state["temp_output_dict_per_obj"][obj_idx]["cond_frame_outputs"].clear()
inference_state["temp_output_dict_per_obj"][obj_idx]["non_cond_frame_outputs"].clear()

# Update the main output_dict to remove the specific object's data
for storage_key in ["cond_frame_outputs", "non_cond_frame_outputs"]:
for frame_idx, frame_data in inference_state["output_dict"][storage_key].items():
if "pred_masks" in frame_data:
frame_data["pred_masks"] = torch.cat([
frame_data["pred_masks"][:obj_idx],
frame_data["pred_masks"][obj_idx+1:]
])
if "obj_ptr" in frame_data:
frame_data["obj_ptr"] = torch.cat([
frame_data["obj_ptr"][:obj_idx],
frame_data["obj_ptr"][obj_idx+1:]
])
if "maskmem_features" in frame_data and frame_data["maskmem_features"] is not None:
frame_data["maskmem_features"] = torch.cat([
frame_data["maskmem_features"][:obj_idx],
frame_data["maskmem_features"][obj_idx+1:]
])

# Remove the object from the mapping dictionaries
del inference_state["obj_id_to_idx"][obj_id]
del inference_state["obj_idx_to_id"][obj_idx]
inference_state["obj_ids"].remove(obj_id)

# Shift the remaining object indices
for remaining_id, remaining_idx in list(inference_state["obj_id_to_idx"].items()):
if remaining_idx > obj_idx:
inference_state["obj_id_to_idx"][remaining_id] -= 1
inference_state["obj_idx_to_id"][remaining_idx - 1] = remaining_id
del inference_state["obj_idx_to_id"][remaining_idx]

# Update other data structures that depend on object indices
for data_dict in [inference_state["point_inputs_per_obj"],
inference_state["mask_inputs_per_obj"],
inference_state["output_dict_per_obj"],
inference_state["temp_output_dict_per_obj"]]:
for idx in range(obj_idx, len(data_dict) - 1):
data_dict[idx] = data_dict[idx + 1]
del data_dict[len(data_dict) - 1]



def _reset_tracking_results(self, inference_state):
"""Reset all tracking inputs and results across the videos."""
for v in inference_state["point_inputs_per_obj"].values():
Expand Down