Skip to content

Commit

Permalink
More improvements with sample usability
Browse files Browse the repository at this point in the history
  • Loading branch information
parthchadha committed Nov 12, 2024
1 parent 9c9c871 commit cdc57d1
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 261 deletions.
27 changes: 27 additions & 0 deletions tripy/examples/segment-anything-model-v2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SAM2: Segment Anything in Images and Videos

## Introduction

This is an implementation of SAM2 model ([original repository](https://github.com/facebookresearch/sam2/tree/main) by Meta).

## Running The Example

### Image pipeline

1. Install prerequisites:

```bash
apt-get install ffmpeg libsm6 libxext6 -y
mkdir checkpoints && cd checkpoints && wget https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt
python3 -m pip install -r requirements.txt
```

2. Run the example:

```bash
python3 image_demo.py
```

### Video segmentation pipeline

TBD
228 changes: 146 additions & 82 deletions tripy/examples/segment-anything-model-v2/image_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,58 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# sudo apt-get update && sudo apt-get install ffmpeg libsm6 libxext6 -y
import torch
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

import cv2
import os
import time
import numpy as np
import torch
import tripy as tp
import matplotlib.pyplot as plt

plt.switch_backend("agg") # Switch to non-interactive backend
from PIL import Image
import tripy as tp
from typing import Tuple, Optional, Dict

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

def show_mask(mask, ax, random_color=False, borders=True):

def process_and_show_mask(
mask: np.ndarray, ax: plt.Axes, random_color: bool = False, borders: bool = True
) -> np.ndarray:
"""
Process and display a segmentation mask, returning the processed mask for testing.
"""
# Generate mask color
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])

# Process mask
h, w = mask.shape[-2:]
mask = mask.astype(np.uint8)
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
if borders:
import cv2

if borders:
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
# Try to smooth contours
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)

ax.imshow(mask_image)
return mask_image


def show_points(coords, labels, ax, marker_size=375):
def show_points(
coords: np.ndarray, labels: np.ndarray, ax: plt.Axes, marker_size: int = 375
) -> Tuple[np.ndarray, np.ndarray]:
"""
Display point prompts and return point coordinates for testing.
"""
pos_points = coords[labels == 1]
neg_points = coords[labels == 0]

ax.scatter(
pos_points[:, 0],
pos_points[:, 1],
Expand All @@ -64,92 +83,137 @@ def show_points(coords, labels, ax, marker_size=375):
linewidth=1.25,
)

return pos_points, neg_points


def show_box(box, ax):
def show_box(box: np.ndarray, ax: plt.Axes) -> Tuple[float, float, float, float]:
"""
Display a bounding box and return its coordinates for testing.
"""
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2))
return x0, y0, w, h


def process_predictions(
image: np.ndarray,
masks: np.ndarray,
scores: np.ndarray,
logits: np.ndarray,
point_coords: Optional[np.ndarray] = None,
box_coords: Optional[np.ndarray] = None,
input_labels: Optional[np.ndarray] = None,
save_path: Optional[str] = None,
) -> Dict[str, np.ndarray]:
"""
Process and visualize predictions, returning a dictionary containing processed masks, scores, and logits.
"""
processed_masks = []

# Create output directory if it doesn't exist
if save_path:
os.makedirs(save_path, exist_ok=True)


def show_masks(
image,
masks,
scores,
point_coords=None,
box_coords=None,
input_labels=None,
borders=True,
):
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(mask, plt.gca(), borders=borders)

fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(image)

processed_mask = process_and_show_mask(mask, ax)
processed_masks.append(processed_mask)

if point_coords is not None:
assert input_labels is not None
show_points(point_coords, input_labels, plt.gca())
assert input_labels is not None, "Input labels required for point prompts"
show_points(point_coords, input_labels, ax)

if box_coords is not None:
# boxes
show_box(box_coords, plt.gca())
show_box(box_coords, ax)

if len(scores) > 1:
plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
plt.axis("off")
plt.show()
plt.savefig(f"mask{i}.png")
ax.set_title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)

# plt.axis("off")
ax.axis("off")

torch.set_printoptions(threshold=10)
if save_path:
plt.savefig(os.path.join(save_path, f"mask_{i}_score_{score:.3f}.png"), bbox_inches="tight", pad_inches=0)
plt.close(fig)

image = Image.open("truck.jpg")
image = np.array(image.convert("RGB"))
plt.figure(figsize=(10, 10))
plt.imshow(image)
plt.axis("on")
plt.savefig("foo.png")
return {
"masks": np.array(processed_masks),
"scores": scores,
"logits": logits,
}


from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
def main(image_path: str, save_path: Optional[str] = None):
"""
Main execution function.
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
device = torch.device("cuda")
sam2_model = build_sam2(
model_cfg,
sam2_checkpoint,
device=device,
)
Args:
image_path (str): Path to input image
save_path (str, optional): Directory to save visualizations
Returns:
Dict[str, np.ndarray]: Processing results
"""

# Load image
image = np.array(Image.open(image_path).convert("RGB"))

# Initialize SAM2 model
sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
device = torch.device("cuda")
sam2_model = build_sam2(
model_cfg,
sam2_checkpoint,
device=device,
)

# Create predictor and process image
predictor = SAM2ImagePredictor(sam2_model)

predictor.set_image(image)

# Set input prompt
input_point = np.array([[500, 375]])
input_label = np.array([1])

# Time mask prediction
start = time.perf_counter()
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)

# Synchronize CUDA operations
tp.default_stream().synchronize()
torch.cuda.synchronize()
prediction_time = (time.perf_counter() - start) * 1000
print(f"Prediction took {prediction_time:.2f}ms")

# Sort masks by confidence score
sorted_ind = np.argsort(scores)[::-1]
masks = masks[sorted_ind]
scores = scores[sorted_ind]
logits = logits[sorted_ind]

# Process and display results
results = process_predictions(
image,
masks,
scores,
logits,
point_coords=input_point,
input_labels=input_label,
save_path=save_path,
)

return results

import time

predictor = SAM2ImagePredictor(sam2_model)
start = time.perf_counter()
predictor.set_image(image)
end = time.perf_counter()
print(f"generate image embedding took {(end - start)*1000}")
input_point = np.array([[500, 375]])
input_label = np.array([1])

start = time.perf_counter()
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
tp.default_stream().synchronize()
torch.cuda.synchronize()
end = time.perf_counter()
print(f"exec took {(end - start)*1000}")

sorted_ind = np.argsort(scores)[::-1]
masks = masks[sorted_ind]
scores = scores[sorted_ind]
logits = logits[sorted_ind]

show_masks(
image,
masks,
scores,
point_coords=input_point,
input_labels=input_label,
borders=True,
)
if __name__ == "__main__":
main("truck.jpg", save_path="output")
1 change: 0 additions & 1 deletion tripy/examples/segment-anything-model-v2/sam2/build_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ def build_sam2(

for comp_name, comp_info in components.items():
if not comp_info["enabled"] or comp_name not in required_components_for_image:
print(comp_name)
continue

executable_file = os.path.join(saved_engines_path, comp_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,7 @@ def __call__(self, x):

class MaskDownSampler(tp.Module):
"""
Progressively downsample a mask by total_stride, each time by stride.
Note that LayerNorm is applied per *token*, like in ViT.
With each downsample (by a factor stride**2), channel capacity increases by the same factor.
In the end, we linearly project to embed_dim channels.
Progressively downsample a mask by total_stride.
"""

def __init__(
Expand Down Expand Up @@ -85,15 +81,8 @@ def forward(self, x):


class CXBlock(tp.Module):
r"""ConvNeXt Block. There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
We use (2) as we find it slightly faster in PyTorch
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
r"""ConvNeXt Block.
DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
"""

def __init__(
Expand Down Expand Up @@ -153,7 +142,6 @@ def __call__(self, x):
return self.forward(x)

def forward(self, x):
# normally x: (N, C, H, W)
x = self.proj(x)
for layer in self.layers:
x = layer(x)
Expand Down Expand Up @@ -193,8 +181,6 @@ def forward(
masks: tp.Tensor,
skip_mask_sigmoid: bool = False,
) -> Tuple[tp.Tensor, tp.Tensor]:
## Process masks
# sigmoid, so that less domain shift from gt masks which are bool
if not skip_mask_sigmoid:
masks = tp.sigmoid(masks)
masks = self.mask_downsampler(masks)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,6 @@ def __init__(
self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta)
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
self.freqs_cis = freqs_cis
print(f"rope_k_repeat : {rope_k_repeat}")
self.rope_k_repeat = rope_k_repeat

def __call__(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: tp.Tensor) -> Tensor:
Expand Down
Loading

0 comments on commit cdc57d1

Please sign in to comment.