diff --git a/.github/workflows/cuda/cu102-Linux.sh b/.github/workflows/cuda/cu102-Linux.sh index 46fb05319..ada39d54e 100644 --- a/.github/workflows/cuda/cu102-Linux.sh +++ b/.github/workflows/cuda/cu102-Linux.sh @@ -1,6 +1,8 @@ #!/bin/bash -OS=ubuntu1804 +# Strip the periods from the version number +OS_VERSION=$(echo $(lsb_release -sr) | tr -d .) +OS=ubuntu${OS_VERSION} wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600 diff --git a/.github/workflows/cuda/cu113-Linux.sh b/.github/workflows/cuda/cu113-Linux.sh index b89a7fb85..0b804d974 100644 --- a/.github/workflows/cuda/cu113-Linux.sh +++ b/.github/workflows/cuda/cu113-Linux.sh @@ -1,11 +1,17 @@ #!/bin/bash -OS=ubuntu1804 +# Strip the periods from the version number +OS_VERSION=$(echo $(lsb_release -sr) | tr -d .) +OS=ubuntu${OS_VERSION} wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600 wget -nv https://developer.download.nvidia.com/compute/cuda/11.3.0/local_installers/cuda-repo-${OS}-11-3-local_11.3.0-465.19.01-1_amd64.deb sudo dpkg -i cuda-repo-${OS}-11-3-local_11.3.0-465.19.01-1_amd64.deb + +# TODO: If on version < 22.04, install via signal-desktop-keyring +# For future versions it's deprecated and should be moved into the trusted folder +# sudo mv /var/cuda-repo-${OS}-11-3-local/7fa2af80.pub /etc/apt/trusted.gpg.d/ sudo apt-key add /var/cuda-repo-${OS}-11-3-local/7fa2af80.pub sudo apt-get -qq update diff --git a/.github/workflows/cuda/cu116-Linux.sh b/.github/workflows/cuda/cu116-Linux.sh index e3e4e2af7..f6ebbe3be 100644 --- a/.github/workflows/cuda/cu116-Linux.sh +++ b/.github/workflows/cuda/cu116-Linux.sh @@ -1,10 +1,13 @@ #!/bin/bash -OS=ubuntu1804 +# Strip the periods from the version number +OS_VERSION=$(echo $(lsb_release -sr) | tr -d .) +OS=ubuntu${OS_VERSION} wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600 wget -nv https://developer.download.nvidia.com/compute/cuda/11.6.2/local_installers/cuda-repo-${OS}-11-6-local_11.6.2-510.47.03-1_amd64.deb + sudo dpkg -i cuda-repo-${OS}-11-6-local_11.6.2-510.47.03-1_amd64.deb sudo apt-key add /var/cuda-repo-${OS}-11-6-local/7fa2af80.pub diff --git a/.github/workflows/cuda/cu117-Linux-env.sh b/.github/workflows/cuda/cu117-Linux-env.sh new file mode 100644 index 000000000..ab432d16f --- /dev/null +++ b/.github/workflows/cuda/cu117-Linux-env.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +CUDA_HOME=/usr/local/cuda-11.7 +LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} +PATH=${CUDA_HOME}/bin:${PATH} + +export FORCE_CUDA=1 +export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" +export CUDA_HOME=/usr/local/cuda-11.7 \ No newline at end of file diff --git a/.github/workflows/cuda/cu117-Linux.sh b/.github/workflows/cuda/cu117-Linux.sh new file mode 100644 index 000000000..40e66f385 --- /dev/null +++ b/.github/workflows/cuda/cu117-Linux.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +# Strip the periods from the version number +OS_VERSION=$(echo $(lsb_release -sr) | tr -d .) +OS=ubuntu${OS_VERSION} + +wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin +sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600 +wget -nv https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda-repo-${OS}-11-7-local_11.7.0-515.43.04-1_amd64.deb + +sudo dpkg -i cuda-repo-${OS}-11-7-local_11.7.0-515.43.04-1_amd64.deb +sudo cp /var/cuda-repo-${OS}-11-7-local/cuda-*-keyring.gpg /usr/share/keyrings/ + +sudo apt-get -qq update +sudo apt install cuda cuda-nvcc-11-7 cuda-libraries-dev-11-7 +sudo apt clean + +rm -f https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda-repo-${OS}-11-7-local_11.7.0-515.43.04-1_amd64.deb diff --git a/.github/workflows/cuda/cu120-Linux-env.sh b/.github/workflows/cuda/cu120-Linux-env.sh new file mode 100644 index 000000000..37917cc82 --- /dev/null +++ b/.github/workflows/cuda/cu120-Linux-env.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +CUDA_HOME=/usr/local/cuda-12.0 +LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} +PATH=${CUDA_HOME}/bin:${PATH} + +export FORCE_CUDA=1 +export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" +export CUDA_HOME=/usr/local/cuda-12.0 \ No newline at end of file diff --git a/.github/workflows/cuda/cu120-Linux.sh b/.github/workflows/cuda/cu120-Linux.sh new file mode 100644 index 000000000..56996dee2 --- /dev/null +++ b/.github/workflows/cuda/cu120-Linux.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +# Strip the periods from the version number +OS_VERSION=$(echo $(lsb_release -sr) | tr -d .) +OS=ubuntu${OS_VERSION} + +wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin +sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600 +wget -nv https://developer.download.nvidia.com/compute/cuda/12.0.0/local_installers/cuda-repo-${OS}-12-0-local_12.0.0-525.60.13-1_amd64.deb + +sudo dpkg -i cuda-repo-${OS}-12-0-local_12.0.0-525.60.13-1_amd64.deb +sudo cp /var/cuda-repo-${OS}-12-0-local/cuda-*-keyring.gpg /usr/share/keyrings/ + +sudo apt-get -qq update +sudo apt install cuda cuda-nvcc-12-0 cuda-libraries-dev-12-0 +sudo apt clean + +rm -f https://developer.download.nvidia.com/compute/cuda/12.0.0/local_installers/cuda-repo-${OS}-12-0-local_12.0.0-525.60.13-1_amd64.deb diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 72df6053c..bc01441ec 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -1,49 +1,83 @@ -# This workflow will upload a Python Package to Release asset +# This workflow will: +# - Create a new Github release +# - Build wheels for supported architectures +# - Deploy the wheels to the Github release +# - Release the static code to PyPi # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries +name: Build wheels and deploy -name: Python Package - +#on: +# create: +# tags: +# - '**' on: - create: - tags: - - '**' + push jobs: - release: - name: Create Release - runs-on: ubuntu-latest - steps: - - name: Get the tag version - id: extract_branch - run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/} - shell: bash + # setup_release: + # name: Create Release + # runs-on: ubuntu-latest + # steps: + # - name: Get the tag version + # id: extract_branch + # run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/} + # shell: bash - - name: Create Release - id: create_release - uses: actions/create-release@v1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - tag_name: ${{ steps.extract_branch.outputs.branch }} - release_name: ${{ steps.extract_branch.outputs.branch }} - - wheel: + # - name: Create Release + # id: create_release + # uses: actions/create-release@v1 + # env: + # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + # with: + # tag_name: ${{ steps.extract_branch.outputs.branch }} + # release_name: ${{ steps.extract_branch.outputs.branch }} + + build_wheels: name: Build Wheel runs-on: ${{ matrix.os }} - needs: release - + #needs: setup_release + strategy: fail-fast: false matrix: - # os: [ubuntu-20.04] - os: [ubuntu-18.04] - python-version: ['3.7', '3.8', '3.9', '3.10'] - torch-version: [1.11.0, 1.12.0, 1.12.1] - cuda-version: ['113', '116'] + os: [ubuntu-20.04, ubuntu-22.04] + #python-version: ['3.7', '3.8', '3.9', '3.10'] + #torch-version: ['1.11.0', '1.12.0', '1.13.0', '2.0.1'] + #cuda-version: ['113', '116', '117', '120'] + python-version: ['3.10'] + torch-version: ['2.0.1'] + cuda-version: ['120'] exclude: - - torch-version: 1.11.0 + # Nvidia only supports 11.7+ for ubuntu-22.04 + - os: ubuntu-22.04 cuda-version: '116' + - os: ubuntu-22.04 + cuda-version: '113' + # Torch only builds cuda 117 for 1.13.0+ + - cuda-version: '117' + torch-version: '1.11.0' + - cuda-version: '117' + torch-version: '1.12.0' + # Torch only builds cuda 116 for 1.12.0+ + - cuda-version: '116' + torch-version: '1.11.0' + # Torch only builds cuda 120 for 2.0.1+ + - cuda-version: '120' + torch-version: '1.11.0' + - cuda-version: '120' + torch-version: '1.12.0' + - cuda-version: '120' + torch-version: '1.13.0' + # 1.13.0 drops support for cuda 11.3 + - cuda-version: '113' + torch-version: '1.13.0' + - cuda-version: '113' + torch-version: '2.0.1' + # Fails with "Validation Error" on artifact upload + - cuda-version: '117' + torch-version: '1.13.0' + os: ubuntu-20.04 steps: - name: Checkout @@ -82,13 +116,24 @@ jobs: - name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }} run: | pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses && conda clean -ya - pip install --no-index --no-cache-dir torch==${{ matrix.torch-version }} -f https://download.pytorch.org/whl/cu${{ matrix.cuda-version }}/torch_stable.html + pip install --no-cache-dir torch==${{ matrix.torch-version }} python --version python -c "import torch; print('PyTorch:', torch.__version__)" python -c "import torch; print('CUDA:', torch.version.cuda)" python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" shell: bash + + # - name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }} + # run: | + # pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses && conda clean -ya + # pip install --no-index --no-cache-dir torch==${{ matrix.torch-version }} -f https://download.pytorch.org/whl/cu${{ matrix.cuda-version }}/torch_stable.html + # python --version + # python -c "import torch; print('PyTorch:', torch.__version__)" + # python -c "import torch; print('CUDA:', torch.version.cuda)" + # python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" + # shell: + # bash - name: Get the tag version id: extract_branch @@ -104,24 +149,60 @@ jobs: - name: Build wheel run: | + export FLASH_ATTENTION_FORCE_BUILD="TRUE" export FORCE_CUDA="1" export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH export CUDA_INSTALL_DIR=/usr/local/cuda-11.3$CUDA_INSTALL_DIR - pip install wheel + pip install ninja packaging setuptools wheel python setup.py bdist_wheel --dist-dir=dist tmpname=cu${{ matrix.cuda-version }}torch${{ matrix.torch-version }} wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") - ls dist/*whl |xargs -I {} mv {} ${wheel_name} + ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} echo "wheel_name=${wheel_name}" >> $GITHUB_ENV - - - name: Upload Release Asset - id: upload_release_asset - uses: actions/upload-release-asset@v1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.get_current_release.outputs.upload_url }} - asset_path: ./${{env.wheel_name}} - asset_name: ${{env.wheel_name}} - asset_content_type: application/* \ No newline at end of file + + - name: Log Built Wheels + run: | + ls dist + + # - name: Upload Release Asset + # id: upload_release_asset + # uses: actions/upload-release-asset@v1 + # env: + # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + # with: + # upload_url: ${{ steps.get_current_release.outputs.upload_url }} + # asset_path: ./dist/${{env.wheel_name}} + # asset_name: ${{env.wheel_name}} + # asset_content_type: application/* + + # publish_package: + # name: Publish package + # needs: [build_wheels] + + # runs-on: ubuntu-latest + + # steps: + # - uses: actions/checkout@v3 + + # - uses: actions/setup-python@v4 + # with: + # python-version: '3.10' + + # - name: Install dependencies + # run: | + # pip install ninja packaging setuptools wheel twine + # pip install torch + + # - name: Build core package + # env: + # FLASH_ATTENTION_SKIP_CUDA_BUILD: "TRUE" + # run: | + # python setup.py sdist --dist-dir=dist + + # - name: Deploy + # env: + # TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} + # TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} + # run: | + # python -m twine upload dist/* diff --git a/setup.py b/setup.py index 1cef26033..fa27ca24d 100644 --- a/setup.py +++ b/setup.py @@ -6,12 +6,16 @@ import ast from pathlib import Path from packaging.version import parse, Version +import platform from setuptools import setup, find_packages import subprocess +import urllib.request +import urllib.error import torch from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel with open("README.md", "r", encoding="utf-8") as fh: @@ -21,6 +25,30 @@ # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) +PACKAGE_NAME = "flash_attn" + +BASE_WHEEL_URL = "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}" + +# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels +# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation +FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE" +SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" + + +def get_platform(): + """ + Returns the platform name as used in wheel filenames. + """ + if sys.platform.startswith('linux'): + return 'linux_x86_64' + elif sys.platform == 'darwin': + mac_version = '.'.join(platform.mac_ver()[0].split('.')[:2]) + return f'macosx_{mac_version}_x86_64' + elif sys.platform == 'win32': + return 'win_amd64' + else: + raise ValueError('Unsupported platform: {}'.format(sys.platform)) + def get_cuda_bare_metal_version(cuda_dir): raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) @@ -90,102 +118,101 @@ def append_nvcc_threads(nvcc_extra_args): else: os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" - -print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) -TORCH_MAJOR = int(torch.__version__.split(".")[0]) -TORCH_MINOR = int(torch.__version__.split(".")[1]) - cmdclass = {} ext_modules = [] -# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h -# See https://github.com/pytorch/pytorch/pull/70650 -generator_flag = [] -torch_dir = torch.__path__[0] -if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): - generator_flag = ["-DOLD_GENERATOR_PATH"] - -raise_if_cuda_home_none("flash_attn") -# Check, if CUDA11 is installed for compute capability 8.0 -cc_flag = [] -_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) -if bare_metal_version < Version("11.0"): - raise RuntimeError("FlashAttention is only supported on CUDA 11 and above") -# cc_flag.append("-gencode") -# cc_flag.append("arch=compute_75,code=sm_75") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_80,code=sm_80") -if bare_metal_version >= Version("11.8"): +if not SKIP_CUDA_BUILD: + print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) + + # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h + # See https://github.com/pytorch/pytorch/pull/70650 + generator_flag = [] + torch_dir = torch.__path__[0] + if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): + generator_flag = ["-DOLD_GENERATOR_PATH"] + + raise_if_cuda_home_none("flash_attn") + # Check, if CUDA11 is installed for compute capability 8.0 + cc_flag = [] + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version < Version("11.0"): + raise RuntimeError("FlashAttention is only supported on CUDA 11 and above") + # cc_flag.append("-gencode") + # cc_flag.append("arch=compute_75,code=sm_75") cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - -subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) -ext_modules.append( - CUDAExtension( - name="flash_attn_2_cuda", - sources=[ - "csrc/flash_attn/flash_api.cpp", - "csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu", - ], - extra_compile_args={ - "cxx": ["-O3", "-std=c++17"] + generator_flag, - "nvcc": append_nvcc_threads( - [ - "-O3", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - "--ptxas-options=-v", - # "--ptxas-options=-O2", - "-lineinfo" - ] - + generator_flag - + cc_flag - ), - }, - include_dirs=[ - Path(this_dir) / 'csrc' / 'flash_attn', - Path(this_dir) / 'csrc' / 'flash_attn' / 'src', - Path(this_dir) / 'csrc' / 'cutlass' / 'include', - ], - ) -) + cc_flag.append("arch=compute_80,code=sm_80") + if bare_metal_version >= Version("11.8"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90,code=sm_90") + + subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) + ext_modules.append( + CUDAExtension( + name="flash_attn_2_cuda", + sources=[ + "csrc/flash_attn/flash_api.cpp", + "csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu", + ], + extra_compile_args={ + "cxx": ["-O3", "-std=c++17"] + generator_flag, + "nvcc": append_nvcc_threads( + [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "--ptxas-options=-v", + # "--ptxas-options=-O2", + "-lineinfo" + ] + + generator_flag + + cc_flag + ), + }, + include_dirs=[ + Path(this_dir) / 'csrc' / 'flash_attn', + Path(this_dir) / 'csrc' / 'flash_attn' / 'src', + Path(this_dir) / 'csrc' / 'cutlass' / 'include', + ], + ) def get_package_version(): @@ -199,8 +226,61 @@ def get_package_version(): return str(public_version) +class CachedWheelsCommand(_bdist_wheel): + """ + The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot + find an existing wheel (which is currently the case for all flash attention installs). We use + the environment parameters to detect whether there is already a pre-built version of a compatible + wheel available and short-circuits the standard full build pipeline. + + """ + def run(self): + if FORCE_BUILD: + return super().run() + + raise_if_cuda_home_none("flash_attn") + + # Determine the version numbers that will be used to determine the correct wheel + _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) + torch_version_raw = parse(torch.__version__) + python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" + platform_name = get_platform() + flash_version = get_package_version() + cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" + torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}.{torch_version_raw.micro}" + + # Determine wheel URL based on CUDA version, torch version, python version and OS + wheel_filename = f'{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}-{python_version}-{python_version}-{platform_name}.whl' + wheel_url = BASE_WHEEL_URL.format( + tag_name=f"v{flash_version}", + wheel_name=wheel_filename + ) + print("Guessing wheel URL: ", wheel_url) + + try: + urllib.request.urlretrieve(wheel_url, wheel_filename) + + # Make the archive + # Lifted from the root wheel processing command + # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 + if not os.path.exists(self.dist_dir): + os.makedirs(self.dist_dir) + + impl_tag, abi_tag, plat_tag = self.get_tag() + archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" + + wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") + print("Raw wheel path", wheel_path) + os.rename(wheel_filename, wheel_path) + except urllib.error.HTTPError: + print("Precompiled wheel not found. Building from source...") + # If the wheel could not be downloaded, build from source + super().run() + + setup( - name="flash_attn", + # @pierce - TODO: Revert for official release + name=PACKAGE_NAME, version=get_package_version(), packages=find_packages( exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",) @@ -208,6 +288,8 @@ def get_package_version(): author="Tri Dao", author_email="trid@cs.stanford.edu", description="Flash Attention: Fast and Memory-Efficient Exact Attention", + long_description=long_description, + long_description_content_type="text/markdown", url="https://github.com/Dao-AILab/flash-attention", classifiers=[ "Programming Language :: Python :: 3", @@ -215,7 +297,12 @@ def get_package_version(): "Operating System :: Unix", ], ext_modules=ext_modules, - cmdclass={"build_ext": BuildExtension} if ext_modules else {}, + cmdclass={ + 'bdist_wheel': CachedWheelsCommand, + "build_ext": BuildExtension + } if ext_modules else { + 'bdist_wheel': CachedWheelsCommand, + }, python_requires=">=3.7", install_requires=[ "torch",