diff --git a/INSTALL.md b/INSTALL.md index be3b74c2..1582983d 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -11,6 +11,28 @@ Then, install SAM 2 from the root of this repository via pip install -e ".[demo]" ``` +Note that you may skip building the SAM 2 CUDA extension during installation via environment variable `SAM2_BUILD_CUDA=0`, as follows: +```bash +# skip the SAM 2 CUDA extension +SAM2_BUILD_CUDA=0 pip install -e ".[demo]" +``` +This would also skip the post-processing step at runtime (removing small holes and sprinkles in the output masks, which requires the CUDA extension), but shouldn't affect the results in most cases. + +### Building the SAM 2 CUDA extension + +By default, we allow the installation to proceed even if the SAM 2 CUDA extension fails to build. (In this case, the build errors are hidden unless using `-v` for verbose output in `pip install`.) + +If you see a message like `Skipping the post-processing step due to the error above` at runtime or `Failed to build the SAM 2 CUDA extension due to the error above` during installation, it indicates that the SAM 2 CUDA extension failed to build in your environment. In this case, you can still use SAM 2 for both image and video applications, but the post-processing step (removing small holes and sprinkles in the output masks) will be skipped. This shouldn't affect the results in most cases. + +If you would like to enable this post-processing step, you can reinstall SAM 2 on a GPU machine with environment variable `SAM2_BUILD_ALLOW_ERRORS=0` to force building the CUDA extension (and raise errors if it fails to build), as follows +```bash +pip uninstall -y SAM-2; SAM2_BUILD_ALLOW_ERRORS=0 pip install -v -e ".[demo]" +``` + +Note that PyTorch needs to be installed first before building the SAM 2 CUDA extension. It's also necessary to install [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. (This should typically be CUDA 12.1 if you follow the default installation command.) After installing the CUDA toolkits, you can check its version via `nvcc --version`. + +Please check the section below on common installation issues if the CUDA extension fails to build during installation or load at runtime. + ### Common Installation Issues Click each issue for its solutions: @@ -22,6 +44,8 @@ I got `ImportError: cannot import name '_C' from 'sam2'`
This is usually because you haven't run the `pip install -e ".[demo]"` step above or the installation failed. Please install SAM 2 first, and see the other issues if your installation fails. + +In some systems, you may need to run `python setup.py build_ext --inplace` in the SAM 2 repo root as suggested in https://github.com/facebookresearch/segment-anything-2/issues/77.
diff --git a/sam2/modeling/sam/transformer.py b/sam2/modeling/sam/transformer.py index 2a0b7960..b5b6fa2f 100644 --- a/sam2/modeling/sam/transformer.py +++ b/sam2/modeling/sam/transformer.py @@ -4,6 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import contextlib import math import warnings from functools import partial @@ -14,12 +15,30 @@ from torch import nn, Tensor from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis - from sam2.modeling.sam2_utils import MLP from sam2.utils.misc import get_sdpa_settings warnings.simplefilter(action="ignore", category=FutureWarning) +# Check whether Flash Attention is available (and use it by default) OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings() +# A fallback setting to allow all available kernels if Flash Attention fails +ALLOW_ALL_KERNELS = False + + +def sdp_kernel_context(dropout_p): + """ + Get the context for the attention scaled dot-product kernel. We use Flash Attention + by default, but fall back to all available kernels if Flash Attention fails. + """ + if ALLOW_ALL_KERNELS: + return contextlib.nullcontext() + + return torch.backends.cuda.sdp_kernel( + enable_flash=USE_FLASH_ATTN, + # if Flash attention kernel is off, then math kernel needs to be enabled + enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON, + enable_mem_efficient=OLD_GPU, + ) class TwoWayTransformer(nn.Module): @@ -246,12 +265,19 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: dropout_p = self.dropout_p if self.training else 0.0 # Attention - with torch.backends.cuda.sdp_kernel( - enable_flash=USE_FLASH_ATTN, - # if Flash attention kernel is off, then math kernel needs to be enabled - enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON, - enable_mem_efficient=OLD_GPU, - ): + try: + with sdp_kernel_context(dropout_p): + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + except Exception as e: + # Fall back to all kernels if the Flash attention kernel fails + warnings.warn( + f"Flash Attention kernel failed due to: {e}\nFalling back to all available " + f"kernels for scaled_dot_product_attention (which may have a slower speed).", + category=UserWarning, + stacklevel=2, + ) + global ALLOW_ALL_KERNELS + ALLOW_ALL_KERNELS = True out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) out = self._recombine_heads(out) @@ -313,12 +339,19 @@ def forward( dropout_p = self.dropout_p if self.training else 0.0 # Attention - with torch.backends.cuda.sdp_kernel( - enable_flash=USE_FLASH_ATTN, - # if Flash attention kernel is off, then math kernel needs to be enabled - enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON, - enable_mem_efficient=OLD_GPU, - ): + try: + with sdp_kernel_context(dropout_p): + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + except Exception as e: + # Fall back to all kernels if the Flash attention kernel fails + warnings.warn( + f"Flash Attention kernel failed due to: {e}\nFalling back to all available " + f"kernels for scaled_dot_product_attention (which may have a slower speed).", + category=UserWarning, + stacklevel=2, + ) + global ALLOW_ALL_KERNELS + ALLOW_ALL_KERNELS = True out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) out = self._recombine_heads(out) diff --git a/sam2/utils/misc.py b/sam2/utils/misc.py index bf6a1799..e4f5d4ee 100644 --- a/sam2/utils/misc.py +++ b/sam2/utils/misc.py @@ -220,10 +220,24 @@ def fill_holes_in_mask_scores(mask, max_area): # Holes are those connected components in background with area <= self.max_area # (background regions are those with mask scores <= 0) assert max_area > 0, "max_area must be positive" - labels, areas = get_connected_components(mask <= 0) - is_hole = (labels > 0) & (areas <= max_area) - # We fill holes with a small positive mask score (0.1) to change them to foreground. - mask = torch.where(is_hole, 0.1, mask) + + input_mask = mask + try: + labels, areas = get_connected_components(mask <= 0, allow_cpu=False) + is_hole = (labels > 0) & (areas <= max_area) + # We fill holes with a small positive mask score (0.1) to change them to foreground. + mask = torch.where(is_hole, 0.1, mask) + except Exception as e: + # Skip the post-processing step on removing small holes if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. " + "Consider building SAM 2 with CUDA extension to enable post-processing (see " + "https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + mask = input_mask + return mask diff --git a/sam2/utils/transforms.py b/sam2/utils/transforms.py index d05cd3e5..995baf98 100644 --- a/sam2/utils/transforms.py +++ b/sam2/utils/transforms.py @@ -4,6 +4,8 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import warnings + import torch import torch.nn as nn import torch.nn.functional as F @@ -78,22 +80,38 @@ def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: from sam2.utils.misc import get_connected_components masks = masks.float() - if self.max_hole_area > 0: - # Holes are those connected components in background with area <= self.fill_hole_area - # (background regions are those with mask scores <= self.mask_threshold) - mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image - labels, areas = get_connected_components(mask_flat <= self.mask_threshold) - is_hole = (labels > 0) & (areas <= self.max_hole_area) - is_hole = is_hole.reshape_as(masks) - # We fill holes with a small positive mask score (10.0) to change them to foreground. - masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) + input_masks = masks + mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image + try: + if self.max_hole_area > 0: + # Holes are those connected components in background with area <= self.fill_hole_area + # (background regions are those with mask scores <= self.mask_threshold) + labels, areas = get_connected_components( + mask_flat <= self.mask_threshold + ) + is_hole = (labels > 0) & (areas <= self.max_hole_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with a small positive mask score (10.0) to change them to foreground. + masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) - if self.max_sprinkle_area > 0: - labels, areas = get_connected_components(mask_flat > self.mask_threshold) - is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) - is_hole = is_hole.reshape_as(masks) - # We fill holes with negative mask score (-10.0) to change them to background. - masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) + if self.max_sprinkle_area > 0: + labels, areas = get_connected_components( + mask_flat > self.mask_threshold + ) + is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with negative mask score (-10.0) to change them to background. + masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) + except Exception as e: + # Skip the post-processing step if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. " + "Consider building SAM 2 with CUDA extension to enable post-processing (see " + "https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + masks = input_masks masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) return masks diff --git a/setup.py b/setup.py index 85ae842f..e8591a49 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import os from setuptools import find_packages, setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension @@ -36,8 +37,18 @@ "dev": ["black==24.2.0", "usort==1.0.2", "ufmt==2.0.0b2"], } +# By default, we also build the SAM 2 CUDA extension. +# You may turn off CUDA build with `export SAM2_BUILD_CUDA=0`. +BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1" +# By default, we allow SAM 2 installation to proceed even with build errors. +# You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`. +BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1" + def get_extensions(): + if not BUILD_CUDA: + return [] + srcs = ["sam2/csrc/connected_components.cu"] compile_args = { "cxx": [], @@ -52,6 +63,40 @@ def get_extensions(): return ext_modules +class BuildExtensionIgnoreErrors(BuildExtension): + # Catch and skip errors during extension building and print a warning message + # (note that this message only shows up under verbose build mode + # "pip install -v -e ." or "python setup.py build_ext -v") + ERROR_MSG = ( + "{}\n\n" + "Failed to build the SAM 2 CUDA extension due to the error above. " + "You can still use SAM 2, but some post-processing functionality may be limited " + "(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n" + ) + + def finalize_options(self): + try: + super().finalize_options() + except Exception as e: + print(self.ERROR_MSG.format(e)) + self.extensions = [] + + def build_extensions(self): + try: + super().build_extensions() + except Exception as e: + print(self.ERROR_MSG.format(e)) + self.extensions = [] + + def get_ext_filename(self, ext_name): + try: + return super().get_ext_filename(ext_name) + except Exception as e: + print(self.ERROR_MSG.format(e)) + self.extensions = [] + return "_C.so" + + # Setup configuration setup( name=NAME, @@ -68,5 +113,11 @@ def get_extensions(): extras_require=EXTRA_PACKAGES, python_requires=">=3.10.0", ext_modules=get_extensions(), - cmdclass={"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)}, + cmdclass={ + "build_ext": ( + BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True) + if BUILD_ALLOW_ERRORS + else BuildExtension.with_options(no_python_abi_suffix=True) + ), + }, )