Skip to content

Commit

Permalink
Make it optional to build CUDA extension for SAM 2; also fallback to …
Browse files Browse the repository at this point in the history
…math kernel if Flash Attention fails
  • Loading branch information
ronghanghu committed Aug 6, 2024
1 parent 0230c5f commit 268ad1c
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 33 deletions.
23 changes: 23 additions & 0 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,27 @@ Then, install SAM 2 from the root of this repository via
pip install -e ".[demo]"
```

If your environment doesn't support CUDA, you may turn off the CUDA extension building during installation via environment variable `SAM2_BUILD_CUDA=0`, as follows:
```bash
# skip CUDA extension
SAM2_BUILD_CUDA=0 pip install -e ".[demo]"
```

### Building SAM 2 CUDA extensions

By default, we allow the installation to proceed even if the SAM 2 CUDA extension fails to build.

If you see a message like `Failed to build SAM 2 CUDA extensions due to the error above` during installation or `Skipping the post-processing step due to the error above` at runtime, it indicates that the SAM 2 CUDA extension failed to build in your environment. In this case, you can still use SAM 2 for both image and video applications, but the post-processing step (removing small holes and sprinkles in the output masks) will be skipped. This shouldn't affect the results in most cases.

If you would like to enable this post-processing step, you can reinstall SAM 2 on a GPU machine with environment variable `SAM2_BUILD_ALLOW_ERRORS=0` to force building CUDA extension, as follows
```bash
pip uninstall SAM-2; SAM2_BUILD_ALLOW_ERRORS=0 pip install -v -e ".[demo]"
```

Note that PyTorch needs to be installed first before building the SAM 2 CUDA extension. It's also necessary to install [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. (This should typically be CUDA 12.1 if you follow the default installation command.) After installing the CUDA toolkits, you can check its version via `nvcc --version`.

Please check the section below on common installation issues if the CUDA extension fails to build during installation or load at runtime.

### Common Installation Issues

Click each issue for its solutions:
Expand All @@ -22,6 +43,8 @@ I got `ImportError: cannot import name '_C' from 'sam2'`
<br/>

This is usually because you haven't run the `pip install -e ".[demo]"` step above or the installation failed. Please install SAM 2 first, and see the other issues if your installation fails.

In some systems, you may need to run `python setup.py build_ext --inplace` in the SAM 2 repo root as suggested in https://github.com/facebookresearch/segment-anything-2/issues/77.
</details>

<details>
Expand Down
59 changes: 46 additions & 13 deletions sam2/modeling/sam/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import math
import warnings
from functools import partial
Expand All @@ -14,12 +15,30 @@
from torch import nn, Tensor

from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis

from sam2.modeling.sam2_utils import MLP
from sam2.utils.misc import get_sdpa_settings

warnings.simplefilter(action="ignore", category=FutureWarning)
# Check whether Flash Attention is available (and use it by default)
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
# A fallback setting to allow all available kernels if Flash Attention fails
ALLOW_ALL_KERNELS = False


def sdp_kernel_context(dropout_p):
"""
Get the context for the attention scaled dot-product kernel. We use Flash Attention
by default, but fall back to all available kernels if Flash Attention fails.
"""
if ALLOW_ALL_KERNELS:
return contextlib.nullcontext()

return torch.backends.cuda.sdp_kernel(
enable_flash=USE_FLASH_ATTN,
# if Flash attention kernel is off, then math kernel needs to be enabled
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
enable_mem_efficient=OLD_GPU,
)


class TwoWayTransformer(nn.Module):
Expand Down Expand Up @@ -246,12 +265,19 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:

dropout_p = self.dropout_p if self.training else 0.0
# Attention
with torch.backends.cuda.sdp_kernel(
enable_flash=USE_FLASH_ATTN,
# if Flash attention kernel is off, then math kernel needs to be enabled
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
enable_mem_efficient=OLD_GPU,
):
try:
with sdp_kernel_context(dropout_p):
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
except Exception as e:
# Fall back to all kernels if the Flash attention kernel fails
warnings.warn(
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
category=UserWarning,
stacklevel=2,
)
global ALLOW_ALL_KERNELS
ALLOW_ALL_KERNELS = True
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)

out = self._recombine_heads(out)
Expand Down Expand Up @@ -313,12 +339,19 @@ def forward(

dropout_p = self.dropout_p if self.training else 0.0
# Attention
with torch.backends.cuda.sdp_kernel(
enable_flash=USE_FLASH_ATTN,
# if Flash attention kernel is off, then math kernel needs to be enabled
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
enable_mem_efficient=OLD_GPU,
):
try:
with sdp_kernel_context(dropout_p):
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
except Exception as e:
# Fall back to all kernels if the Flash attention kernel fails
warnings.warn(
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
category=UserWarning,
stacklevel=2,
)
global ALLOW_ALL_KERNELS
ALLOW_ALL_KERNELS = True
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)

out = self._recombine_heads(out)
Expand Down
22 changes: 18 additions & 4 deletions sam2/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,24 @@ def fill_holes_in_mask_scores(mask, max_area):
# Holes are those connected components in background with area <= self.max_area
# (background regions are those with mask scores <= 0)
assert max_area > 0, "max_area must be positive"
labels, areas = get_connected_components(mask <= 0)
is_hole = (labels > 0) & (areas <= max_area)
# We fill holes with a small positive mask score (0.1) to change them to foreground.
mask = torch.where(is_hole, 0.1, mask)

input_mask = mask
try:
labels, areas = get_connected_components(mask <= 0, allow_cpu=False)
is_hole = (labels > 0) & (areas <= max_area)
# We fill holes with a small positive mask score (0.1) to change them to foreground.
mask = torch.where(is_hole, 0.1, mask)
except Exception as e:
# Skip the post-processing step on removing small holes if the CUDA kernel fails
warnings.warn(
f"{e}\n\nSkipping the post-processing step due to the error above. "
"Consider building SAM 2 with CUDA extension to enable post-processing (see "
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
category=UserWarning,
stacklevel=2,
)
mask = input_mask

return mask


Expand Down
48 changes: 33 additions & 15 deletions sam2/utils/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -78,22 +80,38 @@ def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
from sam2.utils.misc import get_connected_components

masks = masks.float()
if self.max_hole_area > 0:
# Holes are those connected components in background with area <= self.fill_hole_area
# (background regions are those with mask scores <= self.mask_threshold)
mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
labels, areas = get_connected_components(mask_flat <= self.mask_threshold)
is_hole = (labels > 0) & (areas <= self.max_hole_area)
is_hole = is_hole.reshape_as(masks)
# We fill holes with a small positive mask score (10.0) to change them to foreground.
masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
input_masks = masks
mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
try:
if self.max_hole_area > 0:
# Holes are those connected components in background with area <= self.fill_hole_area
# (background regions are those with mask scores <= self.mask_threshold)
labels, areas = get_connected_components(
mask_flat <= self.mask_threshold
)
is_hole = (labels > 0) & (areas <= self.max_hole_area)
is_hole = is_hole.reshape_as(masks)
# We fill holes with a small positive mask score (10.0) to change them to foreground.
masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)

if self.max_sprinkle_area > 0:
labels, areas = get_connected_components(mask_flat > self.mask_threshold)
is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
is_hole = is_hole.reshape_as(masks)
# We fill holes with negative mask score (-10.0) to change them to background.
masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
if self.max_sprinkle_area > 0:
labels, areas = get_connected_components(
mask_flat > self.mask_threshold
)
is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
is_hole = is_hole.reshape_as(masks)
# We fill holes with negative mask score (-10.0) to change them to background.
masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
except Exception as e:
# Skip the post-processing step if the CUDA kernel fails
warnings.warn(
f"{e}\n\nSkipping the post-processing step due to the error above. "
"Consider building SAM 2 with CUDA extension to enable post-processing (see "
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
category=UserWarning,
stacklevel=2,
)
masks = input_masks

masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
return masks
53 changes: 52 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os

from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
Expand Down Expand Up @@ -36,8 +37,18 @@
"dev": ["black==24.2.0", "usort==1.0.2", "ufmt==2.0.0b2"],
}

# By default, we alos build SAM 2 CUDA extensions.
# You may turn off CUDA build with `export SAM2_BUILD_CUDA=0`.
BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1"
# By default, we allow SAM 2 installation to proceed even with build errors.
# You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`.
BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"


def get_extensions():
if not BUILD_CUDA:
return []

srcs = ["sam2/csrc/connected_components.cu"]
compile_args = {
"cxx": [],
Expand All @@ -52,6 +63,40 @@ def get_extensions():
return ext_modules


class BuildExtensionIgnoreErrors(BuildExtension):
# Catch and skip errors during extension building and print a warning message
# (note that this message only shows up under verbose build mode
# "pip install -v -e ." or "python setup.py build_ext -v")
ERROR_MSG = (
"{}\n\n"
"Failed to build SAM 2 CUDA extensions due to the error above. "
"You can still use SAM 2, but some functionality may be limited (see "
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
)

def finalize_options(self):
try:
super().finalize_options()
except Exception as e:
print(self.ERROR_MSG.format(e))
self.extensions = []

def build_extensions(self):
try:
super().build_extensions()
except Exception as e:
print(self.ERROR_MSG.format(e))
self.extensions = []

def get_ext_filename(self, ext_name):
try:
return super().get_ext_filename(ext_name)
except Exception as e:
print(self.ERROR_MSG.format(e))
self.extensions = []
return "_C.so"


# Setup configuration
setup(
name=NAME,
Expand All @@ -68,5 +113,11 @@ def get_extensions():
extras_require=EXTRA_PACKAGES,
python_requires=">=3.10.0",
ext_modules=get_extensions(),
cmdclass={"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)},
cmdclass={
"build_ext": (
BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
if BUILD_ALLOW_ERRORS
else BuildExtension.with_options(no_python_abi_suffix=True)
),
},
)

0 comments on commit 268ad1c

Please sign in to comment.