Skip to content

Commit

Permalink
refactor: Refactor JIT and AOT build script (#567)
Browse files Browse the repository at this point in the history
Previously, JIT and AOT packaging is a bit broken. This PR produces good
sdist for JIT mode, and wheel for AOT mode.

## Changes

Common changes:
1. Remove the symlinks. Symlinks causes lots of duplication when search
in VSCode.
2. In package distribution (sdist or wheel), add data files to
`python/flashinfer/data/`, i.e. inside the python package folder. This
is strongly recommended by setuptools.
* Data files include: `version.txt`, FlashInfer headers, Cutlass
headers.
* Symlinks will be created when building wheel, and will be removed when
finished unless it's using `develop` command.
3. Exclude unneeded cutlass docs and files from wheel and sdist.

AOT changes:
1. Remove `flashinfer-aot` dir. Contents are moved to `python/`.
2. Merge all kernels into one pybind. This is good for compilation
speed. (`_kernels_sm90` is preserved as a separated `.so` file.)
3. AOT wheel can now be built with the following command:
    ```bash
    cd flashinfer/python
TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" python3 aot_setup.py bdist_wheel
    ls -la dist/
    ```
4. AOT wheel can also be built for editable install (develop purpose)
    ```bash
    cd flashinfer/python
    TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" python3 aot_setup.py develop
    ```

JIT changes:
1. JIT mode can now be installed in various ways:
    ```bash
    cd flashinfer/python
    pip install -v .     # Regular install from source
    pip install -v -e .  # Editable install

    python -m build --sdist               # Build sdist
    pip install dist/flashinfer-*.tar.gz  # Install from sdist
    ```

## Directory structure of built package

See attached.

[dir-wheel.txt](https://github.com/user-attachments/files/17562193/dir-wheel.txt)

[dir-sdist.txt](https://github.com/user-attachments/files/17562194/dir-sdist.txt)

## Tests

I was able to pass `pytest -sv test_norm.py test_bmm_fp8.py` using
various way of installation:

1. Editable install
2. Regular install from source
3. Install from sdist
4. Install from wheel
  • Loading branch information
abcdabcd987 authored Oct 30, 2024
1 parent e46d9a7 commit 7df90dd
Show file tree
Hide file tree
Showing 40 changed files with 425 additions and 344 deletions.
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ src/generated/
python/csrc/generated/
python/flashinfer/_build_meta.py
python/flashinfer/jit/aot_config.py
flashinfer-aot/csrc_aot/generated/
python/csrc_aot/generated/

# Package files
python/flashinfer/data/
python/flashinfer/version.txt
python/MANIFEST.in

# Generated documentation files
docs/generated
Expand Down
15 changes: 12 additions & 3 deletions docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ You can follow the steps below to install FlashInfer from source code:
pip install ninja
4. Compile FlashInfer:
4. Install FlashInfer:

.. tabs::

Expand All @@ -153,8 +153,17 @@ You can follow the steps below to install FlashInfer from source code:

.. code-block:: bash
cd flashinfer/flashinfer-aot
pip install -e . -v
cd flashinfer/python
TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" python3 aot_setup.py bdist_wheel
pip install dist/flashinfer-*.whl
.. tab:: Create sdist for JIT mode

.. code-block:: bash
cd flashinfer/python
python -m build --sdist
ls -la dist/
C++ API
-------
Expand Down
1 change: 0 additions & 1 deletion flashinfer-aot/3rdparty

This file was deleted.

12 changes: 0 additions & 12 deletions flashinfer-aot/MANIFEST.in

This file was deleted.

1 change: 0 additions & 1 deletion flashinfer-aot/csrc

This file was deleted.

45 changes: 0 additions & 45 deletions flashinfer-aot/csrc_aot/flashinfer_ops_decode.cu

This file was deleted.

56 changes: 0 additions & 56 deletions flashinfer-aot/csrc_aot/flashinfer_ops_prefill.cu

This file was deleted.

1 change: 0 additions & 1 deletion flashinfer-aot/flashinfer

This file was deleted.

1 change: 0 additions & 1 deletion flashinfer-aot/include

This file was deleted.

1 change: 0 additions & 1 deletion flashinfer-aot/version.txt

This file was deleted.

14 changes: 7 additions & 7 deletions include/flashinfer/attention/scheduler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__
* the new batch size after the partition.
*/
template <typename IdType>
auto PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(
inline auto PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(
const uint32_t max_grid_size, const uint32_t num_kv_heads, const std::vector<IdType>& num_pages,
const uint32_t min_num_pages_per_batch = 1) {
uint32_t low = min_num_pages_per_batch, high = 0;
Expand All @@ -77,7 +77,7 @@ auto PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(
return std::make_tuple(low, new_batch_size);
}

auto PrefillBinarySearchKVChunkSize(const uint32_t max_batch_size_if_split,
inline auto PrefillBinarySearchKVChunkSize(const uint32_t max_batch_size_if_split,
const std::vector<int64_t>& packed_qo_len_arr,
const std::vector<int64_t>& kv_len_arr,
const uint32_t qo_chunk_size,
Expand Down Expand Up @@ -129,7 +129,7 @@ auto PrefillBinarySearchKVChunkSize(const uint32_t max_batch_size_if_split,
*/
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE,
typename AttentionVariant>
cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
inline cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
bool& split_kv, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch,
uint32_t& new_batch_size, uint32_t batch_size, typename AttentionVariant::IdType* kv_indptr_h,
const uint32_t num_qo_heads, const uint32_t page_size, bool enable_cuda_graph,
Expand Down Expand Up @@ -201,7 +201,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
* \return status Indicates whether CUDA calls are successful
*/
template <typename IdType>
auto DecodeSplitKVIndptr(IdType* indptr_h, uint32_t batch_size, uint32_t kv_chunk_size) {
inline auto DecodeSplitKVIndptr(IdType* indptr_h, uint32_t batch_size, uint32_t kv_chunk_size) {
std::vector<IdType> request_indices, kv_tile_indices, o_indptr;
o_indptr.push_back(0);

Expand Down Expand Up @@ -277,7 +277,7 @@ struct DecodePlanInfo {
};

template <uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE, typename AttentionVariant>
cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer,
inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer,
void* page_locked_int_buffer, size_t int_workspace_size_in_bytes,
DecodePlanInfo& plan_info, typename AttentionVariant::IdType* indptr_h,
uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
Expand Down Expand Up @@ -350,7 +350,7 @@ cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in_bytes,
}

template <typename IdType>
auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size,
inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size,
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
uint32_t page_size, uint32_t max_batch_size_if_split,
bool enable_cuda_graph) {
Expand Down Expand Up @@ -520,7 +520,7 @@ struct PrefillPlanInfo {
};

template <typename IdType>
cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer,
inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer,
void* page_locked_int_buffer, size_t int_workspace_size_in_bytes,
PrefillPlanInfo& plan_info, IdType* qo_indptr_h, IdType* kv_indptr_h,
uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
Expand Down
1 change: 0 additions & 1 deletion python/3rdparty

This file was deleted.

12 changes: 0 additions & 12 deletions python/MANIFEST.in

This file was deleted.

Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
limitations under the License.
"""

import sys
import re
from literal_map import (
pos_encoding_mode_literal,
import sys
from pathlib import Path

from .literal_map import (
dtype_literal,
idtype_literal,
pos_encoding_mode_literal,
)
from pathlib import Path


def get_cu_file_str(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@
limitations under the License.
"""

import sys
import re
import itertools
from literal_map import (
mask_mode_literal,
pos_encoding_mode_literal,
import sys
from pathlib import Path

from .literal_map import (
dtype_literal,
idtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)
from pathlib import Path


def get_cu_file_str(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@
limitations under the License.
"""

import sys
import re
from literal_map import (
mask_mode_literal,
pos_encoding_mode_literal,
import sys
from pathlib import Path

from .literal_map import (
dtype_literal,
idtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)
from pathlib import Path


def get_cu_file_str(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@

import argparse
from pathlib import Path
from literal_map import (
pos_encoding_mode_literal,

from .literal_map import (
bool_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
limitations under the License.
"""

import sys
import re
from literal_map import (
pos_encoding_mode_literal,
import sys
from pathlib import Path

from .literal_map import (
dtype_literal,
pos_encoding_mode_literal,
)
from pathlib import Path


def get_cu_file_str(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
limitations under the License.
"""

import sys
import re
from literal_map import (
pos_encoding_mode_literal,
import sys
from pathlib import Path

from .literal_map import (
dtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)
from pathlib import Path


def get_cu_file_str(
Expand Down
File renamed without changes.
13 changes: 13 additions & 0 deletions python/aot_MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# MANIFEST.in for AOT wheel

prune */__pycache__
prune csrc
prune csrc_aot
exclude aot_setup.py
exclude setup.py

include flashinfer/data/version.txt
graft flashinfer/data/csrc
graft flashinfer/data/include
graft flashinfer/data/cutlass/include
graft flashinfer/data/cutlass/tools/util/include
Loading

0 comments on commit 7df90dd

Please sign in to comment.