diff --git a/tripy/examples/segment-anything-model-v2/README.md b/tripy/examples/segment-anything-model-v2/README.md new file mode 100644 index 000000000..252b20c84 --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/README.md @@ -0,0 +1,27 @@ +# SAM2: Segment Anything in Images and Videos + +## Introduction + +This is an implementation of SAM2 model ([original repository](https://github.com/facebookresearch/sam2/tree/main) by Meta). + +## Running The Example + +### Image pipeline + +1. Install prerequisites: + + ```bash + apt-get install ffmpeg libsm6 libxext6 -y + mkdir checkpoints && cd checkpoints && wget https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt + python3 -m pip install -r requirements.txt + ``` + +2. Run the example: + + ```bash + python3 image_demo.py + ``` + +### Video segmentation pipeline + +TBD diff --git a/tripy/examples/segment-anything-model-v2/image_demo.py b/tripy/examples/segment-anything-model-v2/image_demo.py index 2e776aae8..a0dcce045 100644 --- a/tripy/examples/segment-anything-model-v2/image_demo.py +++ b/tripy/examples/segment-anything-model-v2/image_demo.py @@ -12,39 +12,58 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# sudo apt-get update && sudo apt-get install ffmpeg libsm6 libxext6 -y -import torch -from sam2.build_sam import build_sam2 -from sam2.sam2_image_predictor import SAM2ImagePredictor + +import cv2 import os +import time import numpy as np import torch +import tripy as tp import matplotlib.pyplot as plt + +plt.switch_backend("agg") # Switch to non-interactive backend from PIL import Image -import tripy as tp +from typing import Tuple, Optional, Dict +from sam2.build_sam import build_sam2 +from sam2.sam2_image_predictor import SAM2ImagePredictor -def show_mask(mask, ax, random_color=False, borders=True): + +def process_and_show_mask( + mask: np.ndarray, ax: plt.Axes, random_color: bool = False, borders: bool = True +) -> np.ndarray: + """ + Process and display a segmentation mask, returning the processed mask for testing. + """ + # Generate mask color if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) + + # Process mask h, w = mask.shape[-2:] mask = mask.astype(np.uint8) mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) - if borders: - import cv2 + if borders: contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) - # Try to smooth contours contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) + ax.imshow(mask_image) + return mask_image -def show_points(coords, labels, ax, marker_size=375): +def show_points( + coords: np.ndarray, labels: np.ndarray, ax: plt.Axes, marker_size: int = 375 +) -> Tuple[np.ndarray, np.ndarray]: + """ + Display point prompts and return point coordinates for testing. + """ pos_points = coords[labels == 1] neg_points = coords[labels == 0] + ax.scatter( pos_points[:, 0], pos_points[:, 1], @@ -64,92 +83,137 @@ def show_points(coords, labels, ax, marker_size=375): linewidth=1.25, ) + return pos_points, neg_points + -def show_box(box, ax): +def show_box(box: np.ndarray, ax: plt.Axes) -> Tuple[float, float, float, float]: + """ + Display a bounding box and return its coordinates for testing. + """ x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)) + return x0, y0, w, h + + +def process_predictions( + image: np.ndarray, + masks: np.ndarray, + scores: np.ndarray, + logits: np.ndarray, + point_coords: Optional[np.ndarray] = None, + box_coords: Optional[np.ndarray] = None, + input_labels: Optional[np.ndarray] = None, + save_path: Optional[str] = None, +) -> Dict[str, np.ndarray]: + """ + Process and visualize predictions, returning a dictionary containing processed masks, scores, and logits. + """ + processed_masks = [] + + # Create output directory if it doesn't exist + if save_path: + os.makedirs(save_path, exist_ok=True) - -def show_masks( - image, - masks, - scores, - point_coords=None, - box_coords=None, - input_labels=None, - borders=True, -): for i, (mask, score) in enumerate(zip(masks, scores)): - plt.figure(figsize=(10, 10)) - plt.imshow(image) - show_mask(mask, plt.gca(), borders=borders) + + fig, ax = plt.subplots(figsize=(10, 10)) + ax.imshow(image) + + processed_mask = process_and_show_mask(mask, ax) + processed_masks.append(processed_mask) + if point_coords is not None: - assert input_labels is not None - show_points(point_coords, input_labels, plt.gca()) + assert input_labels is not None, "Input labels required for point prompts" + show_points(point_coords, input_labels, ax) + if box_coords is not None: - # boxes - show_box(box_coords, plt.gca()) + show_box(box_coords, ax) + if len(scores) > 1: - plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18) - plt.axis("off") - plt.show() - plt.savefig(f"mask{i}.png") + ax.set_title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18) + # plt.axis("off") + ax.axis("off") -torch.set_printoptions(threshold=10) + if save_path: + plt.savefig(os.path.join(save_path, f"mask_{i}_score_{score:.3f}.png"), bbox_inches="tight", pad_inches=0) + plt.close(fig) -image = Image.open("truck.jpg") -image = np.array(image.convert("RGB")) -plt.figure(figsize=(10, 10)) -plt.imshow(image) -plt.axis("on") -plt.savefig("foo.png") + return { + "masks": np.array(processed_masks), + "scores": scores, + "logits": logits, + } -from sam2.build_sam import build_sam2 -from sam2.sam2_image_predictor import SAM2ImagePredictor +def main(image_path: str, save_path: Optional[str] = None): + """ + Main execution function. -sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" -model_cfg = "sam2_hiera_l.yaml" -device = torch.device("cuda") -sam2_model = build_sam2( - model_cfg, - sam2_checkpoint, - device=device, -) + Args: + image_path (str): Path to input image + save_path (str, optional): Directory to save visualizations + + Returns: + Dict[str, np.ndarray]: Processing results + """ + + # Load image + image = np.array(Image.open(image_path).convert("RGB")) + + # Initialize SAM2 model + sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt" + model_cfg = "sam2_hiera_l.yaml" + device = torch.device("cuda") + sam2_model = build_sam2( + model_cfg, + sam2_checkpoint, + device=device, + ) + + # Create predictor and process image + predictor = SAM2ImagePredictor(sam2_model) + + predictor.set_image(image) + + # Set input prompt + input_point = np.array([[500, 375]]) + input_label = np.array([1]) + + # Time mask prediction + start = time.perf_counter() + masks, scores, logits = predictor.predict( + point_coords=input_point, + point_labels=input_label, + multimask_output=True, + ) + + # Synchronize CUDA operations + tp.default_stream().synchronize() + torch.cuda.synchronize() + prediction_time = (time.perf_counter() - start) * 1000 + print(f"Prediction took {prediction_time:.2f}ms") + + # Sort masks by confidence score + sorted_ind = np.argsort(scores)[::-1] + masks = masks[sorted_ind] + scores = scores[sorted_ind] + logits = logits[sorted_ind] + + # Process and display results + results = process_predictions( + image, + masks, + scores, + logits, + point_coords=input_point, + input_labels=input_label, + save_path=save_path, + ) + + return results -import time -predictor = SAM2ImagePredictor(sam2_model) -start = time.perf_counter() -predictor.set_image(image) -end = time.perf_counter() -print(f"generate image embedding took {(end - start)*1000}") -input_point = np.array([[500, 375]]) -input_label = np.array([1]) - -start = time.perf_counter() -masks, scores, logits = predictor.predict( - point_coords=input_point, - point_labels=input_label, - multimask_output=True, -) -tp.default_stream().synchronize() -torch.cuda.synchronize() -end = time.perf_counter() -print(f"exec took {(end - start)*1000}") - -sorted_ind = np.argsort(scores)[::-1] -masks = masks[sorted_ind] -scores = scores[sorted_ind] -logits = logits[sorted_ind] - -show_masks( - image, - masks, - scores, - point_coords=input_point, - input_labels=input_label, - borders=True, -) +if __name__ == "__main__": + main("truck.jpg", save_path="output") diff --git a/tripy/examples/segment-anything-model-v2/sam2/build_sam.py b/tripy/examples/segment-anything-model-v2/sam2/build_sam.py index 805eff170..07326f5c3 100644 --- a/tripy/examples/segment-anything-model-v2/sam2/build_sam.py +++ b/tripy/examples/segment-anything-model-v2/sam2/build_sam.py @@ -286,7 +286,6 @@ def build_sam2( for comp_name, comp_info in components.items(): if not comp_info["enabled"] or comp_name not in required_components_for_image: - print(comp_name) continue executable_file = os.path.join(saved_engines_path, comp_name) diff --git a/tripy/examples/segment-anything-model-v2/sam2/modeling/memory_encoder.py b/tripy/examples/segment-anything-model-v2/sam2/modeling/memory_encoder.py index be6c98456..5481a77e7 100644 --- a/tripy/examples/segment-anything-model-v2/sam2/modeling/memory_encoder.py +++ b/tripy/examples/segment-anything-model-v2/sam2/modeling/memory_encoder.py @@ -36,11 +36,7 @@ def __call__(self, x): class MaskDownSampler(tp.Module): """ - Progressively downsample a mask by total_stride, each time by stride. - Note that LayerNorm is applied per *token*, like in ViT. - - With each downsample (by a factor stride**2), channel capacity increases by the same factor. - In the end, we linearly project to embed_dim channels. + Progressively downsample a mask by total_stride. """ def __init__( @@ -85,15 +81,8 @@ def forward(self, x): class CXBlock(tp.Module): - r"""ConvNeXt Block. There are two equivalent implementations: - (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) - (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back - We use (2) as we find it slightly faster in PyTorch - - Args: - dim (int): Number of input channels. - drop_path (float): Stochastic depth rate. Default: 0.0 - layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + r"""ConvNeXt Block. + DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back """ def __init__( @@ -153,7 +142,6 @@ def __call__(self, x): return self.forward(x) def forward(self, x): - # normally x: (N, C, H, W) x = self.proj(x) for layer in self.layers: x = layer(x) @@ -193,8 +181,6 @@ def forward( masks: tp.Tensor, skip_mask_sigmoid: bool = False, ) -> Tuple[tp.Tensor, tp.Tensor]: - ## Process masks - # sigmoid, so that less domain shift from gt masks which are bool if not skip_mask_sigmoid: masks = tp.sigmoid(masks) masks = self.mask_downsampler(masks) diff --git a/tripy/examples/segment-anything-model-v2/sam2/modeling/sam/transformer.py b/tripy/examples/segment-anything-model-v2/sam2/modeling/sam/transformer.py index 8c6cac6f4..7c0dfb15b 100755 --- a/tripy/examples/segment-anything-model-v2/sam2/modeling/sam/transformer.py +++ b/tripy/examples/segment-anything-model-v2/sam2/modeling/sam/transformer.py @@ -309,7 +309,6 @@ def __init__( self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta) freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) self.freqs_cis = freqs_cis - print(f"rope_k_repeat : {rope_k_repeat}") self.rope_k_repeat = rope_k_repeat def __call__(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: tp.Tensor) -> Tensor: diff --git a/tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_base.py b/tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_base.py index a57a45ade..e2995aafb 100644 --- a/tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_base.py +++ b/tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_base.py @@ -22,7 +22,6 @@ import torch.distributed import torch.nn.functional as F import tripy as tp -import nvtx from torch.nn.init import trunc_normal_ @@ -36,18 +35,6 @@ NO_OBJ_SCORE = -1024.0 -markers = {} -events = {} - - -def profile_start(name, color="blue"): - markers[name] = nvtx.start_range(message=name, color=color) - - -def profile_stop(name): - nvtx.end_range(markers[name]) - - class SAM2Base(torch.nn.Module): def __init__( self, @@ -845,7 +832,6 @@ def track_step( mask_inputs = prev_sam_mask_logits multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) - profile_start("forward_sam_heads") sam_outputs = self._forward_sam_heads( backbone_features=pix_feat_with_mem, point_inputs=point_inputs, @@ -853,8 +839,6 @@ def track_step( high_res_features=high_res_features, multimask_output=multimask_output, ) - tp.default_stream().synchronize() - profile_stop("forward_sam_heads") ( _, diff --git a/tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_utils.py b/tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_utils.py index d17651b1b..181b4edf8 100755 --- a/tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_utils.py +++ b/tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_utils.py @@ -94,7 +94,7 @@ def scaled_dot_product_attention( - Paper: https://arxiv.org/abs/1706.03762v7 """ - if is_causal: # this path is not called in demoDiffusion + if is_causal: target_shape = query.shape[-2:-1] + key.shape[-2:-1] # TODO: #228: WAR to prevent computing output rank in infer_rank for reshape target_shape.trace_tensor.shape = (2,) diff --git a/tripy/examples/segment-anything-model-v2/test_backbone.py b/tripy/examples/segment-anything-model-v2/test_backbone.py deleted file mode 100644 index 3ac0dfd1f..000000000 --- a/tripy/examples/segment-anything-model-v2/test_backbone.py +++ /dev/null @@ -1,143 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import time -import tripy as tp -import torch - -from sam2.modeling.backbones import image_encoder, hieradet -from sam2.modeling.position_encoding import PositionEmbeddingSine - -from tripy.logging import logger - -logger.verbosity = "mlir" - - -############## trunk -- Hiera ############# -def test_trunk(): - trunk = hieradet.Hiera( - embed_dim=144, - num_heads=2, - stages=[2, 6, 36, 4], - global_att_blocks=[23, 33, 43], - window_pos_embed_bkg_spatial_size=[7, 7], - window_spec=[8, 4, 16, 8], - ) - trunk.generate_static_pos_embed() - # trunk_inp = tp.ones((1, 3, 1024, 1024)) - # trunk_out = trunk(trunk_inp) - # print(trunk_out[1]) - - print("Start compiling trunk...") - start = time.time() - compiled_tp_trunk = tp.compile( - trunk, - optimization_level=3, - args=[ - tp.InputInfo((1, 3, 1024, 1024), dtype=tp.float32), - ], - ) - print(f"Compile trunk took {time.time() - start}s") - - -############## neck -- FpnNeck ############# -def test_neck(): - position_encoding = PositionEmbeddingSine( - num_pos_feats=256, - normalize=True, - scale=None, - temperature=10000, - ) - neck = image_encoder.FpnNeck( - position_encoding=position_encoding, - d_model=256, - backbone_channel_list=[1152, 576, 288, 144], - fpn_top_down_levels=[2, 3], - fpn_interp_model="nearest", - ) - - # neck_inp = [ - # tp.ones([1, 144, 256, 256]), - # tp.ones([1, 288, 128, 128]), - # tp.ones([1, 576, 64, 64]), - # tp.ones([1, 1152, 32, 32]), - # ] - # neck_out = neck(*neck_inp) - # print(neck_out[3]) - - print("Start compiling FpnNeck...") - start = time.time() - compiled_tp_neck = tp.compile( - neck, - optimization_level=3, - args=[ - tp.InputInfo((1, 144, 256, 256), dtype=tp.float32), - tp.InputInfo((1, 288, 128, 128), dtype=tp.float32), - tp.InputInfo((1, 576, 64, 64), dtype=tp.float32), - tp.InputInfo((1, 1152, 32, 32), dtype=tp.float32), - ], - ) - print(f"Compile image encoder took {time.time() - start}s") - - -############ image_encoder (trunk + neck) ################## -def test_image_encoder(): - trunk = hieradet.Hiera( - embed_dim=144, - num_heads=2, - stages=[2, 6, 36, 4], - global_att_blocks=[23, 33, 43], - window_pos_embed_bkg_spatial_size=[7, 7], - window_spec=[8, 4, 16, 8], - ) - trunk.generate_static_pos_embed((256, 256)) - - position_encoding = PositionEmbeddingSine( - num_pos_feats=256, - normalize=True, - scale=None, - temperature=10000, - ) - neck = image_encoder.FpnNeck( - position_encoding=position_encoding, - d_model=256, - backbone_channel_list=[1152, 576, 288, 144], - fpn_top_down_levels=[2, 3], - fpn_interp_model="nearest", - ) - - encoder = image_encoder.ImageEncoder( - trunk=trunk, - neck=neck, - scalp=1, - ) - - # test eager mode - # inp = tp.ones((1, 3, 1024, 1024)) - # out = encoder(inp) - # print(out) - - # test compilation - print("Start compiling image encoder...") - start = time.time() - compiled_tp_image_encoder = tp.compile( - encoder.forward, - args=[ - tp.InputInfo((1, 3, 1024, 1024), dtype=tp.float32), - ], - ) - print(f"Compile image encoder took {time.time() - start}s") - - -test_image_encoder()