Skip to content

Commit

Permalink
also catch errors during installation in case CUDAExtension cannot …
Browse files Browse the repository at this point in the history
…be loaded (#175)

Previously we only catch build errors in `BuildExtension` in https://github.com/facebookresearch/segment-anything-2/pull/155. However, in some cases, the `CUDAExtension` instance might not load. So in this PR, we also catch such errors for `CUDAExtension`.
  • Loading branch information
ronghanghu authored Aug 7, 2024
1 parent 6ecb5ff commit 6186d15
Showing 1 changed file with 32 additions and 23 deletions.
55 changes: 32 additions & 23 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,55 +44,64 @@
# You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`.
BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"

# Catch and skip errors during extension building and print a warning message
# (note that this message only shows up under verbose build mode
# "pip install -v -e ." or "python setup.py build_ext -v")
CUDA_ERROR_MSG = (
"{}\n\n"
"Failed to build the SAM 2 CUDA extension due to the error above. "
"You can still use SAM 2, but some post-processing functionality may be limited "
"(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
)


def get_extensions():
if not BUILD_CUDA:
return []

srcs = ["sam2/csrc/connected_components.cu"]
compile_args = {
"cxx": [],
"nvcc": [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
],
}
ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
try:
srcs = ["sam2/csrc/connected_components.cu"]
compile_args = {
"cxx": [],
"nvcc": [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
],
}
ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
except Exception as e:
if BUILD_ALLOW_ERRORS:
print(CUDA_ERROR_MSG.format(e))
ext_modules = []
else:
raise e

return ext_modules


class BuildExtensionIgnoreErrors(BuildExtension):
# Catch and skip errors during extension building and print a warning message
# (note that this message only shows up under verbose build mode
# "pip install -v -e ." or "python setup.py build_ext -v")
ERROR_MSG = (
"{}\n\n"
"Failed to build the SAM 2 CUDA extension due to the error above. "
"You can still use SAM 2, but some post-processing functionality may be limited "
"(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
)

def finalize_options(self):
try:
super().finalize_options()
except Exception as e:
print(self.ERROR_MSG.format(e))
print(CUDA_ERROR_MSG.format(e))
self.extensions = []

def build_extensions(self):
try:
super().build_extensions()
except Exception as e:
print(self.ERROR_MSG.format(e))
print(CUDA_ERROR_MSG.format(e))
self.extensions = []

def get_ext_filename(self, ext_name):
try:
return super().get_ext_filename(ext_name)
except Exception as e:
print(self.ERROR_MSG.format(e))
print(CUDA_ERROR_MSG.format(e))
self.extensions = []
return "_C.so"

Expand Down

0 comments on commit 6186d15

Please sign in to comment.