Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make it optional to build CUDA extension for SAM 2; also fallback to all available kernels if Flash Attention fails #155

Merged

Conversation

ronghanghu
Copy link
Contributor

@ronghanghu ronghanghu commented Aug 6, 2024

In this PR, we make it optional to build the SAM 2 CUDA extension, in observation that many users encounter difficulties with the CUDA compilation step.

  1. During installation, we catch build errors and print a warning message. We also allow explicitly turning off the CUDA extension building with SAM2_BUILD_CUDA=0.
  2. At runtime, we catch CUDA kernel errors from connected components and print a warning on skipping the post processing step.

We also fall back to the all available kernels if the Flash Attention kernel fails.

@ronghanghu ronghanghu marked this pull request as draft August 6, 2024 05:42
@ronghanghu ronghanghu force-pushed the ronghanghu/cuda_kernel_optional branch 3 times, most recently from 509f0b1 to 268ad1c Compare August 6, 2024 15:08
@ronghanghu ronghanghu changed the title Make it optional to build CUDA extension for SAM 2; also fallback to math kernel if Flash Attention fails Make it optional to build CUDA extension for SAM 2; also fallback to all available kernels if Flash Attention fails Aug 6, 2024
@ronghanghu ronghanghu force-pushed the ronghanghu/cuda_kernel_optional branch 3 times, most recently from 8522a19 to 6943cf6 Compare August 6, 2024 17:05
@bhack
Copy link

bhack commented Aug 6, 2024

Do you think that Kornia like pure pytorch connected components will be too much numerically misaligned?

kornia/kornia#1184

@ronghanghu ronghanghu marked this pull request as ready for review August 6, 2024 17:36
@ronghanghu
Copy link
Contributor Author

Do you think that Kornia like pure pytorch connected components will be too much numerically misaligned?

kornia/kornia#1184

@bhack Thanks for the suggestion! We have also tried this kornia implementation before, but it was too slow for video applications (as it's using an iteration loop in Python and its algorithm has not been carefully optimized for GPUs), so we added a custom CUDA kernel in connected_components.cu instead, which is much faster.

…all available kernels if Flash Attention fails

In this PR, we make it optional to build the SAM 2 CUDA extension, in observation that many users encounter difficulties with the CUDA compilation step.
1. During installation, we catch build errors and print a warning message. We also allow explicitly turning off the CUDA extension building with `SAM2_BUILD_CUDA=0`.
2. At runtime, we catch CUDA kernel errors from connected components and print a warning on skipping the post processing step.

We also fall back to the all available kernels if the Flash Attention kernel fails.
@ronghanghu ronghanghu force-pushed the ronghanghu/cuda_kernel_optional branch from 6943cf6 to 1757177 Compare August 6, 2024 17:45
@bhack
Copy link

bhack commented Aug 6, 2024

Yes I know that it has loops. It is not easy to implement with pytorch ops. Have you benchmarked how is the pytorch compiler behaving with these loops?

@bhack
Copy link

bhack commented Aug 6, 2024

Quite funny that...
pytorch/pytorch#113538 (comment)

@ronghanghu
Copy link
Contributor Author

Yes I know that it has loops. It is not easy to implement with pytorch ops. Have you benchmarked how is the pytorch compiler behaving with these loops?

@bhack In our internal benchmarking, the custom CUDA kernel is much (~100x) faster than the kornia implementation even if we try to optimize the latter (e.g. via torch compilation). Another user also reported similar observations (prittt/YACCLAB#28 (comment)).

ronghanghu referenced this pull request Aug 7, 2024
…be loaded (#175)

Previously we only catch build errors in `BuildExtension` in https://github.com/facebookresearch/segment-anything-2/pull/155. However, in some cases, the `CUDAExtension` instance might not load. So in this PR, we also catch such errors for `CUDAExtension`.
fbcotter added a commit to wayveai/segment-anything-2 that referenced this pull request Sep 6, 2024
… into facebookresearch-main

* 'main' of github.com:facebookresearch/segment-anything-2: (40 commits)
  open `README.md` with unicode (to support Hugging Face emoji); fix various typos (facebookresearch#218)
  accept kwargs in auto_mask_generator
  Fix HF image predictor
  improving warning message and adding further tips for installation (facebookresearch#204)
  better support for non-CUDA devices (CPU, MPS) (facebookresearch#192)
  Update hieradet.py
  add Colab support to the notebooks; pack config files in `sam2_configs` package during installation (facebookresearch#176)
  also catch errors during installation in case `CUDAExtension` cannot be loaded (facebookresearch#175)
  Add interface for box prompt in SAM 2 video predictor (facebookresearch#174)
  Address comment
  Update hieradet.py
  Update docstrings
  Revert code snippet
  Updated INSTALL.md with CUDA_HOME-related troubleshooting (facebookresearch#140)
  Format using ufmt
  Update INSTALL.md (facebookresearch#156)
  Update README
  Make it optional to build CUDA extension for SAM 2; also fallback to all available kernels if Flash Attention fails (facebookresearch#155)
  Clean up
  Address comment
  ...
xydy666 pushed a commit to xydy666/segment-anything-2 that referenced this pull request Sep 17, 2024
…all available kernels if Flash Attention fails (facebookresearch#155)

In this PR, we make it optional to build the SAM 2 CUDA extension, in observation that many users encounter difficulties with the CUDA compilation step.
1. During installation, we catch build errors and print a warning message. We also allow explicitly turning off the CUDA extension building with `SAM2_BUILD_CUDA=0`.
2. At runtime, we catch CUDA kernel errors from connected components and print a warning on skipping the post processing step.

We also fall back to the all available kernels if the Flash Attention kernel fails.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants