diff --git a/tripy/examples/segment-anything-model-v2/.gitignore b/tripy/examples/segment-anything-model-v2/.gitignore new file mode 100644 index 000000000..90594bd02 --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/.gitignore @@ -0,0 +1,3 @@ +checkpoints/ +saved_engines/ +output/ diff --git a/tripy/examples/segment-anything-model-v2/LICENSE_sam2 b/tripy/examples/segment-anything-model-v2/LICENSE_sam2 new file mode 100644 index 000000000..261eeb9e9 --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/LICENSE_sam2 @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/tripy/examples/segment-anything-model-v2/README.md b/tripy/examples/segment-anything-model-v2/README.md new file mode 100644 index 000000000..b8001f160 --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/README.md @@ -0,0 +1,42 @@ +# SAM2: Segment Anything in Images and Videos + +## Introduction + +This is an implementation of SAM2 model ([original repository](https://github.com/facebookresearch/sam2/tree/main) by Meta). + +## Running The Example + +### Image pipeline + +1. Install prerequisites: + + ```bash + sudo apt-get update && sudo apt-get install ffmpeg libsm6 libxext6 -y + wget -O truck.jpg https://raw.githubusercontent.com/facebookresearch/sam2/main/notebooks/images/truck.jpg + mkdir checkpoints && cd checkpoints && wget https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt + python3 -m pip install -r requirements.txt + ``` + +2. Run the example: + + ```bash + python3 image_demo.py + ``` + + + + + +### Video segmentation pipeline + +TBD + + +## License +The SAM2 model checkpoints and associated model code are sourced from Meta's [SAM2 repository](https://github.com/facebookresearch/sam2/tree/main) and are licensed under the Apache 2.0 license (included as LICENSE_sam2 in our sample). diff --git a/tripy/examples/segment-anything-model-v2/configs/__init__.py b/tripy/examples/segment-anything-model-v2/configs/__init__.py new file mode 100644 index 000000000..4e3ee0298 --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/configs/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tripy/examples/segment-anything-model-v2/configs/sam2_hiera_l.yaml b/tripy/examples/segment-anything-model-v2/configs/sam2_hiera_l.yaml new file mode 100644 index 000000000..6c84c6089 --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/configs/sam2_hiera_l.yaml @@ -0,0 +1,126 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 144 + num_heads: 2 + stages: [2, 6, 36, 4] + global_att_blocks: [23, 33, 43] + window_pos_embed_bkg_spatial_size: [7, 7] + window_spec: [8, 4, 16, 8] + dtype: float16 + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [1152, 576, 288, 144] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + dtype: float16 + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + dtype: float16 + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + dtype: float16 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + dtype: float16 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + dtype: float16 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + dtype: float16 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + dtype: float16 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + dtype: float16 + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + dtype: float16 + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + model_precision: float16 diff --git a/tripy/examples/segment-anything-model-v2/image_demo.py b/tripy/examples/segment-anything-model-v2/image_demo.py new file mode 100644 index 000000000..43107c833 --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/image_demo.py @@ -0,0 +1,252 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import cv2 +import os +import time +import numpy as np +import torch +import tripy as tp +import matplotlib.pyplot as plt + +plt.switch_backend("agg") # Switch to non-interactive backend +from PIL import Image +from typing import Tuple, Optional, Dict + +from sam2.build_sam import build_sam2 +from sam2.sam2_image_predictor import SAM2ImagePredictor + +parser = argparse.ArgumentParser() +parser.add_argument("-b", "--batch", type=int, default=2, help="batch size of the input images, between [1, 4]") + + +def process_and_show_mask( + mask: np.ndarray, ax: plt.Axes, random_color: bool = False, borders: bool = True +) -> np.ndarray: + """ + Process and display a segmentation mask, returning the processed mask for testing. + """ + # Generate mask color + if random_color: + color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) + else: + color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) + + # Process mask + h, w = mask.shape[-2:] + mask = mask.astype(np.uint8) + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + + if borders: + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] + mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) + + ax.imshow(mask_image) + return mask_image + + +def show_points( + coords: np.ndarray, labels: np.ndarray, ax: plt.Axes, marker_size: int = 375 +) -> Tuple[np.ndarray, np.ndarray]: + """ + Display point prompts and return point coordinates for testing. + """ + pos_points = coords[labels == 1] + neg_points = coords[labels == 0] + + ax.scatter( + pos_points[:, 0], + pos_points[:, 1], + color="green", + marker="*", + s=marker_size, + edgecolor="white", + linewidth=1.25, + ) + ax.scatter( + neg_points[:, 0], + neg_points[:, 1], + color="red", + marker="*", + s=marker_size, + edgecolor="white", + linewidth=1.25, + ) + + return pos_points, neg_points + + +def show_box(box: np.ndarray, ax: plt.Axes) -> Tuple[float, float, float, float]: + """ + Display a bounding box and return its coordinates for testing. + """ + x0, y0 = box[0], box[1] + w, h = box[2] - box[0], box[3] - box[1] + ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)) + return x0, y0, w, h + + +def process_predictions( + image: np.ndarray, + masks: np.ndarray, + scores: np.ndarray, + logits: np.ndarray, + point_coords: Optional[np.ndarray] = None, + box_coords: Optional[np.ndarray] = None, + input_labels: Optional[np.ndarray] = None, + save_path: Optional[str] = None, +) -> Dict[str, np.ndarray]: + """ + Process and visualize predictions, returning a dictionary containing processed masks, scores, and logits. + """ + processed_masks = [] + + # Create output directory if it doesn't exist + if save_path: + os.makedirs(save_path, exist_ok=True) + + for i, (mask, score) in enumerate(zip(masks, scores)): + + fig, ax = plt.subplots(figsize=(10, 10)) + ax.imshow(image) + + processed_mask = process_and_show_mask(mask, ax) + processed_masks.append(processed_mask) + + if point_coords is not None: + assert input_labels is not None, "Input labels required for point prompts" + show_points(point_coords, input_labels, ax) + + if box_coords is not None: + show_box(box_coords, ax) + + if len(scores) > 1: + ax.set_title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18) + + ax.axis("off") + + if save_path: + plt.savefig(os.path.join(save_path, f"mask_{i}_score_{score:.3f}.png"), bbox_inches="tight", pad_inches=0) + plt.close(fig) + + print(f"Scores for each prediction: {' '.join(map(str, scores))}") + + return { + "masks": np.array(processed_masks), + "scores": scores, + "logits": logits, + } + + +def main(image_path: str, save_path: Optional[str] = None): + """ + Main execution function. + + Args: + image_path (str): Path to input image + save_path (str, optional): Directory to save visualizations + + Returns: + Dict[str, np.ndarray]: Processing results + """ + + args = parser.parse_args() + + # Load image + image = np.array(Image.open(image_path).convert("RGB")) + image_list = [image] * args.batch + + # Initialize SAM2 model + sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt" + model_cfg = "sam2_hiera_l.yaml" + device = torch.device("cuda") + sam2_model = build_sam2( + model_cfg, + sam2_checkpoint, + device=device, + ) + + # Create predictor and process image + predictor = SAM2ImagePredictor(sam2_model) + + # Set input prompt + input_point = np.array([[500, 375]]) + input_label = np.array([1]) + + def time_function(func, num_warmup=5, num_runs=100, description=""): + # Warmup runs + for _ in range(num_warmup): + func() + tp.default_stream().synchronize() + torch.cuda.synchronize() + + # Actual timing + start = time.perf_counter() + for _ in range(num_runs): + output = func() + tp.default_stream().synchronize() + torch.cuda.synchronize() + + end = time.perf_counter() + + avg_time_ms = (end - start) * 1000 / num_runs + print( + f"{description} took {avg_time_ms:.2f} ms per run (averaged over {num_runs} runs, with {num_warmup} warmup runs)" + ) + + return output + + def generate_embedding(): + predictor.set_image_batch(image_list) + return None + + def predict_masks(): + return predictor.predict_batch( + point_coords_batch=[input_point] * args.batch, + point_labels_batch=[input_label] * args.batch, + multimask_output=True, + ) + + predictor.reset_predictor() + time_function(generate_embedding, description="Generating image embedding") + masks, scores, logits = time_function(predict_masks, description="Predicting masks") + + masks = masks[0] + scores = scores[0] + logits = logits[0] + + # Sort masks by confidence score + sorted_ind = np.argsort(scores)[::-1] + masks = masks[sorted_ind] + scores = scores[sorted_ind] + logits = logits[sorted_ind] + + # Process and display results + results = process_predictions( + image, + masks, + scores, + logits, + point_coords=input_point, + input_labels=input_label, + save_path=save_path, + ) + return results + + +if __name__ == "__main__": + main("truck.jpg", save_path="output") diff --git a/tripy/examples/segment-anything-model-v2/requirements.txt b/tripy/examples/segment-anything-model-v2/requirements.txt new file mode 100644 index 000000000..5cd719001 --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/requirements.txt @@ -0,0 +1,10 @@ +torch>=2.3.1 +torchvision>=0.18.1 +numpy>=1.24.4 +tqdm>=4.66.1 +hydra-core>=1.3.2 +iopath>=0.1.10 +pillow>=9.4.0 +matplotlib>=3.9.1 +jupyter>=1.0.0 +opencv-python>=4.7.0 diff --git a/tripy/examples/segment-anything-model-v2/sam2/__init__.py b/tripy/examples/segment-anything-model-v2/sam2/__init__.py new file mode 100755 index 000000000..7f6c01317 --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/sam2/__init__.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra import initialize_config_module + +initialize_config_module("configs", version_base="1.2") diff --git a/tripy/examples/segment-anything-model-v2/sam2/build_sam.py b/tripy/examples/segment-anything-model-v2/sam2/build_sam.py new file mode 100644 index 000000000..8bae9eb96 --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/sam2/build_sam.py @@ -0,0 +1,466 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +from hydra import compose +from hydra.utils import instantiate +from omegaconf import OmegaConf + +import tripy as tp +import time +import os + + +def set_model_attr(model, attr_path, value): + """ + Set a model attribute, handling both nested and regular attributes. + """ + if "." not in attr_path: + setattr(model, attr_path, value) + else: + attrs = attr_path.split(".") + obj = model + # Navigate to the parent object + for attr in attrs[:-1]: + obj = getattr(obj, attr) + # Set the attribute on the parent object + setattr(obj, attrs[-1], value) + + +def get_component_configs(model, cfg): + """ + Get configurations for different components, including both compilation and weight loading info. + """ + batchsize = (1, 2, 4) + num_obj = (1, 2, 4) + model_precision = getattr(cfg["model"], "model_precision", "float32") + return { + "memory_attention": { + "enabled": True, + "model": model.memory_attention, + "dtype": model_precision, + "compile_args": [ + tp.InputInfo( + (4096, 1, 256), + getattr(tp, model_precision), + ), + tp.InputInfo( + ((4100, 16400, 28736), 1, 64), + getattr(tp, model_precision), + ), + tp.InputInfo( + (4096, 1, 256), + getattr(tp, model_precision), + ), + tp.InputInfo( + ((4100, 16400, 28736), 1, 64), + getattr(tp, model_precision), + ), + tp.InputInfo(((4, 16, 64),), tp.int32), + ], + "skip_dtype_convert": ["ln", "norm"], + }, + "sam_mask_decoder_false": { + "enabled": True, + "model": model.sam_mask_decoder, + "dtype": model_precision, + "compile_args": [ + tp.InputInfo( + (1, 256, 64, 64), + dtype=getattr(tp, model_precision), + ), # image_embeddings + tp.InputInfo( + (1, 256, 64, 64), + dtype=getattr(tp, model_precision), + ), # image_pe + tp.InputInfo( + (1, 3, 256), + dtype=getattr(tp, model_precision), + ), # sparse_prompt_embeddings + tp.InputInfo( + (1, 256, 64, 64), + dtype=getattr(tp, model_precision), + ), # dense_prompt_embeddings + False, # multimask_output + False, # repeat_image + tp.InputInfo( + (1, 32, 256, 256), + dtype=getattr(tp, model_precision), + ), # high_res_features_1 + tp.InputInfo( + (1, 64, 128, 128), + dtype=getattr(tp, model_precision), + ), # high_res_features_2 + ], + "skip_dtype_convert": ["ln", "norm", "output_upscaling.1"], + }, + "sam_mask_decoder_true": { + "enabled": True, + "model": model.sam_mask_decoder, + "dtype": model_precision, + "compile_args": [ + tp.InputInfo( + (batchsize, 256, 64, 64), + dtype=getattr(tp, model_precision), + ), # image_embeddings + tp.InputInfo( + (1, 256, 64, 64), + dtype=getattr(tp, model_precision), + ), # image_pe + tp.InputInfo( + (batchsize, 2, 256), + dtype=getattr(tp, model_precision), + ), # sparse_prompt_embeddings + tp.InputInfo( + (batchsize, 256, 64, 64), + dtype=getattr(tp, model_precision), + ), # dense_prompt_embeddings + True, # multimask_output + False, # repeat_image + tp.InputInfo( + (batchsize, 32, 256, 256), + dtype=getattr(tp, model_precision), + ), # high_res_features_1 + tp.InputInfo( + (batchsize, 64, 128, 128), + dtype=getattr(tp, model_precision), + ), # high_res_features_2 + ], + "skip_dtype_convert": ["ln", "norm", "output_upscaling.1"], + "skip_load_state_dict": True, + }, + "sam_mask_decoder.conv_s0": { + "enabled": True, + "model": model.sam_mask_decoder.conv_s0, + "dtype": model_precision, + "compile_args": [ + tp.InputInfo( + (batchsize, 256, 256, 256), + dtype=getattr(tp, model_precision), + ) + ], + "skip_dtype_convert": [], + "skip_load_state_dict": True, + }, + "sam_mask_decoder.conv_s1": { + "enabled": True, + "model": model.sam_mask_decoder.conv_s1, + "dtype": model_precision, + "compile_args": [ + tp.InputInfo( + (batchsize, 256, 128, 128), + dtype=getattr(tp, model_precision), + ) + ], + "skip_dtype_convert": [], + "skip_load_state_dict": True, + }, + "memory_encoder": { + "enabled": True, + "model": model.memory_encoder, + "dtype": model_precision, # TODO add fp16 to yaml + "compile_args": [ + tp.InputInfo((1, 256, 64, 64), getattr(tp, model_precision)), + tp.InputInfo((1, 1, 1024, 1024), getattr(tp, model_precision)), + True, + ], + "skip_dtype_convert": ["ln", "norm"] + + [f"encoder.{i}.{param}" for i in range(1, 40, 3) for param in ("weight", "bias")], + }, + "sam_prompt_encoder": { + "enabled": True, + "model": model.sam_prompt_encoder, + "dtype": "float32", + "compile_args": [ + tp.InputInfo((batchsize, num_obj, 2), dtype=tp.float32), + tp.InputInfo((batchsize, num_obj), dtype=tp.int32), + None, + None, + ], + "skip_dtype_convert": [], + "special_handling": lambda original_model: { + setattr( + model.sam_prompt_encoder, + "mask_input_size", + original_model.mask_input_size, + ) + }, + }, + "sam_prompt_encoder.get_dense_pe": { + "enabled": True, + "model": model.sam_prompt_encoder.get_dense_pe, + "dtype": model_precision, + "compile_args": [], + "skip_dtype_convert": [], + "skip_load_state_dict": True, + }, + "image_encoder.compiled_executable": { + "enabled": True, + "model": model.image_encoder.forward, + "dtype": model_precision, + "compile_args": [ + tp.InputInfo( + (batchsize, 3, 1024, 1024), + dtype=getattr( + tp, + model_precision, + ), + ), + ], + "skip_dtype_convert": ["norm"], + "special_key_loading": lambda key: ( + # If it's a neck.convs key that contains 'conv.' + # neck.convs.0.conv.weight -> neck.convs.0.weight + ".".join(parts[:-2] + [parts[-1]]) + if (parts := key.split(".")) and key.startswith("neck.convs") and "conv." in key + else key + ), + "special_handling": lambda original_model: { + setattr( + model.image_encoder, + "trunk", + type("Trunk", (), {"dtype": original_model.trunk.dtype})(), + ) + }, + }, + } + + +def build_sam2( + config_file, + ckpt_path=None, + device="cuda", + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, + **kwargs, +): + + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + # dynamically fall back to multi-mask if the single mask is not stable + "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + ] + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) + OmegaConf.resolve(cfg) + + model = instantiate(cfg.model, _recursive_=True) + _load_checkpoint(model, ckpt_path, cfg) + + current_dir = os.getcwd() + saved_engines_path = os.path.join(current_dir, "saved_engines") + + # Create the saved_engines directory if it doesn't exist + if not os.path.exists(saved_engines_path): + os.makedirs(saved_engines_path) + + # Get component configurations + components = get_component_configs(model, cfg) + required_components_for_image = [ + "sam_mask_decoder_true", + "sam_mask_decoder.conv_s0", + "sam_mask_decoder.conv_s1", + "sam_prompt_encoder", + "sam_prompt_encoder.get_dense_pe", + "image_encoder.compiled_executable", + ] + + for comp_name, comp_info in components.items(): + if not comp_info["enabled"] or comp_name not in required_components_for_image: + continue + + executable_file = os.path.join(saved_engines_path, comp_name) + if os.path.exists(executable_file): + print(f"Loading existing compiled {comp_name} from {executable_file}") + compiled_model = tp.Executable.load(executable_file) + else: + print(f"Compiling {comp_name}...") + start = time.time() + compiled_model = tp.compile(comp_info["model"], args=comp_info["compile_args"]) + print(f"Compilation took {time.time() - start:.2f}s") + compiled_model.save(executable_file) + + old_model = comp_info["model"] + # If model is model.forward, retrieve the original model object + if hasattr(old_model, "__self__"): + old_model = old_model.__self__ + + set_model_attr(model, comp_name, compiled_model) + if "special_handling" in comp_info and comp_info["special_handling"] is not None: + comp_info["special_handling"](old_model) + + model = model.to(device) + if mode == "eval": + model.eval() + return model + + +def build_sam2_video_predictor( + config_file, + ckpt_path=None, + device="cuda", + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, + **kwargs, +): + hydra_overrides = [ + "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", + ] + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + # dynamically fall back to multi-mask if the single mask is not stable + "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking + "++model.binarize_mask_from_pts_for_mem_enc=true", + # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) + "++model.fill_hole_area=8", + ] + hydra_overrides.extend(hydra_overrides_extra) + + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides) + OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) + _load_checkpoint(model, ckpt_path, cfg) + + current_dir = os.getcwd() + saved_engines_path = os.path.join(current_dir, "saved_engines") + # Create the saved_engines directory if it doesn't exist + if not os.path.exists(saved_engines_path): + os.makedirs(saved_engines_path) + + # Get component configurations + components = get_component_configs(model, cfg) + + for comp_name, comp_info in components.items(): + if not comp_info["enabled"]: + continue + + executable_file = os.path.join(saved_engines_path, comp_name) + if os.path.exists(executable_file): + print(f"Loading existing compiled {comp_name} from {executable_file}") + compiled_model = tp.Executable.load(executable_file) + else: + print(f"Compiling {comp_name}...") + start = time.time() + compiled_model = tp.compile(comp_info["model"], args=comp_info["compile_args"]) + print(f"Compilation took {time.time() - start:.2f}s") + compiled_model.save(executable_file) + + old_model = comp_info["model"] + # If model is model.forward, retrieve the original model object + if hasattr(old_model, "__self__"): + old_model = old_model.__self__ + + set_model_attr(model, comp_name, compiled_model) + if "special_handling" in comp_info and comp_info["special_handling"] is not None: + comp_info["special_handling"](old_model) + + model = model.to(device) + if mode == "eval": + model.eval() + return model + + +def load_component_weights(comp_name, component_info, state_dict, checkpoint_dict): + """ + Load weights for a single component from checkpoint into state dict. + """ + + converted_keys = 0 + if component_info.get("skip_load_state_dict"): + return converted_keys + + for key in checkpoint_dict: + # Remove _true/_false suffixes if present + for suffix in ["_true", "_false", ".compiled_executable"]: + if comp_name.endswith(suffix): + comp_name = comp_name[: -len(suffix)] + break + + if not key.startswith(comp_name): + continue + + new_key = key.replace(f"{comp_name}.", "") + if "special_key_loading" in component_info: + new_key = component_info["special_key_loading"](new_key) + weight = checkpoint_dict[key] + + should_convert = not any(skip in key for skip in component_info["skip_dtype_convert"]) + if should_convert and component_info["dtype"] is not None: + weight = weight.to(getattr(torch, component_info["dtype"])) + + state_dict[new_key] = tp.Parameter(weight.contiguous()) + converted_keys += 1 + + return converted_keys + + +def _load_checkpoint(model, ckpt_path, cfg=None): + + if ckpt_path is None: + return + + sd = torch.load(ckpt_path, map_location="cpu")["model"] + missing_keys, unexpected_keys = model.load_state_dict(sd, strict=False) + + # Get paths for compiled models + current_dir = os.getcwd() + saved_engines_path = os.path.join(current_dir, "saved_engines") + + # Get component configurations + components = get_component_configs(model, cfg) + + # Process each component + for comp_name, comp_info in components.items(): + + if not comp_info["enabled"] or comp_info.get("skip_load_state_dict"): + continue + + # Skip if compiled model exists + model_path = os.path.join(saved_engines_path, comp_name) + if os.path.exists(model_path): + print(f"Using existing compiled model for {comp_name}") + continue + + # If no compiled model exists, convert and load weights + print(f"Converting weights for {comp_name}") + + comp_model = comp_info["model"] + # If model is model.forward, retrieve the original model object + if hasattr(comp_model, "__self__"): + comp_model = comp_model.__self__ + component_sd = comp_model.state_dict() + converted_keys = load_component_weights(comp_name, comp_info, component_sd, sd) + comp_model.load_state_dict(component_sd, strict=False) + if comp_name == "image_encoder.compiled_executable": + comp_model.trunk.generate_static_pos_embed((256, 256)) + + print(f"Converted {converted_keys} keys for {comp_name}") diff --git a/tripy/examples/segment-anything-model-v2/sam2/modeling/__init__.py b/tripy/examples/segment-anything-model-v2/sam2/modeling/__init__.py new file mode 100755 index 000000000..a08b2c204 --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/sam2/modeling/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tripy/examples/segment-anything-model-v2/sam2/modeling/backbones/__init__.py b/tripy/examples/segment-anything-model-v2/sam2/modeling/backbones/__init__.py new file mode 100755 index 000000000..4e3ee0298 --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/sam2/modeling/backbones/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tripy/examples/segment-anything-model-v2/sam2/modeling/backbones/hieradet.py b/tripy/examples/segment-anything-model-v2/sam2/modeling/backbones/hieradet.py new file mode 100755 index 000000000..b2233c5d6 --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/sam2/modeling/backbones/hieradet.py @@ -0,0 +1,306 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial +from typing import Callable, List, Tuple, Union + +import tripy as tp +from sam2.modeling.backbones.utils import ( + PatchEmbed, + window_partition, + window_unpartition, +) +from sam2.modeling.sam2_utils import MLP, scaled_dot_product_attention + + +def do_pool(x, pool, norm=None): + if pool is None: + return x + # (B, H, W, C) -> (B, C, H, W) + x = tp.permute(x, (0, 3, 1, 2)) + x = pool(x) + # (B, C, H', W') -> (B, H', W', C) + x = tp.permute(x, (0, 2, 3, 1)) + if norm: + x = norm(x) + + return x + + +class MultiScaleAttention(tp.Module): + + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + q_pool: tp.Module = None, + dtype: tp.dtype = tp.float32, + ): + super().__init__() + + self.dim = dim + self.dim_out = dim_out + self.num_heads = num_heads + self.q_pool = q_pool + self.qkv = tp.Linear(dim, dim_out * 3, dtype=dtype) + self.proj = tp.Linear(dim_out, dim_out, dtype=dtype) + + def __call__(self, x): + B, H, W = x.shape[0:3] + # qkv with shape (B, H * W, 3, nHead, C) + qkv = tp.reshape(self.qkv(x), (B, H * W, 3, self.num_heads, -1)) + # q, k, v with shape (B, H * W, nheads, C) + q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] + + # Q pooling (for downsample at stage changes) + if self.q_pool: + q = do_pool(tp.reshape(q, (B, H, W, -1)), self.q_pool) + H, W = q.shape[1:3] # downsampled shape + q = tp.reshape(q, (B, H * W, self.num_heads, -1)) + + x = scaled_dot_product_attention( + tp.transpose(q, 1, 2), + tp.transpose(k, 1, 2), + tp.transpose(v, 1, 2), + ) + # Transpose back + x = tp.transpose(x, 1, 2) + x = tp.reshape(x, (B, H, W, -1)) + + x = self.proj(x) + + return x + + +class MultiScaleBlock(tp.Module): + + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + mlp_ratio: float = 4.0, + norm_layer: Union[tp.Module, str] = "LayerNorm", + q_stride: Tuple[int, int] = None, + act_layer: Callable = tp.gelu, + window_size: int = 0, + dtype: tp.dtype = tp.float32, + ): + super().__init__() + + if isinstance(norm_layer, str): + norm_layer = partial(getattr(tp, norm_layer), eps=1e-6) + + self.dim = dim + self.dim_out = dim_out + self.norm1 = norm_layer(dim) + + self.window_size = window_size + + self.pool, self.q_stride = None, q_stride + if self.q_stride: + self.pool = partial(tp.maxpool, kernel_dims=q_stride, stride=q_stride) + + self.attn = MultiScaleAttention( + dim, + dim_out, + num_heads=num_heads, + q_pool=self.pool, + dtype=dtype, + ) + + self.norm2 = norm_layer(dim_out) + self.mlp = MLP( + dim_out, + int(dim_out * mlp_ratio), + dim_out, + num_layers=2, + activation=act_layer, + dtype=dtype, + ) + + if dim != dim_out: + self.proj = tp.Linear(dim, dim_out, dtype=dtype) + + def __call__(self, x): + + def call_norm(x, norm): + x_dtype = x.dtype + x = tp.cast(x, tp.float32) + x = norm(x) + return tp.cast(x, x_dtype) + + shortcut = x # B, H, W, C + x = call_norm(x, self.norm1) + + # Skip connection + if self.dim != self.dim_out: + shortcut = do_pool(self.proj(x), self.pool) + + # Window partition + window_size = self.window_size + if window_size > 0: + H, W = x.shape[1:3] + x, pad_hw = window_partition(x, window_size) + + # Window Attention + Q Pooling (if stage change) + x = self.attn(x) + if self.q_stride: + # Shapes have changed due to Q pooling + window_size = self.window_size // self.q_stride[0] + H, W = shortcut.shape[1:3] + + def mod_int(x, y): + return x - (x / y) * y + + pad_h = mod_int((window_size - mod_int(H, window_size)), window_size) + pad_w = mod_int((window_size - mod_int(W, window_size)), window_size) + pad_hw = (H + pad_h, W + pad_w) + + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, window_size, pad_hw, (H, W)) + + x = shortcut + x + # MLP + t = call_norm(x, self.norm2) + x = x + self.mlp(t) + return x + + +class Hiera(tp.Module): + + def __init__( + self, + embed_dim: int = 96, # initial embed dim + num_heads: int = 1, # initial number of heads + q_pool: int = 3, # number of q_pool stages + q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages + stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage + dim_mul: float = 2.0, # dim_mul factor at stage shift + head_mul: float = 2.0, # head_mul factor at stage shift + window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), + # window size per stage, when not using global att. + window_spec: Tuple[int, ...] = ( + 8, + 4, + 14, + 7, + ), + # global attn in these blocks + global_att_blocks: Tuple[int, ...] = ( + 12, + 16, + 20, + ), + return_interm_layers=True, # return feats from every stage + dtype: str = "float32", + ): + super().__init__() + + self.dtype = dtype + tp_dtype = getattr(tp, dtype) + assert len(stages) == len(window_spec) + self.window_spec = window_spec + + depth = sum(stages) + self.q_stride = q_stride + self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] + assert 0 <= q_pool <= len(self.stage_ends[:-1]) + self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] + self.return_interm_layers = return_interm_layers + + self.embed_dim = embed_dim + self.patch_embed = PatchEmbed(embed_dim=embed_dim, dtype=tp_dtype) + self.global_att_blocks = global_att_blocks + + # Windowed positional embedding (https://arxiv.org/abs/2311.05613) + self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size + self.pos_embed = tp.Parameter(tp.zeros((1, embed_dim, *self.window_pos_embed_bkg_spatial_size), dtype=tp_dtype)) + self.pos_embed_window = tp.Parameter( + tp.zeros((1, embed_dim, self.window_spec[0], self.window_spec[0]), dtype=tp_dtype) + ) + self.pos_embed_torch = None + + cur_stage = 1 + self.blocks = [] + + for i in range(depth): + dim_out = embed_dim + # lags by a block, so first block of + # next stage uses an initial window size + # of previous stage and final window size of current stage + window_size = self.window_spec[cur_stage - 1] + + if self.global_att_blocks is not None: + window_size = 0 if i in self.global_att_blocks else window_size + + if i - 1 in self.stage_ends: + dim_out = int(embed_dim * dim_mul) + num_heads = int(num_heads * head_mul) + cur_stage += 1 + + block = MultiScaleBlock( + dim=embed_dim, + dim_out=dim_out, + num_heads=num_heads, + q_stride=self.q_stride if i in self.q_pool_blocks else None, + window_size=window_size, + dtype=tp_dtype, + ) + + embed_dim = dim_out + self.blocks.append(block) + + self.channel_list = ( + [self.blocks[i].dim_out for i in self.stage_ends[::-1]] + if return_interm_layers + else [self.blocks[-1].dim_out] + ) + + def generate_static_pos_embed(self, hw: Tuple[int, int]): + import torch + import torch.nn.functional as F + + h, w = hw + window_embed = torch.from_dlpack(self.pos_embed_window) + pos_embed = F.interpolate(torch.from_dlpack(self.pos_embed), size=(h, w), mode="bicubic") + pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)]) + pos_embed = pos_embed.permute(0, 2, 3, 1) + self.pos_embed_torch = pos_embed.contiguous() + + def __call__(self, x: tp.Tensor): + x = self.patch_embed(x) + # x: (B, H, W, C) + + # Add pos embed + x = x + tp.Tensor(self.pos_embed_torch) + + outputs = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers): + # [1, 7, 43, 47] + feats = tp.permute(x, (0, 3, 1, 2)) + outputs.append(feats) + + return outputs diff --git a/tripy/examples/segment-anything-model-v2/sam2/modeling/backbones/image_encoder.py b/tripy/examples/segment-anything-model-v2/sam2/modeling/backbones/image_encoder.py new file mode 100755 index 000000000..99a9d0bee --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/sam2/modeling/backbones/image_encoder.py @@ -0,0 +1,148 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional + +import tripy as tp + + +class ImageEncoder(tp.Module): + + def __init__( + self, + trunk: tp.Module, + neck: tp.Module, + scalp: int = 0, + ): + super().__init__() + self.trunk = trunk + self.neck = neck + self.scalp = scalp + assert ( + self.trunk.channel_list == self.neck.backbone_channel_list + ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" + self.compiled_executable = None + + def forward(self, x): + # __call__ returns a dict, not tensors + # thus we need to only compile this forward function + # Forward through backbone + return self.neck(self.trunk(x)) + + def __call__(self, sample: tp.Tensor): + import torch + + # Forward through backbone + if self.compiled_executable: + features_pos = self.compiled_executable(sample) + tp.default_stream().synchronize() + else: + features_pos = self.forward(sample) + for i in range(len(features_pos)): + features_pos[i] = torch.from_dlpack(features_pos[i]) + n = len(self.neck.backbone_channel_list) + features = list(features_pos[:n]) + pos = list(features_pos[n:]) + if self.scalp > 0: + # Discard the lowest resolution features + features, pos = features[: -self.scalp], pos[: -self.scalp] + + src = features[-1] + output = { + "vision_features": src, + "vision_pos_enc": pos, + "backbone_fpn": features, + } + return output + + +class FpnNeck(tp.Module): + + def __init__( + self, + position_encoding: tp.Module, # TODO: replace this with shapes + d_model: int, + backbone_channel_list: List[int], + kernel_size: int = 1, + stride: int = 1, + padding: int = 0, + fpn_interp_model: str = "bilinear", + fuse_type: str = "sum", + fpn_top_down_levels: Optional[List[int]] = None, + dtype: str = "float32", + ): + super().__init__() + self.dtype = getattr(tp, dtype) + self.convs = [] + self.backbone_channel_list = backbone_channel_list + for dim in backbone_channel_list: + make_2d_tuple = lambda x: 2 * (x,) + self.convs.append( + tp.Conv( + in_channels=dim, + out_channels=d_model, + kernel_dims=make_2d_tuple(kernel_size), + stride=make_2d_tuple(stride), + dtype=self.dtype, + ) + ) + self.fpn_interp_model = fpn_interp_model + assert fuse_type in ["sum", "avg"] + self.fuse_type = fuse_type + + if fpn_top_down_levels is None: + fpn_top_down_levels = range(len(self.convs)) + self.fpn_top_down_levels = list(fpn_top_down_levels) + + # position embedding only depends on input shape + # so we generate static embeddings ahead of time + self.position_encoding = [] + position_encoding_shapes = [[256, 256], [128, 128], [64, 64], [32, 32]] + for s in position_encoding_shapes: + self.position_encoding.append(position_encoding.generate_static_embedding([1, 256] + s, dtype=dtype)) + + def __call__(self, xs: List[tp.Tensor]): + + out = [None] * len(self.convs) + pos = [None] * len(self.convs) + assert len(xs) == len(self.convs) + # fpn forward pass + prev_features = None + # forward in top-down order (from low to high resolution) + n = len(self.convs) - 1 + for i in range(n, -1, -1): + x = xs[i] + lateral_features = self.convs[n - i](x) + if i in self.fpn_top_down_levels and prev_features is not None: + top_down_features = tp.resize( + tp.cast(prev_features, self.dtype), + mode=self.fpn_interp_model, + output_shape=(prev_features.shape[0], 256, 64, 64), + ) + prev_features = lateral_features + top_down_features + if self.fuse_type == "avg": + prev_features /= 2 + else: + prev_features = lateral_features + x_out = prev_features + out[i] = x_out + pos[i] = tp.cast(self.position_encoding[i], x_out.dtype) + + return *out, *pos diff --git a/tripy/examples/segment-anything-model-v2/sam2/modeling/backbones/utils.py b/tripy/examples/segment-anything-model-v2/sam2/modeling/backbones/utils.py new file mode 100755 index 000000000..b58080be2 --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/sam2/modeling/backbones/utils.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import tripy as tp + + +def window_partition(x, window_size): + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + # padding is not triggered + Hp, Wp = H, W + x = tp.reshape(x, (B, Hp // window_size, window_size, Wp // window_size, window_size, C)) + x = tp.permute(x, (0, 1, 3, 2, 4, 5)) + windows = tp.reshape(x, (-1, window_size, window_size, C)) + return windows, (Hp, Wp) + + +def window_unpartition(windows, window_size, pad_hw, hw): + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + + x = tp.reshape(windows, (B, Hp // window_size, Wp // window_size, window_size, window_size, -1)) + x = tp.permute(x, (0, 1, 3, 2, 4, 5)) # [B, Hp//window_size, window_size, Wp//window_size, window_size, C] + x = tp.reshape(x, (B, Hp, Wp, -1)) # [B, Hp, Wp, C] + return x + + +class PatchEmbed(tp.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, ...] = (7, 7), + stride: Tuple[int, ...] = (4, 4), + padding: Tuple[int, ...] = (3, 3), + in_chans: int = 3, + embed_dim: int = 768, + dtype: tp.dtype = tp.float32, + ): + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + padding = ((padding[0], padding[0]), (padding[1], padding[1])) + self.proj = tp.Conv( + in_chans, + embed_dim, + kernel_dims=kernel_size, + stride=stride, + padding=padding, + dtype=dtype, + ) + + def __call__(self, x): + x = self.proj(x) + x = tp.permute(x, (0, 2, 3, 1)) # [B, C, H, W] -> [B, H, W, C] + return x diff --git a/tripy/examples/segment-anything-model-v2/sam2/modeling/memory_attention.py b/tripy/examples/segment-anything-model-v2/sam2/modeling/memory_attention.py new file mode 100644 index 000000000..d03a10f21 --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/sam2/modeling/memory_attention.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +from sam2.modeling.sam.transformer import RoPEAttention +from sam2.modeling.sam2_utils import get_activation_fn + +import tripy as tp + + +class MemoryAttentionLayer(tp.Module): + + def __init__( + self, + activation: str, + cross_attention: tp.Module, + d_model: int, + dim_feedforward: int, + dropout: float, + pos_enc_at_attn: bool, + pos_enc_at_cross_attn_keys: bool, + pos_enc_at_cross_attn_queries: bool, + self_attention: tp.Module, + dtype: "float32", + ): + super().__init__() + self.d_model = d_model + self.dim_feedforward = dim_feedforward + self.self_attn = self_attention + self.cross_attn_image = cross_attention + self.dtype = getattr(tp, dtype) + + # Implementation of Feedforward model + self.linear1 = tp.Linear(d_model, dim_feedforward, dtype=self.dtype) + self.linear2 = tp.Linear(dim_feedforward, d_model, dtype=self.dtype) + + self.norm1 = tp.LayerNorm(d_model) + self.norm2 = tp.LayerNorm(d_model) + self.norm3 = tp.LayerNorm(d_model) + + self.activation_str = activation + self.activation = get_activation_fn(activation) + + # Where to add pos enc + self.pos_enc_at_attn = pos_enc_at_attn + self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries + self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys + + def _forward_sa(self, tgt, query_pos): + # Self-Attention + tgt2 = tp.cast(self.norm1(tp.cast(tgt, self.norm1.dtype)), self.dtype) + q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 + tgt2 = self.self_attn(q, k, v=tgt2, num_k_exclude_rope=tp.Tensor([0])) + tgt = tgt + tgt2 + return tgt + + def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): + kwds = {} + + # Cross-Attention + tgt2 = tp.cast(self.norm2(tp.cast(tgt, self.norm2.dtype)), self.dtype) + + tgt2 = self.cross_attn_image( + q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, + k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, + v=memory, + num_k_exclude_rope=num_k_exclude_rope, + ) + tgt = tgt + tgt2 + return tgt + + def __call__( + self, + tgt, + memory, + pos: Optional[tp.Tensor] = None, + query_pos: Optional[tp.Tensor] = None, + num_k_exclude_rope: Optional[tp.Tensor] = None, + ) -> tp.Tensor: + + # Self-Attn, Cross-Attn + tgt = self._forward_sa(tgt, query_pos) + tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) + # MLP + tgt2 = tp.cast(self.norm3(tp.cast(tgt, self.norm3.dtype)), self.dtype) + + tgt2 = self.linear2(self.activation(self.linear1(tgt2))) + tgt = tgt + tgt2 + return tgt + + +class MemoryAttention(tp.Module): + + def __init__( + self, + d_model: int, + pos_enc_at_input: bool, + layer: tp.Module, + num_layers: int, + batch_first: bool = True, # Do layers expect batch first input? + dtype="float32", + ): + super().__init__() + self.d_model = d_model + self.layers = [layer for i in range(num_layers)] + self.num_layers = num_layers + self.norm = tp.LayerNorm(d_model) + self.pos_enc_at_input = pos_enc_at_input + self.batch_first = batch_first + self.dtype = getattr(tp, dtype) + + def __call__( + self, + curr: tp.Tensor, # self-attention inputs + memory: tp.Tensor, # cross-attention inputs + curr_pos: Optional[tp.Tensor] = None, # pos_enc for self-attention inputs + memory_pos: Optional[tp.Tensor] = None, # pos_enc for cross-attention inputs + num_obj_ptr_tokens: Optional[tp.Tensor] = None, # number of object pointer *tokens* + ): + + num_obj_ptr_tokens = num_obj_ptr_tokens.shape[0] + output = curr + if self.pos_enc_at_input and curr_pos is not None: + output = output + 0.1 * curr_pos + + if self.batch_first: + # Convert to batch first + output = tp.transpose(output, 0, 1) + memory = tp.transpose(memory, 0, 1) + if curr_pos is not None: + curr_pos = tp.transpose(curr_pos, 0, 1) + if memory_pos is not None: + memory_pos = tp.transpose(memory_pos, 0, 1) + + for layer in self.layers: + kwds = {} + if isinstance(layer.cross_attn_image, RoPEAttention): + kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} + + output = layer( + tgt=output, + memory=memory, + pos=memory_pos, + query_pos=curr_pos, + **kwds, + ) + + normed_output = tp.cast(self.norm(tp.cast(output, self.norm.dtype)), self.dtype) + + if self.batch_first: + # Convert back to seq first + normed_output = tp.transpose(normed_output, 0, 1) + + return normed_output diff --git a/tripy/examples/segment-anything-model-v2/sam2/modeling/memory_encoder.py b/tripy/examples/segment-anything-model-v2/sam2/modeling/memory_encoder.py new file mode 100644 index 000000000..aa317ebca --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/sam2/modeling/memory_encoder.py @@ -0,0 +1,180 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Tuple + +import tripy as tp + +from sam2.modeling.sam2_utils import LayerNorm2d + + +class Dummy(tp.Module): + def __init__(self): + pass + + def __call__(self, x): + return x + + +class MaskDownSampler(tp.Module): + """ + Progressively downsample a mask by total_stride. + """ + + def __init__( + self, + embed_dim=256, + kernel_size=4, + stride=4, + padding=0, + total_stride=16, + activation=tp.gelu, + dtype="float32", + ): + super().__init__() + self.dtype = getattr(tp, dtype) + num_layers = int(math.log2(total_stride) // math.log2(stride)) + assert stride**num_layers == total_stride + self.encoder = [] + mask_in_chans, mask_out_chans = 1, 1 + for _ in range(num_layers): + mask_out_chans = mask_in_chans * (stride**2) + self.encoder.append( + tp.Conv( + mask_in_chans, + mask_out_chans, + kernel_dims=(kernel_size, kernel_size), + stride=(stride, stride), + padding=((padding, padding), (padding, padding)), + dtype=self.dtype, + ) + ) + self.encoder.append(LayerNorm2d(mask_out_chans)) + self.encoder.append(activation) + mask_in_chans = mask_out_chans + + self.encoder.append(tp.Conv(mask_out_chans, embed_dim, kernel_dims=(1, 1), dtype=self.dtype)) + + def __call__(self, x): + for l in self.encoder: + x = l(x) + return x + + +class CXBlock(tp.Module): + r"""ConvNeXt Block. + DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + """ + + def __init__( + self, + dim, + kernel_size=7, + padding=3, + drop_path=0.0, + layer_scale_init_value=1e-6, + use_dwconv=True, + dtype="float32", + ): + super().__init__() + self.dtype = getattr(tp, dtype) + self.dwconv = tp.Conv( + dim, + dim, + kernel_dims=(kernel_size, kernel_size), + padding=((padding, padding), (padding, padding)), + groups=dim if use_dwconv else 1, + dtype=self.dtype, + ) # depthwise conv + self.norm = LayerNorm2d(dim, eps=1e-6) + self.pwconv1 = tp.Linear(dim, 4 * dim, dtype=self.dtype) # pointwise/1x1 convs, implemented with linear layers + self.act = tp.gelu + self.pwconv2 = tp.Linear(4 * dim, dim, dtype=self.dtype) + self.gamma = tp.ones((dim,), dtype=self.dtype) * layer_scale_init_value + + self.drop_path = Dummy() + + def __call__(self, x): + input = x + x = self.dwconv(x) + x = self.norm(x) + x = tp.permute(x, (0, 2, 3, 1)) # (N, C, H, W) -> (N, H, W, C) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = tp.permute(x, (0, 3, 1, 2)) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class Fuser(tp.Module): + def __init__(self, layer, num_layers, dim=None, input_projection=False, dtype="float32"): + super().__init__() + self.dtype = getattr(tp, dtype) + self.proj = Dummy() + self.layers = [layer for i in range(num_layers)] + + if input_projection: + self.proj = tp.Conv(dim, dim, kernel_dims=(1, 1), dtype=self.dtype) + + def __call__(self, x): + x = self.proj(x) + for layer in self.layers: + x = layer(x) + return x + + +class MemoryEncoder(tp.Module): + def __init__( + self, out_dim, mask_downsampler, fuser, position_encoding, in_dim=256, dtype="float32" # in_dim of pix_feats + ): + super().__init__() + self.dtype = getattr(tp, dtype) + + self.mask_downsampler = mask_downsampler + self.pix_feat_proj = tp.Conv(in_dim, in_dim, kernel_dims=(1, 1), dtype=self.dtype) + self.fuser = fuser + self.position_encoding = position_encoding + self.out_proj = Dummy() + if out_dim != in_dim: + self.out_proj = tp.Conv(in_dim, out_dim, kernel_dims=(1, 1), dtype=self.dtype) + + def __call__( + self, + pix_feat: tp.Tensor, + masks: tp.Tensor, + skip_mask_sigmoid: bool = False, + ) -> Tuple[tp.Tensor, tp.Tensor]: + if not skip_mask_sigmoid: + masks = tp.sigmoid(masks) + masks = self.mask_downsampler(masks) + + x = self.pix_feat_proj(pix_feat) + x = x + masks + x = self.fuser(x) + x = self.out_proj(x) + + pos = tp.cast(self.position_encoding(x), x.dtype) + + return x, pos diff --git a/tripy/examples/segment-anything-model-v2/sam2/modeling/position_encoding.py b/tripy/examples/segment-anything-model-v2/sam2/modeling/position_encoding.py new file mode 100644 index 000000000..1e0f1ed74 --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/sam2/modeling/position_encoding.py @@ -0,0 +1,193 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import numpy as np +import tripy as tp +from typing import Optional, Tuple +from sam2.modeling.sam2_utils import cartesian_via_polar, mul_as_complex + + +class PositionEmbeddingSine(tp.Module): + + def __init__( + self, + num_pos_feats, + temperature: int = 10000, + normalize: bool = True, + scale: Optional[float] = None, + ): + super().__init__() + assert num_pos_feats % 2 == 0, "Expecting even model width" + self.num_pos_feats = num_pos_feats // 2 + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def __call__(self, x: tp.Tensor): + # x: [B, C, H, W] + B, _, H, W = x.shape + y_embed = tp.arange(1, H + 1, dtype=tp.float32) + y_embed = tp.reshape(y_embed, (1, -1, 1)) + y_embed = tp.repeat(y_embed, B, 0) + y_embed = tp.repeat(y_embed, W, 2) # [B, H, W] + + x_embed = tp.arange(1, W + 1, dtype=tp.float32) + x_embed = tp.reshape(x_embed, (1, 1, -1)) + x_embed = tp.repeat(x_embed, B, 0) + x_embed = tp.repeat(x_embed, H, 1) # [B, H, W] + + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = tp.arange(self.num_pos_feats, dtype=tp.float32) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = tp.unsqueeze(x_embed, -1) / dim_t + pos_y = tp.unsqueeze(y_embed, -1) / dim_t + pos_x = tp.stack((tp.sin(pos_x[:, :, :, 0::2]), tp.cos(pos_x[:, :, :, 1::2])), dim=4) + pos_y = tp.stack((tp.sin(pos_y[:, :, :, 0::2]), tp.cos(pos_y[:, :, :, 1::2])), dim=4) + pos_x = tp.flatten(pos_x, 3) + pos_y = tp.flatten(pos_y, 3) + pos = tp.concatenate([pos_x, pos_y], dim=3) + pos = tp.permute(pos, (0, 3, 1, 2)) + return pos + + def generate_static_embedding(self, inp_shape, dtype): + import torch + + B, _, H, W = inp_shape + y_embed = torch.arange(1, H + 1, dtype=torch.float32).view(1, -1, 1).repeat(B, 1, W) + x_embed = torch.arange(1, W + 1, dtype=torch.float32).view(1, 1, -1).repeat(B, H, 1) + + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return tp.Tensor(pos.to(getattr(torch, dtype)).contiguous()) + + +class PositionEmbeddingRandom(tp.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.positional_encoding_gaussian_matrix = tp.Parameter( + tp.Tensor( + scale * np.random.randn(2, num_pos_feats).astype(np.float32), + dtype=tp.float32, + ) + ) + + def _pe_encoding(self, coords: tp.Tensor) -> tp.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return tp.concatenate([tp.sin(coords), tp.cos(coords)], dim=-1) + + def __call__(self, size: Tuple[int, int]) -> tp.Tensor: + return self.forward(size) + + def forward(self, size: Tuple[int, int]) -> tp.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + grid = tp.ones((h, w), dtype=tp.float32) + y_embed = tp.cumsum(grid, dim=0) - 0.5 + x_embed = tp.cumsum(grid, dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(tp.stack([x_embed, y_embed], dim=-1)) + return tp.permute(pe, (2, 0, 1)) # C x H x W + + def forward_with_coords(self, coords_input: tp.Tensor, image_size: Tuple[int, int]) -> tp.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + new_x_coords = coords_input[:, :, 0] / image_size[1] + new_y_coords = coords_input[:, :, 1] / image_size[0] + + # Combine the updated x and y coordinates into a new tensor + new_coords = tp.stack([new_x_coords, new_y_coords], dim=-1) + return self._pe_encoding(tp.cast(new_coords, tp.float32)) # B x N x C + + +def init_t_xy(end_x: tp.DimensionSize, end_y: tp.DimensionSize): + t = tp.arange(end_x * end_y, dtype=tp.float32) + if isinstance(end_x, tp.DimensionSize) and isinstance(end_y, tp.DimensionSize): + end_x, end_y = tp.cast(end_x, tp.float32), tp.cast(end_y, tp.float32) + t_x = t % end_x + t_y = t // end_x + return t_x, t_y + + +def compute_axial_cis(dim: int, end_x: tp.DimensionSize, end_y: tp.DimensionSize, theta: float = 10000.0): + freqs_x = 1.0 / (theta ** (tp.cast(tp.arange(0, dim, 4)[: (dim // 4)], tp.float32) / dim)) + freqs_y = 1.0 / (theta ** (tp.cast(tp.arange(0, dim, 4)[: (dim // 4)], tp.float32) / dim)) + + t_x, t_y = init_t_xy(end_x, end_y) + freqs_x = tp.outer(t_x, freqs_x) + freqs_y = tp.outer(t_y, freqs_y) + freqs_cis_x = cartesian_via_polar(tp.ones_like(freqs_x), freqs_x) + freqs_cis_y = cartesian_via_polar(tp.ones_like(freqs_y), freqs_y) + return tp.concatenate([freqs_cis_x, freqs_cis_y], dim=-2) + + +def reshape_for_broadcast(freqs_cis: tp.Tensor, x: tp.Tensor): + ndim = x.rank + shape = [d if i >= ndim - 3 else 1 for i, d in enumerate(x.shape)] + return tp.reshape(freqs_cis, shape) + + +def apply_rotary_enc( + xq: tp.Tensor, + xk: tp.Tensor, + freqs_cis: tp.Tensor, + repeat_freqs_k: bool = False, +): + xq_ = tp.reshape(xq, (*xq.shape[:-1], -1, 2)) + xk_ = tp.reshape(xk, (*xk.shape[:-1], -1, 2)) + + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = mul_as_complex(xq_, freqs_cis) + xq_out = tp.reshape(xq_out, (xq_out.shape[0], xq_out.shape[1], xq_out.shape[2], -1)) + + # repeat freqs along seq_len dim to match k seq_len + if repeat_freqs_k: + r = xk_.shape[-3] // xq_.shape[-3] + freqs_cis = tp.flatten(tp.expand(tp.unsqueeze(freqs_cis, 2), (-1, -1, r, -1, -1, -1)), 2, 3) + + xk_out = tp.flatten(mul_as_complex(xk_, freqs_cis), 3) + return tp.cast(xq_out, xq.dtype), tp.cast(xk_out, xk.dtype) diff --git a/tripy/examples/segment-anything-model-v2/sam2/modeling/sam/__init__.py b/tripy/examples/segment-anything-model-v2/sam2/modeling/sam/__init__.py new file mode 100644 index 000000000..a08b2c204 --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/sam2/modeling/sam/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tripy/examples/segment-anything-model-v2/sam2/modeling/sam/mask_decoder.py b/tripy/examples/segment-anything-model-v2/sam2/modeling/sam/mask_decoder.py new file mode 100755 index 000000000..c6d4cf2bd --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/sam2/modeling/sam/mask_decoder.py @@ -0,0 +1,394 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Tuple + +import tripy as tp + +from sam2.modeling.sam2_utils import LayerNorm2d, MLP + + +class Dummy(tp.Module): + def __init__(self): + pass + + def __call__(self, x): + return x + + +class MaskDecoder(tp.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: tp.Module, + num_multimask_outputs: int = 3, + activation=tp.gelu, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + use_high_res_features: bool = False, + iou_prediction_use_sigmoid=False, + dynamic_multimask_via_stability=False, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + pred_obj_scores: bool = False, + pred_obj_scores_mlp: bool = False, + use_multimask_token_for_obj_ptr: bool = False, + dtype=tp.float32, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + + self.transformer_dim = transformer_dim + self.transformer = transformer + self.dtype = dtype + + self.num_multimask_outputs = num_multimask_outputs + self.activation = activation + + self.iou_token = tp.Embedding(1, transformer_dim, dtype=dtype) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = tp.Embedding(self.num_mask_tokens, transformer_dim, dtype=dtype) + + self.pred_obj_scores = pred_obj_scores + if self.pred_obj_scores: + self.obj_score_token = tp.Embedding(1, transformer_dim, dtype=dtype) + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + + self.output_upscaling = [ + tp.ConvTranspose( + transformer_dim, + transformer_dim // 4, + kernel_dims=(2, 2), + stride=(2, 2), + dtype=dtype, + ), + LayerNorm2d(transformer_dim // 4), + Dummy(), # Accounts for Dropout layer, needed for weight loading + tp.ConvTranspose( + transformer_dim // 4, + transformer_dim // 8, + kernel_dims=(2, 2), + stride=(2, 2), + dtype=dtype, + ), + Dummy(), # Accounts for Dropout layer, needed for weight loading + ] + self.use_high_res_features = use_high_res_features + if use_high_res_features: + self.conv_s0 = tp.Conv( + transformer_dim, + transformer_dim // 8, + kernel_dims=(1, 1), + stride=(1, 1), + dtype=dtype, + ) + self.conv_s1 = tp.Conv( + transformer_dim, + transformer_dim // 4, + kernel_dims=(1, 1), + stride=(1, 1), + dtype=dtype, + ) + + self.output_hypernetworks_mlps = [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3, dtype=dtype) + for i in range(self.num_mask_tokens) + ] + + self.iou_prediction_head = MLP( + transformer_dim, + iou_head_hidden_dim, + self.num_mask_tokens, + iou_head_depth, + sigmoid_output=iou_prediction_use_sigmoid, + dtype=dtype, + ) + if self.pred_obj_scores: + self.pred_obj_score_head = tp.Linear(transformer_dim, 1) + if pred_obj_scores_mlp: + self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3, dtype=dtype) + + # When outputting a single mask, optionally we can dynamically fall back to the best + # multimask output token if the single mask output token gives low stability scores. + self.dynamic_multimask_via_stability = dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh + + def __call__( + self, + image_embeddings: tp.Tensor, + image_pe: tp.Tensor, + sparse_prompt_embeddings: tp.Tensor, + dense_prompt_embeddings: tp.Tensor, + multimask_output: bool, + repeat_image: bool, + high_res_features_1: Optional[tp.Tensor] = None, + high_res_features_2: Optional[tp.Tensor] = None, + ) -> Tuple[tp.Tensor, tp.Tensor]: + + return self.forward( + image_embeddings, + image_pe, + sparse_prompt_embeddings, + dense_prompt_embeddings, + multimask_output, + repeat_image, + high_res_features_1, + high_res_features_2, + ) + + def forward( + self, + image_embeddings: tp.Tensor, + image_pe: tp.Tensor, + sparse_prompt_embeddings: tp.Tensor, + dense_prompt_embeddings: tp.Tensor, + multimask_output: bool, + repeat_image: bool, + high_res_features_1: Optional[tp.Tensor] = None, + high_res_features_2: Optional[tp.Tensor] = None, + ) -> Tuple[tp.Tensor, tp.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (tp.Tensor): the embeddings from the image encoder + image_pe (tp.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (tp.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (tp.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + tp.Tensor: batched predicted masks + tp.Tensor: batched predictions of mask quality + tp.Tensor: batched SAM token for mask output + """ + masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + repeat_image=repeat_image, + high_res_features_1=high_res_features_1, + high_res_features_2=high_res_features_2, + ) + + # Select the correct mask or masks for output + if multimask_output: + masks = masks[:, 1:, :, :] + iou_pred = iou_pred[:, 1:] + elif self.dynamic_multimask_via_stability: + masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) + # masks = masks[:, 0:1, :, :] + # iou_pred = iou_pred[:, 0:1] + + if multimask_output and self.use_multimask_token_for_obj_ptr: + sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape + else: + # Take the mask output token. Here we *always* use the token for single mask output. + # At test time, even if we track after 1-click (and using multimask_output=True), + # we still take the single mask token here. The rationale is that we always track + # after multiple clicks during training, so the past tokens seen during training + # are always the single mask token (and we'll let it be the object-memory token). + sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape + + # Prepare output + return masks, iou_pred, sam_tokens_out, object_score_logits + + def predict_masks( + self, + image_embeddings: tp.Tensor, + image_pe: tp.Tensor, + sparse_prompt_embeddings: tp.Tensor, + dense_prompt_embeddings: tp.Tensor, + repeat_image: bool, + high_res_features_1: Optional[tp.Tensor] = None, + high_res_features_2: Optional[tp.Tensor] = None, + ) -> Tuple[tp.Tensor, tp.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + s = 0 + if self.pred_obj_scores: + output_tokens = tp.concatenate( + [ + self.obj_score_token.weight, + self.iou_token.weight, + self.mask_tokens.weight, + ], + dim=0, + ) + s = 1 + else: + output_tokens = tp.concatenate([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = tp.expand(tp.unsqueeze(output_tokens, 0), (sparse_prompt_embeddings.shape[0], -1, -1)) + tokens = tp.concatenate((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + if repeat_image: + src = tp.repeat(image_embeddings, tokens.shape[0], dim=0) + else: + src = image_embeddings + + src = src + dense_prompt_embeddings + pos_src = tp.repeat(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + og_dtype = src.dtype + src = tp.cast(src, self.dtype) + pos_src = tp.cast(pos_src, self.dtype) + tokens = tp.cast(tokens, self.dtype) + hs, src = self.transformer(src, pos_src, tokens) + hs = tp.cast(hs, og_dtype) + src = tp.cast(src, og_dtype) + + iou_token_out = hs[:, s, :] + mask_tokens_out = hs[:, s + 1 : s + 1 + self.num_mask_tokens, :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = tp.reshape(tp.transpose(src, 1, 2), (b, c, h, w)) + act1 = self.activation + act2 = self.activation + + if not self.use_high_res_features: + dc1, ln1, _, dc2, _ = self.output_upscaling + post_ln1 = tp.cast(ln1(tp.cast(dc1(src), tp.float32)), src.dtype) + upscaled_embedding = act2(dc2(act1(post_ln1))) + # upscaled_embedding = act2(dc2(act1(ln1(dc1(src))))) + else: + dc1, ln1, _, dc2, _ = self.output_upscaling + feat_s0, feat_s1 = high_res_features_1, high_res_features_2 + post_ln1 = tp.cast(ln1(tp.cast(dc1(src) + feat_s1, tp.float32)), src.dtype) + upscaled_embedding = act1(post_ln1) + # upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) + upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) + + hyper_in_list: List[tp.Tensor] = [] + for i in range(self.num_mask_tokens): + mlp_in = mask_tokens_out[:, i, :] + hyper_in_list.append(self.output_hypernetworks_mlps[i](mlp_in)) + hyper_in = tp.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + upscaled_embedding = tp.reshape(upscaled_embedding, (b, c, h * w)) + # out_4 = upscaled_embedding + masks = tp.reshape(hyper_in @ upscaled_embedding, (b, -1, h, w)) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + if self.pred_obj_scores: + assert s == 1 + object_score_logits = self.pred_obj_score_head(hs[:, 0, :]) + else: + # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 + object_score_logits = 10.0 * tp.ones( + (iou_pred.shape[0], 1), dtype=self.dtype + ) # iou_pred.new_ones(iou_pred.shape[0], 1) + + return masks, iou_pred, mask_tokens_out, object_score_logits + + def _get_stability_scores(self, mask_logits): + """ + Compute stability scores of the mask logits based on the IoU between upper and + lower thresholds, similar to https://github.com/fairinternal/onevision/pull/568. + """ + mask_logits = tp.flatten(mask_logits, -2) + stability_delta = self.dynamic_multimask_stability_delta + area_i = tp.cast(tp.sum(tp.cast(mask_logits > stability_delta, tp.float32), dim=-1), self.dtype) + area_u = tp.cast(tp.sum(tp.cast(mask_logits > -stability_delta, tp.float32), dim=-1), self.dtype) + stability_scores = tp.where(area_u > 0, area_i / area_u, tp.Tensor(1.0, dtype=self.dtype)) + return stability_scores + + def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): + """ + When outputting a single mask, if the stability score from the current single-mask + output (based on output token 0) falls below a threshold, we instead select from + multi-mask outputs (based on output token 1~3) the mask with the highest predicted + IoU score. This is intended to ensure a valid mask for both clicking and tracking. + """ + # The best mask from multimask output tokens (1~3) + multimask_logits = all_mask_logits[:, 1:, :, :] + multimask_iou_scores = all_iou_scores[:, 1:] + best_scores_inds = tp.argmax(multimask_iou_scores, dim=-1) + batch_inds = tp.arange(multimask_iou_scores.shape[0], dtype=self.dtype) + batch_inds = tp.cast(batch_inds, tp.int32) + + def advanced_indexing(tensor, first_index, second_index): + # Step 1: Use the first_index to select rows + step1 = tp.gather(tensor, dim=0, index=first_index) + + # Step 2: Prepare for the second gather operation + batch_size = first_index.shape[0] + row_indices = tp.arange(batch_size, dtype=tp.int32) + + # Step 3: Combine row_indices and second_index + combined_indices = tp.stack([row_indices, second_index], dim=1) + + # Step 4: Flatten the tensor + flattened = tp.flatten(step1) + + # Step 5: Calculate flat indices + flat_indices = combined_indices[:, 0] * batch_size + combined_indices[:, 1] + + # Step 6: Gather using flat indices + result = tp.gather(flattened, dim=0, index=flat_indices) + + return result + + best_multimask_logits = advanced_indexing(multimask_logits, batch_inds, best_scores_inds) + best_multimask_iou_scores = advanced_indexing(multimask_iou_scores, batch_inds, best_scores_inds) + + best_multimask_logits = tp.unsqueeze(best_multimask_logits, 1) + best_multimask_iou_scores = tp.unsqueeze(best_multimask_iou_scores, 1) + + # The mask from singlemask output token 0 and its stability score + singlemask_logits = all_mask_logits[:, 0:1, :, :] + singlemask_iou_scores = all_iou_scores[:, 0:1] + stability_scores = self._get_stability_scores(singlemask_logits) + is_stable = stability_scores >= self.dynamic_multimask_stability_thresh + + # Dynamically fall back to best multimask output upon low stability scores. + is_stable = tp.unsqueeze(tp.unsqueeze(is_stable, -1), -1) + mask_logits_out = tp.where( + is_stable, + singlemask_logits, + best_multimask_logits, + ) + iou_scores_out = tp.where( + is_stable, + singlemask_iou_scores, + best_multimask_iou_scores, + ) + return mask_logits_out, iou_scores_out diff --git a/tripy/examples/segment-anything-model-v2/sam2/modeling/sam/prompt_encoder.py b/tripy/examples/segment-anything-model-v2/sam2/modeling/sam/prompt_encoder.py new file mode 100755 index 000000000..3a3e0e23b --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/sam2/modeling/sam/prompt_encoder.py @@ -0,0 +1,223 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple, Type + +import tripy as tp +from sam2.modeling.position_encoding import PositionEmbeddingRandom +from sam2.modeling.sam2_utils import LayerNorm2d + + +class PromptEncoder(tp.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[tp.Module] = tp.gelu, + dtype=tp.float32, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (tp.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + self.point_embeddings = [tp.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] + self.not_a_point_embed = tp.Embedding(1, embed_dim) + + self.mask_input_size = ( + 4 * image_embedding_size[0], + 4 * image_embedding_size[1], + ) + self.mask_downscaling = [ + tp.Conv(1, mask_in_chans // 4, kernel_dims=(2, 2), stride=(2, 2)), + LayerNorm2d(mask_in_chans // 4), + activation, + tp.Conv(mask_in_chans // 4, mask_in_chans, kernel_dims=(2, 2), stride=(2, 2)), + LayerNorm2d(mask_in_chans), + activation, + tp.Conv(mask_in_chans, embed_dim, kernel_dims=(1, 1)), + ] + self.no_mask_embed = tp.Embedding(1, embed_dim) + self.dtype = dtype + + def get_dense_pe(self) -> tp.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + tp.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + dense_pe = tp.unsqueeze(self.pe_layer(self.image_embedding_size), 0) + return tp.cast(dense_pe, self.dtype) + + def _embed_points( + self, + points: tp.Tensor, + labels: tp.Tensor, + pad: bool, + ) -> tp.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = tp.zeros((points.shape[0], 1, 2), dtype=points.dtype) + padding_label = 0 - tp.ones((labels.shape[0], 1), dtype=labels.dtype) + padding_label = tp.cast(padding_label, labels.dtype) + points = tp.concatenate([points, padding_point], dim=1) + labels = tp.concatenate([labels, padding_label], dim=1) + + point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + labels = tp.unsqueeze(labels, 2) + point_embedding = tp.where(labels == -1, tp.Tensor([0.0]), point_embedding) + point_embedding = tp.where( + labels == -1, + point_embedding + self.not_a_point_embed.weight, + point_embedding, + ) + point_embedding = tp.where( + labels == 0, + point_embedding + self.point_embeddings[0].weight, + point_embedding, + ) + point_embedding = tp.where( + labels == 1, + point_embedding + self.point_embeddings[1].weight, + point_embedding, + ) + point_embedding = tp.where( + labels == 2, + point_embedding + self.point_embeddings[2].weight, + point_embedding, + ) + point_embedding = tp.where( + labels == 3, + point_embedding + self.point_embeddings[3].weight, + point_embedding, + ) + return point_embedding + + def _embed_boxes(self, boxes: tp.Tensor) -> tp.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = tp.reshape(boxes, (-1, 2, 2)) + corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + + corner_embedding_0 = corner_embedding[:, 0, :] + self.point_embeddings[2].weight + corner_embedding_1 = corner_embedding[:, 1, :] + self.point_embeddings[3].weight + + # Combine the updated x and y coordinates into a new tensor + new_corner_embedding = tp.stack([corner_embedding_0, corner_embedding_1], dim=1) + + return new_corner_embedding + + def _embed_masks(self, masks: tp.Tensor) -> tp.Tensor: + """Embeds mask inputs.""" + mask_embedding = masks + for l in self.mask_downscaling: + mask_embedding = l(mask_embedding) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[tp.Tensor, tp.Tensor]], + boxes: Optional[tp.Tensor], + masks: Optional[tp.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def __call__( + self, + points_x: Optional[tp.Tensor], + points_y: Optional[tp.Tensor], + boxes: Optional[tp.Tensor], + masks: Optional[tp.Tensor], + ) -> Tuple[tp.Tensor, tp.Tensor]: + return self.forward(points_x, points_y, boxes, masks) + + def forward( + self, + points_x: Optional[tp.Tensor], + points_y: Optional[tp.Tensor], + boxes: Optional[tp.Tensor], + masks: Optional[tp.Tensor], + ) -> Tuple[tp.Tensor, tp.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(tp.Tensor, tp.Tensor) or none): point coordinates + and labels to embed. + boxes (tp.Tensor or none): boxes to embed + masks (tp.Tensor or none): masks to embed + + Returns: + tp.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + tp.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + points = (points_x, points_y) + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = tp.zeros((bs, 0, self.embed_dim)) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = tp.concatenate([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = tp.concatenate([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = tp.reshape(self.no_mask_embed.weight, (1, -1, 1, 1)) + dense_embeddings = tp.expand( + dense_embeddings, + (bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]), + ) + + sparse_embeddings = tp.cast(sparse_embeddings, self.dtype) + dense_embeddings = tp.cast(dense_embeddings, self.dtype) + return sparse_embeddings, dense_embeddings diff --git a/tripy/examples/segment-anything-model-v2/sam2/modeling/sam/transformer.py b/tripy/examples/segment-anything-model-v2/sam2/modeling/sam/transformer.py new file mode 100755 index 000000000..7c0dfb15b --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/sam2/modeling/sam/transformer.py @@ -0,0 +1,348 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Tuple, Type + +import tripy as tp +from tripy import Tensor +from sam2.modeling.sam2_utils import MLP, scaled_dot_product_attention +from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis + + +class TwoWayTransformer(tp.Module): + + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[tp.Module] = tp.relu, + attention_downsample_rate: int = 2, + dtype=tp.float32, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = [] + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + dtype=dtype, + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, + num_heads, + downsample_rate=attention_downsample_rate, + dtype=dtype, + ) + self.norm_final_attn = tp.LayerNorm(embedding_dim) + + def __call__( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + + return self.forward(image_embedding, image_pe, point_embedding) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (tp.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (tp.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (tp.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + tp.Tensor: the processed point_embedding + tp.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = tp.permute(tp.flatten(image_embedding, 2), (0, 2, 1)) + image_pe = tp.permute(tp.flatten(image_pe, 2), (0, 2, 1)) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = tp.cast( + self.norm_final_attn(tp.cast(queries, self.norm_final_attn.dtype)), + queries.dtype, + ) + # queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(tp.Module): + + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[tp.Module] = tp.relu, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + dtype=tp.float32, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads, dtype=dtype) + self.norm1 = tp.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, + num_heads, + downsample_rate=attention_downsample_rate, + dtype=dtype, + ) + self.norm2 = tp.LayerNorm(embedding_dim) + + self.mlp = MLP( + embedding_dim, + mlp_dim, + embedding_dim, + num_layers=2, + activation=activation, + dtype=dtype, + ) + self.norm3 = tp.LayerNorm(embedding_dim) + + self.norm4 = tp.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, + num_heads, + downsample_rate=attention_downsample_rate, + dtype=dtype, + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def __call__(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]: + return self.forward(queries, keys, query_pe, key_pe) + + def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + + queries = tp.cast(self.norm1(tp.cast(queries, self.norm1.dtype)), queries.dtype) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + + queries = tp.cast(self.norm2(tp.cast(queries, self.norm2.dtype)), queries.dtype) + # queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = tp.cast(self.norm3(tp.cast(queries, self.norm3.dtype)), queries.dtype) + # queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = tp.cast(self.norm4(tp.cast(keys, self.norm4.dtype)), keys.dtype) + # keys = self.norm4(keys) + + return queries, keys + + +class Attention(tp.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + dropout: float = 0.0, + kv_in_dim: int = None, + dtype=tp.float32, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + self.q_proj = tp.Linear(embedding_dim, self.internal_dim, dtype=dtype) + self.k_proj = tp.Linear(self.kv_in_dim, self.internal_dim, dtype=dtype) + self.v_proj = tp.Linear(self.kv_in_dim, self.internal_dim, dtype=dtype) + self.out_proj = tp.Linear(self.internal_dim, embedding_dim, dtype=dtype) + + self.dropout_p = dropout + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape[0], x.shape[1], x.shape[2] + x = tp.reshape(x, [b, n, num_heads, c // num_heads]) + return tp.transpose(x, 1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_head, n_token, c_per_head = x.shape[0], x.shape[1], x.shape[2], x.shape[3] + x = tp.transpose(x, 1, 2) + return tp.reshape(x, [b, n_token, n_head * c_per_head]) # B x N_tokens x C + + def __call__(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + return self.forward(q, k, v) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + out = scaled_dot_product_attention(q, k, v, embedding_dim=k.shape[-1]) + out = self._recombine_heads(out) + out = self.out_proj(out) + return out + + +class RoPEAttention(Attention): + """Attention with rotary position encoding.""" + + def __init__( + self, + *args, + rope_theta=10000.0, + # whether to repeat q rope to match k length + # this is needed for cross-attention to memories + rope_k_repeat=False, + feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution + dtype="float32", + **kwargs, + ): + self.dtype = getattr(tp, dtype) + super().__init__(*args, dtype=self.dtype, **kwargs) + self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta) + freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) + self.freqs_cis = freqs_cis + self.rope_k_repeat = rope_k_repeat + + def __call__(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: tp.Tensor) -> Tensor: + return self.forward(q, k, v, num_k_exclude_rope) + + def forward(self, q: tp.Tensor, k: tp.Tensor, v: tp.Tensor, num_k_exclude_rope: tp.Tensor) -> tp.Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Apply rotary position encoding + # w = h = tp.DimensionSize(tp.cast(tp.sqrt(tp.cast(q.shape[-2], tp.float32)), tp.int32)) # DDS? + w = h = tp.DimensionSize(64) # Current demo always uses 64. + self.freqs_cis = self.compute_cis(end_x=w, end_y=h) + self.freqs_cis = tp.cast(self.freqs_cis, self.dtype) + + num_k_rope = k.shape[-2] - num_k_exclude_rope + q, new_k = apply_rotary_enc( + q, + k[:, :, :num_k_rope], + freqs_cis=self.freqs_cis, + repeat_freqs_k=self.rope_k_repeat, + ) + k = tp.concatenate([new_k, k[:, :, num_k_rope:, :]], dim=-2) + + # Attention + out = scaled_dot_product_attention(q, k, v, embedding_dim=k.shape[-1]) + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_base.py b/tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_base.py new file mode 100644 index 000000000..eb65f5395 --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_base.py @@ -0,0 +1,844 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed +import torch.nn.functional as F +import tripy as tp + +from torch.nn.init import trunc_normal_ + +from sam2.modeling.sam.mask_decoder import MaskDecoder +from sam2.modeling.sam.prompt_encoder import PromptEncoder +from sam2.modeling.sam.transformer import TwoWayTransformer + +from sam2.modeling.sam2_utils import get_1d_sine_pe, TorchMLP, select_closest_cond_frames + +# a large negative value as a placeholder score for missing objects +NO_OBJ_SCORE = -1024.0 + + +class SAM2Base(torch.nn.Module): + def __init__( + self, + image_encoder, + memory_attention, + memory_encoder, + num_maskmem=7, # default 1 input frame + 6 previous frames + image_size=512, + backbone_stride=16, # stride of the image backbone output + sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob + sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob + # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks + binarize_mask_from_pts_for_mem_enc=False, + use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder + # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit, + # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model + # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM. + max_cond_frames_in_attn=-1, + # on the first frame, whether to directly add the no-memory embedding to the image feature + # (instead of using the transformer encoder) + directly_add_no_mem_embed=False, + # whether to use high-resolution feature maps in the SAM mask decoder + use_high_res_features_in_sam=False, + # whether to output multiple (3) masks for the first click on initial conditioning frames + multimask_output_in_sam=False, + # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`; + # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points) + multimask_min_pt_num=1, + multimask_max_pt_num=1, + # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`) + multimask_output_for_tracking=False, + # Whether to use multimask tokens for obj ptr; Only relevant when both + # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True + use_multimask_token_for_obj_ptr: bool = False, + # whether to use sigmoid to restrict ious prediction to [0-1] + iou_prediction_use_sigmoid=False, + # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5). + # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of + # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame. + memory_temporal_stride_for_eval=1, + # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click + # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames + add_all_frames_to_correct_as_cond=False, + # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks) + non_overlap_masks_for_mem_enc=False, + # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder=False, + # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`) + max_obj_ptrs_in_encoder=16, + # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`) + add_tpos_enc_to_obj_ptrs=True, + # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference + # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) + proj_tpos_enc_in_obj_ptrs=False, + # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation + # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking) + only_obj_ptrs_in_the_past_for_eval=False, + # Whether to predict if there is an object in the frame + pred_obj_scores: bool = False, + # Whether to use an MLP to predict object scores + pred_obj_scores_mlp: bool = False, + # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True; + # Whether to have a fixed no obj pointer when there is no object present + # or to use it as an additive embedding with obj_ptr produced by decoder + fixed_no_obj_ptr: bool = False, + # Soft no object, i.e. mix in no_obj_ptr softly, + # hope to make recovery easier if there is a mistake and mitigate accumulation of errors + soft_no_obj_ptr: bool = False, + use_mlp_for_obj_ptr_proj: bool = False, + # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class. + sam_mask_decoder_extra_args=None, + model_precision="float32", + ): + super().__init__() + self.model_dtype = getattr(tp, model_precision) + # Part 1: the image backbone + self.image_encoder = image_encoder + # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting + self.use_high_res_features_in_sam = use_high_res_features_in_sam + self.num_feature_levels = 3 if use_high_res_features_in_sam else 1 + self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder + self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder + if use_obj_ptrs_in_encoder: + # A conv layer to downsample the mask prompt to stride 4 (the same stride as + # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale, + # so that it can be fed into the SAM mask decoder to generate a pointer. + self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4) + self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs + if proj_tpos_enc_in_obj_ptrs: + assert add_tpos_enc_to_obj_ptrs # these options need to be used together + self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs + self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval + + # Part 2: memory attention to condition current frame's visual features + # with memories (and obj ptrs) from past frames + self.memory_attention = memory_attention + self.hidden_dim = memory_attention.d_model + + # Part 3: memory encoder for the previous frame's outputs + self.memory_encoder = memory_encoder + self.mem_dim = self.hidden_dim + if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"): + # if there is compression of memories along channel dim + self.mem_dim = int(self.memory_encoder.out_proj.weight.shape[0]) + self.num_maskmem = num_maskmem # Number of memories accessible + # Temporal encoding of the memories + self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim)) + trunc_normal_(self.maskmem_tpos_enc, std=0.02) + # a single token to indicate no memory embedding from previous frames + self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + trunc_normal_(self.no_mem_embed, std=0.02) + trunc_normal_(self.no_mem_pos_enc, std=0.02) + self.directly_add_no_mem_embed = directly_add_no_mem_embed + # Apply sigmoid to the output raw mask logits (to turn them from + # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder + self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc + self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc + self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc + self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc + self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval + # On frames with mask input, whether to directly output the input mask without + # using a SAM prompt encoder + mask decoder + self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam + self.multimask_output_in_sam = multimask_output_in_sam + self.multimask_min_pt_num = multimask_min_pt_num + self.multimask_max_pt_num = multimask_max_pt_num + self.multimask_output_for_tracking = multimask_output_for_tracking + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid + + # Part 4: SAM-style prompt encoder (for both mask and point inputs) + # and SAM-style mask decoder for the final mask output + self.image_size = image_size + self.backbone_stride = backbone_stride + self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args + self.pred_obj_scores = pred_obj_scores + self.pred_obj_scores_mlp = pred_obj_scores_mlp + self.fixed_no_obj_ptr = fixed_no_obj_ptr + self.soft_no_obj_ptr = soft_no_obj_ptr + if self.fixed_no_obj_ptr: + assert self.pred_obj_scores + assert self.use_obj_ptrs_in_encoder + if self.pred_obj_scores and self.use_obj_ptrs_in_encoder: + self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) + trunc_normal_(self.no_obj_ptr, std=0.02) + self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj + + self._build_sam_heads() + self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond + self.max_cond_frames_in_attn = max_cond_frames_in_attn + + def _build_sam_heads(self): + """Build SAM-style prompt encoder and mask decoder.""" + self.sam_prompt_embed_dim = self.hidden_dim + self.sam_image_embedding_size = self.image_size // self.backbone_stride + + self.sam_prompt_encoder = PromptEncoder( + embed_dim=self.sam_prompt_embed_dim, + image_embedding_size=( + self.sam_image_embedding_size, + self.sam_image_embedding_size, + ), + input_image_size=(self.image_size, self.image_size), + mask_in_chans=16, + dtype=self.model_dtype, + ) + + transformer = TwoWayTransformer( + depth=2, + embedding_dim=self.sam_prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + dtype=self.model_dtype, + ) + self.sam_mask_decoder = MaskDecoder( + num_multimask_outputs=3, + transformer=transformer, + transformer_dim=self.sam_prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + use_high_res_features=self.use_high_res_features_in_sam, + iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, + pred_obj_scores=self.pred_obj_scores, + pred_obj_scores_mlp=self.pred_obj_scores_mlp, + use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, + dtype=self.model_dtype, + **(self.sam_mask_decoder_extra_args or {}), + ) + if self.use_obj_ptrs_in_encoder: + # a linear projection on SAM output tokens to turn them into object pointers + self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim) + if self.use_mlp_for_obj_ptr_proj: + self.obj_ptr_proj = TorchMLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3) + else: + self.obj_ptr_proj = torch.nn.Identity() + if self.proj_tpos_enc_in_obj_ptrs: + # a linear projection on temporal positional encoding in object pointers to + # avoid potential interference with spatial positional encoding + self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim) + else: + self.obj_ptr_tpos_proj = torch.nn.Identity() + + def _forward_sam_heads( + self, + backbone_features, + point_inputs=None, + mask_inputs=None, + high_res_features=None, + multimask_output=False, + ): + """ + Forward SAM prompt encoders and mask heads. + + Inputs: + - backbone_features: image features of [B, C, H, W] shape + - point_inputs: a dictionary with "point_coords" and "point_labels", where + 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the + absolute pixel-unit coordinate in (x, y) format of the P input points + 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means + positive clicks, 0 means negative clicks, and -1 means padding + - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the + same spatial size as the image. + - high_res_features: either 1) None or 2) or a list of length 2 containing + two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively, + which will be used as high-resolution feature maps for SAM decoder. + - multimask_output: if it's True, we output 3 candidate masks and their 3 + corresponding IoU estimates, and if it's False, we output only 1 mask and + its corresponding IoU estimate. + + Outputs: + - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if + `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM + output mask logits (before sigmoid) for the low-resolution masks, with 4x + the resolution (1/4 stride) of the input backbone_features. + - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3 + if `multimask_output=True` and M = 1 if `multimask_output=False`), + upsampled from the low-resolution masks, with shape size as the image + (stride is 1 pixel). + - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1 + if `multimask_output=False`), the estimated IoU of each output mask. + - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `low_res_multimasks`. + - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `high_res_multimasks`. + - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted + based on the output token from the SAM mask decoder. + """ + B = backbone_features.size(0) + device = backbone_features.device + assert backbone_features.size(1) == self.sam_prompt_embed_dim + assert backbone_features.size(2) == self.sam_image_embedding_size + assert backbone_features.size(3) == self.sam_image_embedding_size + + # Handle point prompts + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B + else: + # If no points are provide, pad with an empty point (with label -1) + sam_point_coords = torch.zeros(B, 1, 2, device=device) + sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + + sam_point_coords = tp.Tensor(sam_point_coords.contiguous()) + sam_point_labels = tp.Tensor(sam_point_labels.contiguous()) + + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points_x=sam_point_coords, + points_y=sam_point_labels, + ) + self.dense_pe = self.sam_prompt_encoder.get_dense_pe() + hres_1 = high_res_features[0] + hres_2 = high_res_features[1] + if self.model.model_dtype == tp.float16: + image_embedding = image_embedding.half() + hres_1 = hres_1.half() + hres_2 = hres_2.half() + + tp_backbone_features = tp.Tensor(backbone_features.contiguous()) + hres_1 = tp.Tensor(hres_1.contiguous()) + hres_2 = tp.Tensor(hres_2.contiguous()) + + if multimask_output: + ( + low_res_multimasks, + ious, + sam_output_tokens, + object_score_logits, + ) = self.sam_mask_decoder_true( + image_embeddings=tp_backbone_features, + image_pe=self.dense_pe, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + high_res_features_1=hres_1, + high_res_features_2=hres_2, + ) + + else: + ( + low_res_multimasks, + ious, + sam_output_tokens, + object_score_logits, + ) = self.sam_mask_decoder_false( + image_embeddings=tp_backbone_features, + image_pe=self.dense_pe, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + high_res_features_1=hres_1, + high_res_features_2=hres_2, + ) + + low_res_multimasks = torch.from_dlpack(low_res_multimasks).float() + ious = torch.from_dlpack(ious).float() + sam_output_tokens = torch.from_dlpack(sam_output_tokens).float() + object_score_logits = torch.from_dlpack(object_score_logits).float() + + if self.pred_obj_scores: + is_obj_appearing = object_score_logits > 0 + + # Mask used for spatial memories is always a *hard* choice between obj and no obj, + # consistent with the actual mask prediction + low_res_multimasks = torch.where( + is_obj_appearing[:, None, None], + low_res_multimasks, + NO_OBJ_SCORE, + ) + + # convert masks from possibly bfloat16 (or float16) to float32 + # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) + low_res_multimasks = low_res_multimasks.float() + high_res_multimasks = F.interpolate( + low_res_multimasks, + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + + sam_output_token = sam_output_tokens[:, 0] + if multimask_output: + # take the best mask prediction (with the highest IoU estimation) + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(B, device=device) + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + + # Extract object pointer from the SAM output token (with occlusion handling) + obj_ptr = self.obj_ptr_proj(sam_output_token) + if self.pred_obj_scores: + # Allow *soft* no obj ptr, unlike for masks + if self.soft_no_obj_ptr: + # Only hard possible with gt + assert not self.teacher_force_obj_scores_for_mem + lambda_is_obj_appearing = object_score_logits.sigmoid() + else: + lambda_is_obj_appearing = is_obj_appearing.float() + + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): + """ + Directly turn binary `mask_inputs` into a output mask logits without using SAM. + (same input and output shapes as in _forward_sam_heads above). + """ + # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid). + out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05 + mask_inputs_float = mask_inputs.float() + high_res_masks = mask_inputs_float * out_scale + out_bias + low_res_masks = F.interpolate( + high_res_masks, + size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + # a dummy IoU prediction of all 1's under mask input + ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float() + if not self.use_obj_ptrs_in_encoder: + # all zeros as a dummy object pointer (of shape [B, C]) + obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device) + else: + # produce an object pointer using the SAM decoder from the mask input + _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads( + backbone_features=backbone_features, + mask_inputs=self.mask_downsample(mask_inputs_float), + high_res_features=high_res_features, + ) + # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem; + # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying + # on the object_scores from the SAM decoder. + is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1) + is_obj_appearing = is_obj_appearing[..., None] + lambda_is_obj_appearing = is_obj_appearing.float() + object_score_logits = out_scale * lambda_is_obj_appearing + out_bias + if self.pred_obj_scores: + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_masks, + high_res_masks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def forward_image(self, img_batch: tp.Tensor): + """Get the image feature on the input batch.""" + + backbone_out = self.image_encoder(img_batch) + + if self.use_high_res_features_in_sam: + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + conv_s0_in = backbone_out["backbone_fpn"][0].contiguous() + conv_s1_in = backbone_out["backbone_fpn"][1].contiguous() + + if self.model_dtype == tp.float32: + conv_s0_in = tp.Tensor(conv_s0_in) + conv_s1_in = tp.Tensor(conv_s1_in) + else: + conv_s0_in = tp.Tensor(conv_s0_in.half()) + conv_s1_in = tp.Tensor(conv_s1_in.half()) + + backbone_out["backbone_fpn"][0] = torch.from_dlpack(self.sam_mask_decoder.conv_s0(conv_s0_in)) + backbone_out["backbone_fpn"][1] = torch.from_dlpack(self.sam_mask_decoder.conv_s1(conv_s1_in)) + + return backbone_out + + def _prepare_backbone_features(self, backbone_out): + """Prepare and flatten visual features.""" + backbone_out = backbone_out.copy() + assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) + assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels + + feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :] + vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :] + + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] + # flatten NxCxHxW to HWxNxC + vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] + vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] + + return backbone_out, vision_feats, vision_pos_embeds, feat_sizes + + def _prepare_memory_conditioned_features( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + ): + """Fuse the current frame's visual feature map with previous memory.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + device = current_vision_feats[-1].device + # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images. + # In this case, we skip the fusion with any memory. + if self.num_maskmem == 0: # Disable memory and skip fusion + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + return pix_feat + + num_obj_ptr_tokens = 0 + # Step 1: condition the visual features of the current frame on previous memories + if not is_init_cond_frame: + # Retrieve the memories encoded with the maskmem backbone + to_cat_memory, to_cat_memory_pos_embed = [], [] + # Add conditioning frames's output first (all cond frames have t_pos=0 for + # when getting temporal positional embedding below) + assert len(output_dict["cond_frame_outputs"]) > 0 + # Select a maximum number of temporally closest cond frames for cross attention + cond_outputs = output_dict["cond_frame_outputs"] + selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames( + frame_idx, cond_outputs, self.max_cond_frames_in_attn + ) + t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()] + # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory + # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 + # We also allow taking the memory frame non-consecutively (with r>1), in which case + # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame. + r = self.memory_temporal_stride_for_eval + for t_pos in range(1, self.num_maskmem): + t_rel = self.num_maskmem - t_pos # how many frames before current frame + if t_rel == 1: + # for t_rel == 1, we take the last frame (regardless of r) + if not track_in_reverse: + # the frame immediately before this frame (i.e. frame_idx - 1) + prev_frame_idx = frame_idx - t_rel + else: + # the frame immediately after this frame (i.e. frame_idx + 1) + prev_frame_idx = frame_idx + t_rel + else: + # for t_rel >= 2, we take the memory frame from every r-th frames + if not track_in_reverse: + # first find the nearest frame among every r-th frames before this frame + # for r=1, this would be (frame_idx - 2) + prev_frame_idx = ((frame_idx - 2) // r) * r + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx - (t_rel - 2) * r + else: + # first find the nearest frame among every r-th frames after this frame + # for r=1, this would be (frame_idx + 2) + prev_frame_idx = -(-(frame_idx + 2) // r) * r + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx + (t_rel - 2) * r + out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) + if out is None: + # If an unselected conditioning frame is among the last (self.num_maskmem - 1) + # frames, we still attend to it as if it's a non-conditioning frame. + out = unselected_cond_outputs.get(prev_frame_idx, None) + t_pos_and_prevs.append((t_pos, out)) + + for t_pos, prev in t_pos_and_prevs: + if prev is None: + continue # skip padding frames + # "maskmem_features" might have been offloaded to CPU in demo use cases, + # so we load it back to GPU (it's a no-op if it's already on GPU). + feats = prev["maskmem_features"].to(device, non_blocking=True) + to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) + # Spatial positional encoding (it might have been offloaded to CPU in eval) + maskmem_enc = prev["maskmem_pos_enc"][-1].to(device) + maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) + # Temporal positional encoding + maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1] + to_cat_memory_pos_embed.append(maskmem_enc) + + # Construct the list of past object pointers + if self.use_obj_ptrs_in_encoder: + max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder) + # First add those object pointers from selected conditioning frames + # (optionally, only include object pointers in the past during evaluation) + if not self.training and self.only_obj_ptrs_in_the_past_for_eval: + ptr_cond_outputs = { + t: out + for t, out in selected_cond_outputs.items() + if (t >= frame_idx if track_in_reverse else t <= frame_idx) + } + else: + ptr_cond_outputs = selected_cond_outputs + pos_and_ptrs = [ + # Temporal pos encoding contains how far away each pointer is from current frame + (abs(frame_idx - t), out["obj_ptr"]) + for t, out in ptr_cond_outputs.items() + ] + # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame + for t_diff in range(1, max_obj_ptrs_in_encoder): + t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff + if t < 0 or (num_frames is not None and t >= num_frames): + break + out = output_dict["non_cond_frame_outputs"].get(t, unselected_cond_outputs.get(t, None)) + if out is not None: + pos_and_ptrs.append((t_diff, out["obj_ptr"])) + # If we have at least one object pointer, add them to the across attention + if len(pos_and_ptrs) > 0: + pos_list, ptrs_list = zip(*pos_and_ptrs) + # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape + obj_ptrs = torch.stack(ptrs_list, dim=0) + # a temporal positional embedding based on how far each object pointer is from + # the current frame (sine embedding normalized by the max pointer num). + if self.add_tpos_enc_to_obj_ptrs: + t_diff_max = max_obj_ptrs_in_encoder - 1 + tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim + obj_pos = torch.tensor(pos_list, device=device) + obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) + obj_pos = self.obj_ptr_tpos_proj(obj_pos) + obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) + else: + obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim) + if self.mem_dim < C: + # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C + obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim) + obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1) + obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0) + to_cat_memory.append(obj_ptrs) + to_cat_memory_pos_embed.append(obj_pos) + num_obj_ptr_tokens = obj_ptrs.shape[0] + else: + num_obj_ptr_tokens = 0 + else: + # for initial conditioning frames, encode them without using any previous memory + if self.directly_add_no_mem_embed: + # directly add no-mem embedding (instead of using the transformer encoder) + pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + # Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder) + to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] + to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] + + # Step 2: Concatenate the memories and forward through the transformer encoder + memory = torch.cat(to_cat_memory, dim=0) + memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0) + if isinstance(self.memory_attention, tp.Module) or isinstance(self.memory_attention, tp.Executable): + fake_obj_ptrs = torch.ones((num_obj_ptr_tokens,), dtype=torch.int32) + pix_feat_with_mem = self.memory_attention( + curr=tp.Tensor(current_vision_feats[0].half().contiguous()), + memory=tp.Tensor(memory.half().contiguous()), + curr_pos=tp.Tensor(current_vision_pos_embeds[0].half().contiguous()), + memory_pos=tp.Tensor(memory_pos_embed.half().contiguous()), + num_obj_ptr_tokens=tp.Tensor(fake_obj_ptrs.contiguous()), + ) + else: + pix_feat_with_mem = self.memory_attention( + curr=current_vision_feats[0], + curr_pos=current_vision_pos_embeds[0], + memory=memory, + memory_pos=memory_pos_embed, + num_obj_ptr_tokens=num_obj_ptr_tokens, + ) + pix_feat_with_mem = torch.from_dlpack(pix_feat_with_mem) + # reshape the output (HW)BC => BCHW + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + def _encode_new_memory( + self, + current_vision_feats, + feat_sizes, + pred_masks_high_res, + is_mask_from_pts, + ): + """Encode the current image and its prediction into a memory feature.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + # top-level feature, (HW)BC => BCHW + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + if self.non_overlap_masks_for_mem_enc and not self.training: + # optionally, apply non-overlapping constraints to the masks (it's applied + # in the batch dimension and should only be used during eval, where all + # the objects come from the same video under batch size 1). + pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res) + # scale the raw mask logits with a temperature before applying sigmoid + binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts + if binarize and not self.training: + mask_for_mem = (pred_masks_high_res > 0).float() + else: + # apply sigmoid on the raw mask logits to turn them into range (0, 1) + mask_for_mem = torch.sigmoid(pred_masks_high_res) + # apply scale and bias terms to the sigmoid probabilities + if self.sigmoid_scale_for_mem_enc != 1.0: + mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc + if self.sigmoid_bias_for_mem_enc != 0.0: + mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc + + maskmem_features, maskmem_pos_enc = self.memory_encoder( + tp.Tensor(pix_feat.contiguous()), tp.Tensor(mask_for_mem.contiguous()) + ) # sigmoid already applied + maskmem_features = torch.from_dlpack(maskmem_features) + maskmem_pos_enc = [torch.from_dlpack(maskmem_pos_enc)] + + return maskmem_features, maskmem_pos_enc + + def track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + # Whether to run the memory encoder on the predicted masks. Sometimes we might want + # to skip the memory encoder with `run_mem_encoder=False`. For example, + # in demo we might call `track_step` multiple times for each user click, + # and only encode the memory when the user finalizes their clicks. And in ablation + # settings like SAM training on static images, we don't need the memory encoder. + run_mem_encoder=True, + # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). + prev_sam_mask_logits=None, + ): + current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(current_vision_feats) > 1: + high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) + ] + else: + high_res_features = None + if mask_inputs is not None and self.use_mask_input_as_output_without_sam: + # When use_mask_input_as_output_without_sam=True, we directly output the mask input + # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. + pix_feat = current_vision_feats[-1].permute(1, 2, 0) + pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs) + else: + # fused the visual feature with previous memory features in the memory bank + pix_feat_with_mem = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats[-1:], + current_vision_pos_embeds=current_vision_pos_embeds[-1:], + feat_sizes=feat_sizes[-1:], + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + ) + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + if prev_sam_mask_logits is not None: + assert point_inputs is not None and mask_inputs is None + mask_inputs = prev_sam_mask_logits + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat_with_mem, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + + ( + _, + _, + _, + low_res_masks, + high_res_masks, + obj_ptr, + _, + ) = sam_outputs + + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + if run_mem_encoder and self.num_maskmem > 0: + high_res_masks_for_mem_enc = high_res_masks + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks_for_mem_enc, + is_mask_from_pts=(point_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + return current_out + + def _use_multimask(self, is_init_cond_frame, point_inputs): + """Whether to use multimask output in the SAM head.""" + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) + multimask_output = ( + self.multimask_output_in_sam + and (is_init_cond_frame or self.multimask_output_for_tracking) + and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) + ) + return multimask_output + + def _apply_non_overlapping_constraints(self, pred_masks): + """ + Apply non-overlapping constraints to the object scores in pred_masks. Here we + keep only the highest scoring object at each spatial location in pred_masks. + """ + batch_size = pred_masks.size(0) + if batch_size == 1: + return pred_masks + + device = pred_masks.device + # "max_obj_inds": object index of the object with the highest score at each location + max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) + # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` + batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] + keep = max_obj_inds == batch_obj_inds + # suppress overlapping regions' scores below -10.0 so that the foreground regions + # don't overlap (here sigmoid(-10.0)=4.5398e-05) + pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) + return pred_masks diff --git a/tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_utils.py b/tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_utils.py new file mode 100755 index 000000000..0a64bd69a --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_utils.py @@ -0,0 +1,232 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import Optional, Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F +import tripy as tp + + +def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): + """ + Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs` + that are temporally closest to the current frame at `frame_idx`. Here, we take + - a) the closest conditioning frame before `frame_idx` (if any); + - b) the closest conditioning frame after `frame_idx` (if any); + - c) any other temporally closest conditioning frames until reaching a total + of `max_cond_frame_num` conditioning frames. + + Outputs: + - selected_outputs: selected items (keys & values) from `cond_frame_outputs`. + - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`. + """ + if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: + selected_outputs = cond_frame_outputs + unselected_outputs = {} + else: + assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" + selected_outputs = {} + + # the closest conditioning frame before `frame_idx` (if any) + idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) + if idx_before is not None: + selected_outputs[idx_before] = cond_frame_outputs[idx_before] + + # the closest conditioning frame after `frame_idx` (if any) + idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) + if idx_after is not None: + selected_outputs[idx_after] = cond_frame_outputs[idx_after] + + # add other temporally closest conditioning frames until reaching a total + # of `max_cond_frame_num` conditioning frames. + num_remain = max_cond_frame_num - len(selected_outputs) + inds_remain = sorted( + (t for t in cond_frame_outputs if t not in selected_outputs), + key=lambda x: abs(x - frame_idx), + )[:num_remain] + selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) + unselected_outputs = {t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs} + + return selected_outputs, unselected_outputs + + +def get_1d_sine_pe(pos_inds, dim, temperature=10000): + """ + Get 1D sine positional embedding as in the original Transformer paper. + """ + pe_dim = dim // 2 + dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) + dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) + + pos_embed = pos_inds.unsqueeze(-1) / dim_t + pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) + return pos_embed + + +def scaled_dot_product_attention( + query: tp.Tensor, + key: tp.Tensor, + value: tp.Tensor, + embedding_dim: Optional[int] = None, + attn_mask: Optional[tp.Tensor] = None, + is_causal: bool = False, +) -> tp.Tensor: + """ + Computes scaled dot-product attention. + `self` is the query tensor, `key` is the key tensor, and `value` is the value tensor. + + - Described: https://paperswithcode.com/method/scaled + - Paper: https://arxiv.org/abs/1706.03762v7 + """ + + if is_causal: + target_shape = query.shape[-2:-1] + key.shape[-2:-1] + # TODO: #228: WAR to prevent computing output rank in infer_rank for reshape + target_shape.trace_tensor.shape = (2,) + attn_mask = tp.cast(tp.tril(tp.ones(target_shape)), tp.bool) + if attn_mask is not None and attn_mask.dtype == tp.bool: + attn_mask = tp.where( + (attn_mask == 0), + tp.ones_like(attn_mask) * -float("inf"), + tp.zeros_like(attn_mask), + ) + if embedding_dim is None: + embedding_dim = query.shape[-1] + qk = query @ tp.transpose(key, -2, -1) / tp.sqrt(tp.cast(embedding_dim, query.dtype)) + return ( + tp.cast( + tp.softmax((qk + attn_mask) if attn_mask is not None else qk, -1), + query.dtype, + ) + @ value + ) + + +class TorchMLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + activation: nn.Module = nn.ReLU, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + self.sigmoid_output = sigmoid_output + self.act = activation() + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x + + +class MLP(tp.Module): + + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + activation: Callable[[tp.Tensor], tp.Tensor] = tp.relu, + sigmoid_output: bool = False, + dtype=tp.float32, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = [] + for n, k in zip([input_dim] + h, h + [output_dim]): + self.layers.append(tp.Linear(n, k, dtype=dtype)) + + self.sigmoid_output = sigmoid_output + self.act = activation + + def __call__(self, x): + for i, layer in enumerate(self.layers): + x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = tp.sigmoid(x) + return x + + +class LayerNorm2d(tp.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + from tripy.frontend.module.parameter import DefaultParameter + + self.weight = DefaultParameter((num_channels,), tp.float32) + self.bias = DefaultParameter((num_channels,), tp.float32) + self.eps = eps + + def __call__(self, x: tp.Tensor) -> tp.Tensor: + original_dtype = x.dtype + x = tp.cast(x, tp.float32) + u = tp.mean(x, dim=1, keepdim=True) + s = tp.mean((x - u) ** 2, dim=1, keepdim=True) + x = (x - u) / tp.sqrt(s + self.eps) + w = tp.unsqueeze(tp.unsqueeze(self.weight, 1), 2) + b = tp.unsqueeze(tp.unsqueeze(self.bias, 1), 2) + x = w * x + b + x = tp.cast(x, original_dtype) + return x + + +def get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return tp.relu + if activation == "gelu": + return tp.gelu + if activation == "glu": + return tp.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + + +def get_clones(module, N): + return [copy.deepcopy(module) for _ in range(N)] + + +def cartesian_via_polar(abs, angles): + r""" + Constructs the real-valued cartesian coordinates from magnitude and angle representing polar coordinates. For input + ``abs`` and ``angles`` of shape :math:`(m_1, m_2, \ldots, m_i),` this function returns a new real tensor of shape + """ + real = abs * tp.cos(angles) + imag = abs * tp.sin(angles) + return tp.stack([real, imag], dim=-1) + + +def mul_as_complex(tensor1, tensor2): + r""" + Multiplies two tensors (elementwise) as if they were complex-valued. + The last dimension for both tensors must be 2, representing the real and imaginary components. + """ + flattened1 = tensor1 + flattened2 = tensor2 + + real = flattened1[:, :, :, :, 0] * flattened2[:, :, :, :, 0] - flattened1[:, :, :, :, 1] * flattened2[:, :, :, :, 1] + imag = flattened1[:, :, :, :, 0] * flattened2[:, :, :, :, 1] + flattened1[:, :, :, :, 1] * flattened2[:, :, :, :, 0] + return tp.stack([real, imag], dim=-1) diff --git a/tripy/examples/segment-anything-model-v2/sam2/sam2_image_predictor.py b/tripy/examples/segment-anything-model-v2/sam2/sam2_image_predictor.py new file mode 100644 index 000000000..5a8bdc80d --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/sam2/sam2_image_predictor.py @@ -0,0 +1,298 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import tripy as tp + +from PIL.Image import Image + +from sam2.modeling.sam2_base import SAM2Base +from sam2.utils.transforms import SAM2Transforms + + +class SAM2ImagePredictor: + def __init__( + self, + sam_model: SAM2Base, + mask_threshold=0.0, + **kwargs, + ) -> None: + """ + Compute image embedding for a given image and then perform mask prediction using the user provided prompt. + """ + super().__init__() + self.model = sam_model + self.device = torch.device("cuda") + + # Transforms using torch + self._transforms = SAM2Transforms( + resolution=self.model.image_size, + ) + + # Predictor state + self._is_image_set = False + self._features = None + self._orig_hw = None + + # Predictor config + self.mask_threshold = mask_threshold + + # Spatial dim for backbone feature maps + self._bb_feat_sizes = [ + (256, 256), + (128, 128), + (64, 64), + ] + + @torch.no_grad() + def set_image_batch( + self, + image_list: List[Union[np.ndarray]], + ) -> None: + """ + Calculates the image embeddings for the provided image batch, allowing + masks to be predicted with the 'predict_batch' method. + + Arguments: + image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray + with pixel values in [0, 255]. + """ + self.reset_predictor() + assert isinstance(image_list, list) + self._orig_hw = [] + for image in image_list: + assert isinstance( + image, np.ndarray + ), "Images are expected to be an np.ndarray in RGB format, and of shape HWC" + self._orig_hw.append(image.shape[:2]) + assert len(set(self._orig_hw)) == 1, "Images in the batch must have the same size." + # Transform the image to the form expected by the model + img_batch = self._transforms.forward_batch(image_list) + img_batch = img_batch.to(self.device) + batch_size = img_batch.shape[0] + assert ( + len(img_batch.shape) == 4 and img_batch.shape[1] == 3 + ), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}" + logging.info("Computing image embeddings for the provided images...") + img_batch_tp = tp.Tensor(img_batch.to(getattr(torch, self.model.image_encoder.trunk.dtype)).contiguous()) + backbone_out = self.model.forward_image(img_batch_tp) + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) + # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos + if self.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed.to(vision_feats[-1].dtype) + + feats = [ + feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + self._is_image_set = True + self._is_batch = True + logging.info("Image embeddings computed.") + + def predict_batch( + self, + point_coords_batch: List[np.ndarray] = None, + point_labels_batch: List[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + normalize_coords=True, + ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: + """ + Predict masks for the given input prompts, using the currently set images. + + Arguments: + point_coords_batch: A list of Nx2 arrays of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels_batch: A list of length N arrays of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + multimask_output: If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits: If true, returns un-thresholded masks logits + instead of a binary mask. + normalize_coords: If true, the point coordinates will be normalized to + the range [0,1] and point_coords is expected to be wrt. image dimensions. + + Returns: + masks: A list of output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + ious: A list of arrays of length C containing the model's + predictions for the quality of each mask. + low_res_masks: A list of arrays of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + assert self._is_batch, "This function should only be used when in batched mode" + if not self._is_image_set: + raise RuntimeError("An image must be set with .set_image_batch(...) before mask prediction.") + + def concat_batch(x): + if x is None: + return x + return np.concatenate(x, axis=0) + + point_coords = concat_batch(point_coords_batch) + point_labels = concat_batch(point_labels_batch) + + _, unnorm_coords, labels, _ = self._prep_prompts( + point_coords, + point_labels, + None, # box + None, # mask_input + normalize_coords, + ) + masks, iou_predictions, low_res_masks = self._predict( + unnorm_coords, + labels, + multimask_output, + return_logits=return_logits, + ) + + def to_np_list(x): + x = x.float().detach().cpu().numpy() + return [xi for xi in x] + + all_masks = to_np_list(masks) + all_ious = to_np_list(iou_predictions) + all_low_res_masks = to_np_list(low_res_masks) + + return all_masks, all_ious, all_low_res_masks + + def _prep_prompts(self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1): + """ + point_coords: [B, 2] -> [B, 1, 2] + point_labels: [B] -> [B, 1] + """ + unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None + if point_coords is not None: + assert point_labels is not None, "point_labels must be supplied if point_coords is supplied." + point_coords = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) + unnorm_coords = self._transforms.transform_coords( + point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx] + ) + labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + if len(unnorm_coords.shape) == 2: + unnorm_coords, labels = unnorm_coords.unsqueeze(1), labels.unsqueeze(1) + if box is not None: + box = torch.as_tensor(box, dtype=torch.float, device=self.device) + unnorm_box = self._transforms.transform_boxes( + box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx] + ) # Bx2x2 + if mask_logits is not None: + mask_input = torch.as_tensor(mask_logits, dtype=torch.float, device=self.device) + if len(mask_input.shape) == 3: + mask_input = mask_input[None, :, :, :] + return mask_input, unnorm_coords, labels, unnorm_box + + @torch.no_grad() + def _predict( + self, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + multimask_output: bool = True, + return_logits: bool = False, + img_idx: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using SAM2Transforms. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self._is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( + points_x=tp.Tensor(point_coords.contiguous()), + points_y=tp.Tensor(point_labels.contiguous()), + ) + + # Predict masks + self.dense_pe = self.model.sam_prompt_encoder.get_dense_pe() + image_embedding = self._features["image_embed"].contiguous() + high_res_features_1 = self._features["high_res_feats"][0].contiguous() + high_res_features_2 = self._features["high_res_feats"][1].contiguous() + + if self.model.model_dtype == tp.float16: + image_embedding = image_embedding.half() + high_res_features_1 = high_res_features_1.half() + high_res_features_2 = high_res_features_2.half() + + low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder_true( + image_embeddings=tp.Tensor(image_embedding), + image_pe=self.dense_pe, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + high_res_features_1=tp.Tensor(high_res_features_1), + high_res_features_2=tp.Tensor(high_res_features_2), + ) + low_res_masks = torch.from_dlpack(low_res_masks) + iou_predictions = torch.from_dlpack(iou_predictions) + + # Upscale the masks to the original image resolution + masks = self._transforms.postprocess_masks(low_res_masks, self._orig_hw[img_idx]) + low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0) + if not return_logits: + masks = masks > self.mask_threshold + + return masks, iou_predictions, low_res_masks + + def reset_predictor(self) -> None: + """ + Resets the image embeddings and other state variables. + """ + self._is_image_set = False + self._features = None + self._orig_hw = None + self._is_batch = False diff --git a/tripy/examples/segment-anything-model-v2/sam2/utils/__init__.py b/tripy/examples/segment-anything-model-v2/sam2/utils/__init__.py new file mode 100644 index 000000000..4e3ee0298 --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/sam2/utils/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tripy/examples/segment-anything-model-v2/sam2/utils/transforms.py b/tripy/examples/segment-anything-model-v2/sam2/utils/transforms.py new file mode 100644 index 000000000..b0fcbe638 --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/sam2/utils/transforms.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import Normalize, Resize, ToTensor + + +class SAM2Transforms(nn.Module): + def __init__(self, resolution): + """ + Transforms for SAM2. + """ + super().__init__() + self.resolution = resolution + self.mean = [0.485, 0.456, 0.406] + self.std = [0.229, 0.224, 0.225] + self.to_tensor = ToTensor() + self.transforms = torch.jit.script( + nn.Sequential( + Resize((self.resolution, self.resolution)), + Normalize(self.mean, self.std), + ) + ) + + def __call__(self, x): + x = self.to_tensor(x) + return self.transforms(x) + + def forward_batch(self, img_list): + img_batch = [self.transforms(self.to_tensor(img)) for img in img_list] + img_batch = torch.stack(img_batch, dim=0) + return img_batch + + def transform_coords(self, coords: torch.Tensor, normalize=False, orig_hw=None) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, + If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + + Returns + Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model. + """ + if normalize: + assert orig_hw is not None + h, w = orig_hw + coords = coords.clone() + coords[..., 0] = coords[..., 0] / w + coords[..., 1] = coords[..., 1] / h + + coords = coords * self.resolution # unnormalize coords + return coords + + def transform_boxes(self, boxes: torch.Tensor, normalize=False, orig_hw=None) -> torch.Tensor: + """ + Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, + if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + """ + boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) + return boxes + + def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: + """ + Perform PostProcessing on output masks. + """ + masks = masks.float() + masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) + return masks diff --git a/tripy/tests/test_examples.py b/tripy/tests/test_examples.py index 1faf6ed01..272c23823 100644 --- a/tripy/tests/test_examples.py +++ b/tripy/tests/test_examples.py @@ -80,12 +80,40 @@ def __str__(self): return os.path.relpath(self.path, EXAMPLES_ROOT) -EXAMPLES = [Example(["nanogpt"])] +EXAMPLES = [ + Example(["nanogpt"]), + Example(["segment-anything-model-v2"], artifact_names=["truck.jpg", "saved_engines/", "output/", "checkpoints/"]), +] @pytest.mark.l1 @pytest.mark.parametrize("example", EXAMPLES, ids=lambda case: str(case)) def test_examples(example, sandboxed_install_run): + + def test_with_tolerance(expected, actual, tolerance): + return (abs(float(actual) - float(expected)) / float(expected)) * 100 <= float(tolerance) + + def process_tolerances(expected_output): + specs = [] + placeholder_regex = r"{(\d+\.?\d*)~(\d+)%}" + pattern = expected_output + + # Replace tolerance patterns with more flexible capture group + matches = list(re.finditer(placeholder_regex, pattern)) + for match in matches: + specs.append((match.group(1), match.group(2))) + pattern = pattern.replace(match.group(0), r"(\d+\.?\d*)", 1) + + # Escape parentheses but not our capture group + pattern = pattern.replace("(", r"\(") + pattern = pattern.replace(")", r"\)") + pattern = pattern.replace(r"\(\d+\.?\d*\)", r"(\d+\.?\d*)") + + # Make whitespace flexible + pattern = pattern.replace(" ", r"\s+") + + return pattern.strip(), specs + with open(example.readme, "r", encoding="utf-8") as f: contents = f.read() # Check that the README has all the expected sections. @@ -101,9 +129,20 @@ def test_examples(example, sandboxed_install_run): code = str(block) if block.has_marker("test: expected_stdout"): - print("Checking command output against expected output: ", end="") out = statuses[-1].stdout.strip() - matched = re.match(dedent(code).strip(), out) + expected = dedent(code).strip() + pattern, specs = process_tolerances(expected) + + match = re.search(pattern, out) + if match and specs: + # Check if captured numbers are within tolerance + matched = all( + test_with_tolerance(expected, actual, tolerance) + for (expected, tolerance), actual in zip(specs, match.groups()) + ) + else: + matched = bool(match) + print("matched!" if matched else "did not match!") print(f"==== STDOUT ====\n{out}") assert matched