From 548f961a771896016c77fc0870537e6746c9789d Mon Sep 17 00:00:00 2001 From: EC2 Default User Date: Tue, 27 Aug 2024 10:03:36 +0000 Subject: [PATCH] add function to reset_state for a specific object --- sam2/sam2_video_predictor.py | 64 ++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index 8b2fd6c4..82ce5c54 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -769,6 +769,70 @@ def reset_state(self, inference_state): inference_state["output_dict_per_obj"].clear() inference_state["temp_output_dict_per_obj"].clear() + @torch.inference_mode() + def reset_state_for_objectId(self, inference_state, obj_id): + """Remove all input points, masks, and tracking results for a specific object ID.""" + if obj_id not in inference_state["obj_id_to_idx"]: + raise ValueError(f"Object ID {obj_id} not found in the current state.") + + obj_idx = inference_state["obj_id_to_idx"][obj_id] + + # Clear point inputs for the specific object + inference_state["point_inputs_per_obj"][obj_idx].clear() + + # Clear mask inputs for the specific object + inference_state["mask_inputs_per_obj"][obj_idx].clear() + + # Clear output dict for the specific object + inference_state["output_dict_per_obj"][obj_idx]["cond_frame_outputs"].clear() + inference_state["output_dict_per_obj"][obj_idx]["non_cond_frame_outputs"].clear() + + # Clear temporary output dict for the specific object + inference_state["temp_output_dict_per_obj"][obj_idx]["cond_frame_outputs"].clear() + inference_state["temp_output_dict_per_obj"][obj_idx]["non_cond_frame_outputs"].clear() + + # Update the main output_dict to remove the specific object's data + for storage_key in ["cond_frame_outputs", "non_cond_frame_outputs"]: + for frame_idx, frame_data in inference_state["output_dict"][storage_key].items(): + if "pred_masks" in frame_data: + frame_data["pred_masks"] = torch.cat([ + frame_data["pred_masks"][:obj_idx], + frame_data["pred_masks"][obj_idx+1:] + ]) + if "obj_ptr" in frame_data: + frame_data["obj_ptr"] = torch.cat([ + frame_data["obj_ptr"][:obj_idx], + frame_data["obj_ptr"][obj_idx+1:] + ]) + if "maskmem_features" in frame_data and frame_data["maskmem_features"] is not None: + frame_data["maskmem_features"] = torch.cat([ + frame_data["maskmem_features"][:obj_idx], + frame_data["maskmem_features"][obj_idx+1:] + ]) + + # Remove the object from the mapping dictionaries + del inference_state["obj_id_to_idx"][obj_id] + del inference_state["obj_idx_to_id"][obj_idx] + inference_state["obj_ids"].remove(obj_id) + + # Shift the remaining object indices + for remaining_id, remaining_idx in list(inference_state["obj_id_to_idx"].items()): + if remaining_idx > obj_idx: + inference_state["obj_id_to_idx"][remaining_id] -= 1 + inference_state["obj_idx_to_id"][remaining_idx - 1] = remaining_id + del inference_state["obj_idx_to_id"][remaining_idx] + + # Update other data structures that depend on object indices + for data_dict in [inference_state["point_inputs_per_obj"], + inference_state["mask_inputs_per_obj"], + inference_state["output_dict_per_obj"], + inference_state["temp_output_dict_per_obj"]]: + for idx in range(obj_idx, len(data_dict) - 1): + data_dict[idx] = data_dict[idx + 1] + del data_dict[len(data_dict) - 1] + + + def _reset_tracking_results(self, inference_state): """Reset all tracking inputs and results across the videos.""" for v in inference_state["point_inputs_per_obj"].values():