Skip to content

Commit

Permalink
Make it optional to build CUDA extension for SAM 2; also fallback to …
Browse files Browse the repository at this point in the history
…math kernel if Flash Attention fails
  • Loading branch information
ronghanghu committed Aug 6, 2024
1 parent 0230c5f commit 509f0b1
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 26 deletions.
21 changes: 21 additions & 0 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,25 @@ Then, install SAM 2 from the root of this repository via
pip install -e ".[demo]"
```

If your environment doesn't support CUDA, you may turn off the CUDA extension building during installation via environment variable `SAM2_BUILD_CUDA=0`, as follows:
```bash
# skip CUDA extension
SAM2_BUILD_CUDA=0 pip install -e ".[demo]"
```

### Building SAM 2 CUDA extensions

If you see a message like `Failed to build SAM 2 CUDA extensions due to the error above` during installation or `Skipping the post-processing step due to the error above` at runtime, it indicates that the SAM 2 CUDA extension failed to build in your environment. In this case, you can still use SAM 2 for both image and video applications, but the post-processing step (removing small holes and sprinkles in the output masks) will be skipped. This shouldn't affect the results in most cases.

If you would like to enable this post-processing step, you can reinstall SAM 2 with environment variable `SAM2_BUILD_ALLOW_ERRORS=0` to force building CUDA extension, as follows
```bash
pip uninstall SAM-2; SAM2_BUILD_ALLOW_ERRORS=0 pip install -v -e ".[demo]"
```

Note that PyTorch needs to be installed first before building the SAM 2 CUDA extension. It's also necessary to install [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. (This should typically be CUDA 12.1 if you follow the default installation command.) After installing the CUDA toolkits, you can check its version via `nvcc --version`.

Please check the section below on common installation issues if the CUDA extension fails to build during installation or load at runtime.

### Common Installation Issues

Click each issue for its solutions:
Expand All @@ -22,6 +41,8 @@ I got `ImportError: cannot import name '_C' from 'sam2'`
<br/>

This is usually because you haven't run the `pip install -e ".[demo]"` step above or the installation failed. Please install SAM 2 first, and see the other issues if your installation fails.

In some systems, you may need to run `python setup.py build_ext --inplace` in the SAM 2 repo root as suggested in https://github.com/facebookresearch/segment-anything-2/issues/77.
</details>

<details>
Expand Down
25 changes: 19 additions & 6 deletions sam2/modeling/sam/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,25 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:

dropout_p = self.dropout_p if self.training else 0.0
# Attention
with torch.backends.cuda.sdp_kernel(
enable_flash=USE_FLASH_ATTN,
# if Flash attention kernel is off, then math kernel needs to be enabled
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
enable_mem_efficient=OLD_GPU,
):
try:
with torch.backends.cuda.sdp_kernel(
enable_flash=USE_FLASH_ATTN,
# if Flash attention kernel is off, then math kernel needs to be enabled
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
enable_mem_efficient=OLD_GPU,
):
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
except Exception as e:
# Fall back to all kernels if the Flash attention kernel fails
warnings.warn(
f"Flash Attention kernel failed to load: {e}\n"
"Falling back to math kernels for scaled_dot_product_attention (which may have a slower speed) ",
category=UserWarning,
stacklevel=2,
)
global MATH_KERNEL_ON, OLD_GPU
MATH_KERNEL_ON = True
OLD_GPU = True
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)

out = self._recombine_heads(out)
Expand Down
22 changes: 18 additions & 4 deletions sam2/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,24 @@ def fill_holes_in_mask_scores(mask, max_area):
# Holes are those connected components in background with area <= self.max_area
# (background regions are those with mask scores <= 0)
assert max_area > 0, "max_area must be positive"
labels, areas = get_connected_components(mask <= 0)
is_hole = (labels > 0) & (areas <= max_area)
# We fill holes with a small positive mask score (0.1) to change them to foreground.
mask = torch.where(is_hole, 0.1, mask)

input_mask = mask
try:
labels, areas = get_connected_components(mask <= 0, allow_cpu=False)
is_hole = (labels > 0) & (areas <= max_area)
# We fill holes with a small positive mask score (0.1) to change them to foreground.
mask = torch.where(is_hole, 0.1, mask)
except Exception as e:
# Skip the post-processing step on removing small holes if the CUDA kernel fails
warnings.warn(
f"{e}\n\nSkipping the post-processing step due to the error above. "
"Consider building SAM 2 with CUDA extension to enable post-processing (see "
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
category=UserWarning,
stacklevel=2,
)
mask = input_mask

return mask


Expand Down
48 changes: 33 additions & 15 deletions sam2/utils/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -78,22 +80,38 @@ def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
from sam2.utils.misc import get_connected_components

masks = masks.float()
if self.max_hole_area > 0:
# Holes are those connected components in background with area <= self.fill_hole_area
# (background regions are those with mask scores <= self.mask_threshold)
mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
labels, areas = get_connected_components(mask_flat <= self.mask_threshold)
is_hole = (labels > 0) & (areas <= self.max_hole_area)
is_hole = is_hole.reshape_as(masks)
# We fill holes with a small positive mask score (10.0) to change them to foreground.
masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
input_masks = masks
mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
try:
if self.max_hole_area > 0:
# Holes are those connected components in background with area <= self.fill_hole_area
# (background regions are those with mask scores <= self.mask_threshold)
labels, areas = get_connected_components(
mask_flat <= self.mask_threshold
)
is_hole = (labels > 0) & (areas <= self.max_hole_area)
is_hole = is_hole.reshape_as(masks)
# We fill holes with a small positive mask score (10.0) to change them to foreground.
masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)

if self.max_sprinkle_area > 0:
labels, areas = get_connected_components(mask_flat > self.mask_threshold)
is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
is_hole = is_hole.reshape_as(masks)
# We fill holes with negative mask score (-10.0) to change them to background.
masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
if self.max_sprinkle_area > 0:
labels, areas = get_connected_components(
mask_flat > self.mask_threshold
)
is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
is_hole = is_hole.reshape_as(masks)
# We fill holes with negative mask score (-10.0) to change them to background.
masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
except Exception as e:
# Skip the post-processing step if the CUDA kernel fails
warnings.warn(
f"{e}\n\nSkipping the post-processing step due to the error above. "
"Consider building SAM 2 with CUDA extension to enable post-processing (see "
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
category=UserWarning,
stacklevel=2,
)
masks = input_masks

masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
return masks
50 changes: 49 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os

from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
Expand Down Expand Up @@ -36,8 +37,15 @@
"dev": ["black==24.2.0", "usort==1.0.2", "ufmt==2.0.0b2"],
}

# allow turning off CUDA build with `export SAM2_BUILD_CUDA=0`
BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1"
BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"


def get_extensions():
if not BUILD_CUDA:
return []

srcs = ["sam2/csrc/connected_components.cu"]
compile_args = {
"cxx": [],
Expand All @@ -52,6 +60,40 @@ def get_extensions():
return ext_modules


class BuildExtensionIgnoreErrors(BuildExtension):
# Catch and skip errors during extension building and print a warning message
# (note that this message only shows up under verbose build mode
# "pip install -v -e ." or "python setup.py build_ext -v")
ERROR_MSG = (
"{}\n\n"
"Failed to build SAM 2 CUDA extensions due to the error above. "
"You can still use SAM 2, but some functionality may be limited (see "
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
)

def finalize_options(self):
try:
super().finalize_options()
except Exception as e:
print(self.ERROR_MSG.format(e))
self.extensions = []

def build_extensions(self):
try:
super().build_extensions()
except Exception as e:
print(self.ERROR_MSG.format(e))
self.extensions = []

def get_ext_filename(self, ext_name):
try:
return super().get_ext_filename(ext_name)
except Exception as e:
print(self.ERROR_MSG.format(e))
self.extensions = []
return "_C.so"


# Setup configuration
setup(
name=NAME,
Expand All @@ -68,5 +110,11 @@ def get_extensions():
extras_require=EXTRA_PACKAGES,
python_requires=">=3.10.0",
ext_modules=get_extensions(),
cmdclass={"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)},
cmdclass={
"build_ext": (
BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
if BUILD_ALLOW_ERRORS
else BuildExtension.with_options(no_python_abi_suffix=True)
),
},
)

0 comments on commit 509f0b1

Please sign in to comment.