From 7df90ddfbb839c9e5b3723f43ab0b48a906ad286 Mon Sep 17 00:00:00 2001 From: Lequn Chen Date: Wed, 30 Oct 2024 00:20:25 -0700 Subject: [PATCH] refactor: Refactor JIT and AOT build script (#567) Previously, JIT and AOT packaging is a bit broken. This PR produces good sdist for JIT mode, and wheel for AOT mode. ## Changes Common changes: 1. Remove the symlinks. Symlinks causes lots of duplication when search in VSCode. 2. In package distribution (sdist or wheel), add data files to `python/flashinfer/data/`, i.e. inside the python package folder. This is strongly recommended by setuptools. * Data files include: `version.txt`, FlashInfer headers, Cutlass headers. * Symlinks will be created when building wheel, and will be removed when finished unless it's using `develop` command. 3. Exclude unneeded cutlass docs and files from wheel and sdist. AOT changes: 1. Remove `flashinfer-aot` dir. Contents are moved to `python/`. 2. Merge all kernels into one pybind. This is good for compilation speed. (`_kernels_sm90` is preserved as a separated `.so` file.) 3. AOT wheel can now be built with the following command: ```bash cd flashinfer/python TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" python3 aot_setup.py bdist_wheel ls -la dist/ ``` 4. AOT wheel can also be built for editable install (develop purpose) ```bash cd flashinfer/python TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" python3 aot_setup.py develop ``` JIT changes: 1. JIT mode can now be installed in various ways: ```bash cd flashinfer/python pip install -v . # Regular install from source pip install -v -e . # Editable install python -m build --sdist # Build sdist pip install dist/flashinfer-*.tar.gz # Install from sdist ``` ## Directory structure of built package See attached. [dir-wheel.txt](https://github.com/user-attachments/files/17562193/dir-wheel.txt) [dir-sdist.txt](https://github.com/user-attachments/files/17562194/dir-sdist.txt) ## Tests I was able to pass `pytest -sv test_norm.py test_bmm_fp8.py` using various way of installation: 1. Editable install 2. Regular install from source 3. Install from sdist 4. Install from wheel --- .gitignore | 7 +- docs/installation.rst | 15 +- flashinfer-aot/3rdparty | 1 - flashinfer-aot/MANIFEST.in | 12 - flashinfer-aot/csrc | 1 - .../csrc_aot/flashinfer_ops_decode.cu | 45 ---- .../csrc_aot/flashinfer_ops_prefill.cu | 56 ----- flashinfer-aot/flashinfer | 1 - flashinfer-aot/include | 1 - flashinfer-aot/version.txt | 1 - include/flashinfer/attention/scheduler.cuh | 14 +- python/3rdparty | 1 - python/MANIFEST.in | 12 - python/_aot_build_utils/__init__.py | 0 .../generate_batch_paged_decode_inst.py | 9 +- .../generate_batch_paged_prefill_inst.py | 12 +- .../generate_batch_ragged_prefill_inst.py | 11 +- .../generate_dispatch_inc.py | 5 +- .../generate_single_decode_inst.py | 9 +- .../generate_single_prefill_inst.py | 9 +- .../_aot_build_utils}/literal_map.py | 0 python/aot_MANIFEST.in | 13 + .../setup.py => python/aot_setup.py | 167 +++++++------ .../csrc_aot/activation.cu | 0 .../csrc_aot/batch_decode.cu | 0 .../csrc_aot/batch_prefill.cu | 0 .../csrc_aot/flashinfer_ops.cu | 232 +++++++++++++----- .../csrc_aot/flashinfer_sm90_ops.cu | 2 +- .../csrc_aot/pytorch_extension_utils.h | 0 .../csrc_aot/single_decode.cu | 0 .../csrc_aot/single_prefill.cu | 0 python/flashinfer/decode.py | 10 +- python/flashinfer/jit/env.py | 10 +- python/flashinfer/prefill.py | 16 +- python/flashinfer/triton/kernels/__init__.py | 0 python/include | 1 - python/jit_MANIFEST.in | 17 ++ python/setup.py | 75 ++++-- python/version.txt | 1 - scripts/run-ci-build-wheel.sh | 3 +- 40 files changed, 425 insertions(+), 344 deletions(-) delete mode 120000 flashinfer-aot/3rdparty delete mode 100644 flashinfer-aot/MANIFEST.in delete mode 120000 flashinfer-aot/csrc delete mode 100644 flashinfer-aot/csrc_aot/flashinfer_ops_decode.cu delete mode 100644 flashinfer-aot/csrc_aot/flashinfer_ops_prefill.cu delete mode 120000 flashinfer-aot/flashinfer delete mode 120000 flashinfer-aot/include delete mode 120000 flashinfer-aot/version.txt delete mode 120000 python/3rdparty delete mode 100644 python/MANIFEST.in create mode 100644 python/_aot_build_utils/__init__.py rename {flashinfer-aot => python/_aot_build_utils}/generate_batch_paged_decode_inst.py (98%) rename {flashinfer-aot => python/_aot_build_utils}/generate_batch_paged_prefill_inst.py (98%) rename {flashinfer-aot => python/_aot_build_utils}/generate_batch_ragged_prefill_inst.py (99%) rename {flashinfer-aot => python/_aot_build_utils}/generate_dispatch_inc.py (99%) rename {flashinfer-aot => python/_aot_build_utils}/generate_single_decode_inst.py (98%) rename {flashinfer-aot => python/_aot_build_utils}/generate_single_prefill_inst.py (98%) rename {flashinfer-aot => python/_aot_build_utils}/literal_map.py (100%) create mode 100644 python/aot_MANIFEST.in rename flashinfer-aot/setup.py => python/aot_setup.py (82%) rename {flashinfer-aot => python}/csrc_aot/activation.cu (100%) rename {flashinfer-aot => python}/csrc_aot/batch_decode.cu (100%) rename {flashinfer-aot => python}/csrc_aot/batch_prefill.cu (100%) rename {flashinfer-aot => python}/csrc_aot/flashinfer_ops.cu (64%) rename {flashinfer-aot => python}/csrc_aot/flashinfer_sm90_ops.cu (99%) rename {flashinfer-aot => python}/csrc_aot/pytorch_extension_utils.h (100%) rename {flashinfer-aot => python}/csrc_aot/single_decode.cu (100%) rename {flashinfer-aot => python}/csrc_aot/single_prefill.cu (100%) create mode 100644 python/flashinfer/triton/kernels/__init__.py delete mode 120000 python/include create mode 100644 python/jit_MANIFEST.in delete mode 120000 python/version.txt diff --git a/.gitignore b/.gitignore index 14efeef1..fa13a77b 100644 --- a/.gitignore +++ b/.gitignore @@ -13,7 +13,12 @@ src/generated/ python/csrc/generated/ python/flashinfer/_build_meta.py python/flashinfer/jit/aot_config.py -flashinfer-aot/csrc_aot/generated/ +python/csrc_aot/generated/ + +# Package files +python/flashinfer/data/ +python/flashinfer/version.txt +python/MANIFEST.in # Generated documentation files docs/generated diff --git a/docs/installation.rst b/docs/installation.rst index 35bf84a1..4a423305 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -138,7 +138,7 @@ You can follow the steps below to install FlashInfer from source code: pip install ninja -4. Compile FlashInfer: +4. Install FlashInfer: .. tabs:: @@ -153,8 +153,17 @@ You can follow the steps below to install FlashInfer from source code: .. code-block:: bash - cd flashinfer/flashinfer-aot - pip install -e . -v + cd flashinfer/python + TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" python3 aot_setup.py bdist_wheel + pip install dist/flashinfer-*.whl + + .. tab:: Create sdist for JIT mode + + .. code-block:: bash + + cd flashinfer/python + python -m build --sdist + ls -la dist/ C++ API ------- diff --git a/flashinfer-aot/3rdparty b/flashinfer-aot/3rdparty deleted file mode 120000 index 303a6484..00000000 --- a/flashinfer-aot/3rdparty +++ /dev/null @@ -1 +0,0 @@ -../3rdparty \ No newline at end of file diff --git a/flashinfer-aot/MANIFEST.in b/flashinfer-aot/MANIFEST.in deleted file mode 100644 index b20747fe..00000000 --- a/flashinfer-aot/MANIFEST.in +++ /dev/null @@ -1,12 +0,0 @@ -# sdist & wheel -include version.txt -recursive-include include * -recursive-include csrc * -recursive-include 3rdparty/cutlass * - -# wheel-only -exclude flashinfer/_build_meta.py - -# Unneeded files -prune */__pycache__ -global-exclude *.so diff --git a/flashinfer-aot/csrc b/flashinfer-aot/csrc deleted file mode 120000 index bf562722..00000000 --- a/flashinfer-aot/csrc +++ /dev/null @@ -1 +0,0 @@ -../python/csrc \ No newline at end of file diff --git a/flashinfer-aot/csrc_aot/flashinfer_ops_decode.cu b/flashinfer-aot/csrc_aot/flashinfer_ops_decode.cu deleted file mode 100644 index fe665a1d..00000000 --- a/flashinfer-aot/csrc_aot/flashinfer_ops_decode.cu +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v, - torch::Tensor tmp, - std::optional alibi_slopes, - unsigned int layout, int window_left, - float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta); - -std::vector BatchDecodeWithPagedKVCachePlan( - bool use_logits_soft_cap, unsigned int head_dim, torch::Tensor empty_q_data, - torch::Tensor empty_kv_data, torch::Tensor float_workspace_buffer, - torch::Tensor int_workspace_buffer, torch::Tensor page_locked_int_workspace_buffer, - torch::Tensor indptr, unsigned int batch_size, unsigned int num_qo_heads, - unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph); - -torch::Tensor BatchDecodeWithPagedKVCacheRun( - torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, - std::vector plan_info_vec, torch::Tensor q, torch::Tensor paged_k_cache, - torch::Tensor paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, std::optional alibi_slopes, - unsigned int kv_layout_code, int window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, std::optional maybe_lse); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("single_decode_with_kv_cache", &single_decode_with_kv_cache, - "Single-request decode with KV-Cache operator"); - m.def("batch_decode_with_paged_kv_cache_plan", &BatchDecodeWithPagedKVCachePlan); - m.def("batch_decode_with_paged_kv_cache_run", &BatchDecodeWithPagedKVCacheRun); -} diff --git a/flashinfer-aot/csrc_aot/flashinfer_ops_prefill.cu b/flashinfer-aot/csrc_aot/flashinfer_ops_prefill.cu deleted file mode 100644 index 2f353d02..00000000 --- a/flashinfer-aot/csrc_aot/flashinfer_ops_prefill.cu +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -torch::Tensor single_prefill_with_kv_cache( - unsigned int mask_mode_code, torch::Tensor q, torch::Tensor k, torch::Tensor v, - std::optional maybe_packed_custom_mask, torch::Tensor tmp, - std::optional maybe_alibi_slopes, unsigned int layout, int32_t window_left, - float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, - std::optional maybe_lse); - -std::vector BatchPrefillWithKVCachePlan( - unsigned int head_dim, torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, - torch::Tensor page_locked_int_workspace_buffer, torch::Tensor qo_indptr, - torch::Tensor kv_indptr, unsigned int batch_size, unsigned int num_qo_heads, - unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph); - -torch::Tensor BatchPrefillWithRaggedKVCacheRun( - unsigned int mask_mode_code, torch::Tensor float_workspace_buffer, - torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, - torch::Tensor k, torch::Tensor v, std::optional maybe_custom_mask, - std::optional maybe_alibi_slopes, torch::Tensor qo_indptr, - torch::Tensor kv_indptr, std::optional maybe_qk_indptr, unsigned int layout, - int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, - std::optional maybe_lse); - -torch::Tensor BatchPrefillWithPagedKVCacheRun( - unsigned int mask_mode_code, torch::Tensor float_workspace_buffer, - torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, - torch::Tensor paged_k_cache, torch::Tensor paged_v_cache, - std::optional maybe_custom_mask, std::optional maybe_alibi_slopes, - torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, std::optional maybe_qk_indptr, - unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, std::optional maybe_lse); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache, - "Single-request prefill attention with KV-Cache operator"); - m.def("batch_prefill_with_kv_cache_plan", &BatchPrefillWithKVCachePlan); - m.def("batch_prefill_with_ragged_kv_cache_run", &BatchPrefillWithRaggedKVCacheRun); - m.def("batch_prefill_with_paged_kv_cache_run", &BatchPrefillWithPagedKVCacheRun); -} diff --git a/flashinfer-aot/flashinfer b/flashinfer-aot/flashinfer deleted file mode 120000 index c5f9b1c7..00000000 --- a/flashinfer-aot/flashinfer +++ /dev/null @@ -1 +0,0 @@ -../python/flashinfer \ No newline at end of file diff --git a/flashinfer-aot/include b/flashinfer-aot/include deleted file mode 120000 index f5030fe8..00000000 --- a/flashinfer-aot/include +++ /dev/null @@ -1 +0,0 @@ -../include \ No newline at end of file diff --git a/flashinfer-aot/version.txt b/flashinfer-aot/version.txt deleted file mode 120000 index aa4e5bec..00000000 --- a/flashinfer-aot/version.txt +++ /dev/null @@ -1 +0,0 @@ -../version.txt \ No newline at end of file diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 423c989f..ecafee1e 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -50,7 +50,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__ * the new batch size after the partition. */ template -auto PartitionPagedKVCacheBinarySearchMinNumPagePerBatch( +inline auto PartitionPagedKVCacheBinarySearchMinNumPagePerBatch( const uint32_t max_grid_size, const uint32_t num_kv_heads, const std::vector& num_pages, const uint32_t min_num_pages_per_batch = 1) { uint32_t low = min_num_pages_per_batch, high = 0; @@ -77,7 +77,7 @@ auto PartitionPagedKVCacheBinarySearchMinNumPagePerBatch( return std::make_tuple(low, new_batch_size); } -auto PrefillBinarySearchKVChunkSize(const uint32_t max_batch_size_if_split, +inline auto PrefillBinarySearchKVChunkSize(const uint32_t max_batch_size_if_split, const std::vector& packed_qo_len_arr, const std::vector& kv_len_arr, const uint32_t qo_chunk_size, @@ -129,7 +129,7 @@ auto PrefillBinarySearchKVChunkSize(const uint32_t max_batch_size_if_split, */ template -cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( +inline cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( bool& split_kv, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch, uint32_t& new_batch_size, uint32_t batch_size, typename AttentionVariant::IdType* kv_indptr_h, const uint32_t num_qo_heads, const uint32_t page_size, bool enable_cuda_graph, @@ -201,7 +201,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( * \return status Indicates whether CUDA calls are successful */ template -auto DecodeSplitKVIndptr(IdType* indptr_h, uint32_t batch_size, uint32_t kv_chunk_size) { +inline auto DecodeSplitKVIndptr(IdType* indptr_h, uint32_t batch_size, uint32_t kv_chunk_size) { std::vector request_indices, kv_tile_indices, o_indptr; o_indptr.push_back(0); @@ -277,7 +277,7 @@ struct DecodePlanInfo { }; template -cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer, +inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer, void* page_locked_int_buffer, size_t int_workspace_size_in_bytes, DecodePlanInfo& plan_info, typename AttentionVariant::IdType* indptr_h, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, @@ -350,7 +350,7 @@ cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in_bytes, } template -auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size, +inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size, uint32_t max_batch_size_if_split, bool enable_cuda_graph) { @@ -520,7 +520,7 @@ struct PrefillPlanInfo { }; template -cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer, +inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer, void* page_locked_int_buffer, size_t int_workspace_size_in_bytes, PrefillPlanInfo& plan_info, IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, diff --git a/python/3rdparty b/python/3rdparty deleted file mode 120000 index 303a6484..00000000 --- a/python/3rdparty +++ /dev/null @@ -1 +0,0 @@ -../3rdparty \ No newline at end of file diff --git a/python/MANIFEST.in b/python/MANIFEST.in deleted file mode 100644 index b20747fe..00000000 --- a/python/MANIFEST.in +++ /dev/null @@ -1,12 +0,0 @@ -# sdist & wheel -include version.txt -recursive-include include * -recursive-include csrc * -recursive-include 3rdparty/cutlass * - -# wheel-only -exclude flashinfer/_build_meta.py - -# Unneeded files -prune */__pycache__ -global-exclude *.so diff --git a/python/_aot_build_utils/__init__.py b/python/_aot_build_utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/flashinfer-aot/generate_batch_paged_decode_inst.py b/python/_aot_build_utils/generate_batch_paged_decode_inst.py similarity index 98% rename from flashinfer-aot/generate_batch_paged_decode_inst.py rename to python/_aot_build_utils/generate_batch_paged_decode_inst.py index efd1945b..7808c33b 100644 --- a/flashinfer-aot/generate_batch_paged_decode_inst.py +++ b/python/_aot_build_utils/generate_batch_paged_decode_inst.py @@ -14,14 +14,15 @@ limitations under the License. """ -import sys import re -from literal_map import ( - pos_encoding_mode_literal, +import sys +from pathlib import Path + +from .literal_map import ( dtype_literal, idtype_literal, + pos_encoding_mode_literal, ) -from pathlib import Path def get_cu_file_str( diff --git a/flashinfer-aot/generate_batch_paged_prefill_inst.py b/python/_aot_build_utils/generate_batch_paged_prefill_inst.py similarity index 98% rename from flashinfer-aot/generate_batch_paged_prefill_inst.py rename to python/_aot_build_utils/generate_batch_paged_prefill_inst.py index 21328aae..97f1423a 100644 --- a/flashinfer-aot/generate_batch_paged_prefill_inst.py +++ b/python/_aot_build_utils/generate_batch_paged_prefill_inst.py @@ -14,16 +14,16 @@ limitations under the License. """ -import sys import re -import itertools -from literal_map import ( - mask_mode_literal, - pos_encoding_mode_literal, +import sys +from pathlib import Path + +from .literal_map import ( dtype_literal, idtype_literal, + mask_mode_literal, + pos_encoding_mode_literal, ) -from pathlib import Path def get_cu_file_str( diff --git a/flashinfer-aot/generate_batch_ragged_prefill_inst.py b/python/_aot_build_utils/generate_batch_ragged_prefill_inst.py similarity index 99% rename from flashinfer-aot/generate_batch_ragged_prefill_inst.py rename to python/_aot_build_utils/generate_batch_ragged_prefill_inst.py index 59acc67b..f5631303 100644 --- a/flashinfer-aot/generate_batch_ragged_prefill_inst.py +++ b/python/_aot_build_utils/generate_batch_ragged_prefill_inst.py @@ -14,15 +14,16 @@ limitations under the License. """ -import sys import re -from literal_map import ( - mask_mode_literal, - pos_encoding_mode_literal, +import sys +from pathlib import Path + +from .literal_map import ( dtype_literal, idtype_literal, + mask_mode_literal, + pos_encoding_mode_literal, ) -from pathlib import Path def get_cu_file_str( diff --git a/flashinfer-aot/generate_dispatch_inc.py b/python/_aot_build_utils/generate_dispatch_inc.py similarity index 99% rename from flashinfer-aot/generate_dispatch_inc.py rename to python/_aot_build_utils/generate_dispatch_inc.py index f3ad9db8..30552e6e 100644 --- a/flashinfer-aot/generate_dispatch_inc.py +++ b/python/_aot_build_utils/generate_dispatch_inc.py @@ -16,10 +16,11 @@ import argparse from pathlib import Path -from literal_map import ( - pos_encoding_mode_literal, + +from .literal_map import ( bool_literal, mask_mode_literal, + pos_encoding_mode_literal, ) diff --git a/flashinfer-aot/generate_single_decode_inst.py b/python/_aot_build_utils/generate_single_decode_inst.py similarity index 98% rename from flashinfer-aot/generate_single_decode_inst.py rename to python/_aot_build_utils/generate_single_decode_inst.py index 754e185f..ce24d7e7 100644 --- a/flashinfer-aot/generate_single_decode_inst.py +++ b/python/_aot_build_utils/generate_single_decode_inst.py @@ -14,13 +14,14 @@ limitations under the License. """ -import sys import re -from literal_map import ( - pos_encoding_mode_literal, +import sys +from pathlib import Path + +from .literal_map import ( dtype_literal, + pos_encoding_mode_literal, ) -from pathlib import Path def get_cu_file_str( diff --git a/flashinfer-aot/generate_single_prefill_inst.py b/python/_aot_build_utils/generate_single_prefill_inst.py similarity index 98% rename from flashinfer-aot/generate_single_prefill_inst.py rename to python/_aot_build_utils/generate_single_prefill_inst.py index eb54ed4e..49eefd17 100644 --- a/flashinfer-aot/generate_single_prefill_inst.py +++ b/python/_aot_build_utils/generate_single_prefill_inst.py @@ -14,14 +14,15 @@ limitations under the License. """ -import sys import re -from literal_map import ( - pos_encoding_mode_literal, +import sys +from pathlib import Path + +from .literal_map import ( dtype_literal, mask_mode_literal, + pos_encoding_mode_literal, ) -from pathlib import Path def get_cu_file_str( diff --git a/flashinfer-aot/literal_map.py b/python/_aot_build_utils/literal_map.py similarity index 100% rename from flashinfer-aot/literal_map.py rename to python/_aot_build_utils/literal_map.py diff --git a/python/aot_MANIFEST.in b/python/aot_MANIFEST.in new file mode 100644 index 00000000..e1988769 --- /dev/null +++ b/python/aot_MANIFEST.in @@ -0,0 +1,13 @@ +# MANIFEST.in for AOT wheel + +prune */__pycache__ +prune csrc +prune csrc_aot +exclude aot_setup.py +exclude setup.py + +include flashinfer/data/version.txt +graft flashinfer/data/csrc +graft flashinfer/data/include +graft flashinfer/data/cutlass/include +graft flashinfer/data/cutlass/tools/util/include diff --git a/flashinfer-aot/setup.py b/python/aot_setup.py similarity index 82% rename from flashinfer-aot/setup.py rename to python/aot_setup.py index 80fd4ea9..ee9cfa2a 100644 --- a/flashinfer-aot/setup.py +++ b/python/aot_setup.py @@ -14,33 +14,36 @@ limitations under the License. """ -from typing import List, Tuple - +import argparse +import contextlib import copy -import pathlib +import itertools import os +import pathlib +import platform import re -import itertools +import shutil import subprocess -import platform +import sys +import warnings +from typing import Iterator, List, Tuple import setuptools -import argparse import torch import torch.utils.cpp_extension as torch_cpp_ext -from collections import namedtuple -import generate_single_decode_inst, generate_single_prefill_inst, generate_batch_paged_decode_inst, generate_batch_paged_prefill_inst, generate_batch_ragged_prefill_inst, generate_dispatch_inc +root = pathlib.Path(__file__).resolve().parents[1] +sys.path.append(str(root / "python")) -root = pathlib.Path(__name__).parent - - -# cuda arch check for fp8 at the moment. -for cuda_arch_flags in torch_cpp_ext._get_cuda_arch_flags(): - arch = int(re.search("compute_\d+", cuda_arch_flags).group()[-2:]) - if arch < 75: - raise RuntimeError("FlashInfer requires sm75+") +from _aot_build_utils import ( + generate_batch_paged_decode_inst, + generate_batch_paged_prefill_inst, + generate_batch_ragged_prefill_inst, + generate_dispatch_inc, + generate_single_decode_inst, + generate_single_prefill_inst, +) enable_bf16 = os.environ.get("FLASHINFER_ENABLE_BF16", "1") == "1" enable_fp8 = os.environ.get("FLASHINFER_ENABLE_FP8", "1") == "1" @@ -61,7 +64,7 @@ def write_if_different(path: pathlib.Path, content: str) -> None: def get_instantiation_cu() -> Tuple[List[str], List[str], List[str]]: - path = root / "csrc_aot" / "generated" + path = root / "python" / "csrc_aot" / "generated" path.mkdir(parents=True, exist_ok=True) head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "64,128,256").split(",") @@ -104,13 +107,7 @@ def get_instantiation_cu() -> Tuple[List[str], List[str], List[str]]: files_prefill = [] single_decode_uris = [] # single decode files - for ( - head_dim, - pos_encoding_mode, - ) in itertools.product( - head_dims, - pos_encoding_modes, - ): + for head_dim, pos_encoding_mode in itertools.product(head_dims, pos_encoding_modes): for dtype_q, dtype_kv in list(zip(decode_dtypes, decode_dtypes)) + list( itertools.product(fp16_dtypes, fp8_dtypes) ): @@ -278,6 +275,11 @@ def get_instantiation_cu() -> Tuple[List[str], List[str], List[str]]: f"f16qk_{bool(allow_fp16_qk_reduction)}" ) + # Change to relative path + this_dir = pathlib.Path(__file__).parent.resolve() + files_prefill = [str(pathlib.Path(p).relative_to(this_dir)) for p in files_prefill] + files_decode = [str(pathlib.Path(p).relative_to(this_dir)) for p in files_decode] + return ( files_prefill, files_decode, @@ -313,14 +315,14 @@ def generate_build_meta() -> None: d["torch"] = torch.__version__ d["python"] = platform.python_version() d["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", None) - with open(root / "flashinfer" / "_build_meta.py", "w") as f: + with open(root / "python" / "flashinfer" / "_build_meta.py", "w") as f: f.write(f"__version__ = {version!r}\n") f.write(f"build_meta = {d!r}") def generate_aot_config(aot_kernel_uris: List[str]) -> None: aot_config_str = f"""prebuilt_ops_uri = set({aot_kernel_uris})""" - with open(root / "flashinfer" / "jit" / "aot_config.py", "w") as f: + with open(root / "python" / "flashinfer" / "jit" / "aot_config.py", "w") as f: f.write(aot_config_str) @@ -348,11 +350,43 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) +@contextlib.contextmanager +def link_data_files() -> Iterator[None]: + this_dir = pathlib.Path(__file__).parent + data_dir = root / "python" / "flashinfer" / "data" + if data_dir.exists(): + shutil.rmtree(data_dir) + data_dir.mkdir(parents=True) + + def ln(src: str, dst: str, is_dir: bool = False) -> None: + (data_dir / dst).symlink_to(root / src, target_is_directory=is_dir) + + ln("3rdparty/cutlass", "cutlass", True) + ln("include", "include", True) + ln("python/csrc", "csrc", True) + ln("version.txt", "version.txt") + (this_dir / "MANIFEST.in").unlink(True) + (this_dir / "MANIFEST.in").symlink_to("jit_MANIFEST.in") + + yield + + if sys.argv[1] != "develop": + shutil.rmtree(data_dir) + (this_dir / "MANIFEST.in").unlink(True) + + if __name__ == "__main__": + # cuda arch check for fp8 at the moment. + for cuda_arch_flags in torch_cpp_ext._get_cuda_arch_flags(): + arch = int(re.search(r"compute_(\d+)", cuda_arch_flags).group(1)) + if arch < 75: + raise RuntimeError("FlashInfer requires sm75+") + remove_unwanted_pytorch_nvcc_flags() generate_build_meta() files_prefill, files_decode, aot_kernel_uris = get_instantiation_cu() generate_aot_config(aot_kernel_uris) + include_dirs = [ str(root.resolve() / "include"), str(root.resolve() / "3rdparty" / "cutlass" / "include"), # for group gemm @@ -366,8 +400,7 @@ def __init__(self, *args, **kwargs) -> None: "nvcc": [ "-O3", "-std=c++17", - "--threads", - "1", + "--threads=1", "-Xfatbin", "-compress-all", "-use_fast_math", @@ -382,17 +415,23 @@ def __init__(self, *args, **kwargs) -> None: torch_cpp_ext.CUDAExtension( name="flashinfer._kernels", sources=[ + "csrc/bmm_fp8.cu", "csrc/cascade.cu", + "csrc/group_gemm.cu", + "csrc/norm.cu", "csrc/page.cu", + "csrc/quantization.cu", + "csrc/rope.cu", "csrc/sampling.cu", - "csrc/norm.cu", "csrc_aot/activation.cu", - "csrc/rope.cu", - "csrc/quantization.cu", - "csrc/group_gemm.cu", - "csrc/bmm_fp8.cu", - "csrc_aot/flashinfer_ops.cu" - ], + "csrc_aot/batch_decode.cu", + "csrc_aot/batch_prefill.cu", + "csrc_aot/flashinfer_ops.cu", + "csrc_aot/single_decode.cu", + "csrc_aot/single_prefill.cu", + ] + + files_decode + + files_prefill, include_dirs=include_dirs, extra_compile_args=extra_compile_args, ) @@ -408,41 +447,25 @@ def __init__(self, *args, **kwargs) -> None: extra_compile_args=extra_compile_args_sm90, ) ) - ext_modules.append( - torch_cpp_ext.CUDAExtension( - name="flashinfer._decode_kernels", - sources=[ - "csrc_aot/single_decode.cu", - "csrc_aot/flashinfer_ops_decode.cu", - "csrc_aot/batch_decode.cu", - ] - + files_decode, - include_dirs=include_dirs, - extra_compile_args=extra_compile_args, - ) - ) - ext_modules.append( - torch_cpp_ext.CUDAExtension( - name="flashinfer._prefill_kernels", - sources=[ - "csrc_aot/single_prefill.cu", - "csrc_aot/flashinfer_ops_prefill.cu", - "csrc_aot/batch_prefill.cu", - ] - + files_prefill, - include_dirs=include_dirs, - extra_compile_args=extra_compile_args, + + # Suppress warnings complaining that: + # Package 'flashinfer.data*' is absent from the `packages` configuration. + warnings.filterwarnings("ignore", r".*flashinfer\.data.*", UserWarning) + + with link_data_files(): + setuptools.setup( + name="flashinfer", + version=get_version(), + packages=setuptools.find_packages( + include=["flashinfer*"], + exclude=["flashinfer.data*"], + ), + include_package_data=True, + author="FlashInfer team", + license="Apache License 2.0", + description="FlashInfer: Kernel Library for LLM Serving", + url="https://github.com/flashinfer-ai/flashinfer", + python_requires=">=3.8", + ext_modules=ext_modules, + cmdclass={"build_ext": NinjaBuildExtension}, ) - ) - setuptools.setup( - name="flashinfer", - version=get_version(), - packages=setuptools.find_packages(), - author="FlashInfer team", - license="Apache License 2.0", - description="FlashInfer: Kernel Library for LLM Serving", - url="https://github.com/flashinfer-ai/flashinfer", - python_requires=">=3.8", - ext_modules=ext_modules, - cmdclass={"build_ext": NinjaBuildExtension}, - ) diff --git a/flashinfer-aot/csrc_aot/activation.cu b/python/csrc_aot/activation.cu similarity index 100% rename from flashinfer-aot/csrc_aot/activation.cu rename to python/csrc_aot/activation.cu diff --git a/flashinfer-aot/csrc_aot/batch_decode.cu b/python/csrc_aot/batch_decode.cu similarity index 100% rename from flashinfer-aot/csrc_aot/batch_decode.cu rename to python/csrc_aot/batch_decode.cu diff --git a/flashinfer-aot/csrc_aot/batch_prefill.cu b/python/csrc_aot/batch_prefill.cu similarity index 100% rename from flashinfer-aot/csrc_aot/batch_prefill.cu rename to python/csrc_aot/batch_prefill.cu diff --git a/flashinfer-aot/csrc_aot/flashinfer_ops.cu b/python/csrc_aot/flashinfer_ops.cu similarity index 64% rename from flashinfer-aot/csrc_aot/flashinfer_ops.cu rename to python/csrc_aot/flashinfer_ops.cu index 5cf365ca..e32ac3e9 100644 --- a/flashinfer-aot/csrc_aot/flashinfer_ops.cu +++ b/python/csrc_aot/flashinfer_ops.cu @@ -15,11 +15,13 @@ */ #include -void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, - torch::Tensor append_indptr, torch::Tensor paged_k_cache, - torch::Tensor paged_v_cache, torch::Tensor kv_indices, - torch::Tensor kv_indptr, torch::Tensor kv_last_page_len, - unsigned int layout); +//========== activation ========== + +void silu_and_mul(torch::Tensor& out, torch::Tensor& input); +void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); +void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); + +//========== cascade ========== std::vector merge_state(torch::Tensor v_a, torch::Tensor s_a, torch::Tensor v_b, torch::Tensor s_b); @@ -29,43 +31,41 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe std::vector merge_states(torch::Tensor v, torch::Tensor s); -torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples, - bool deterministic); +//========== decode ========== -std::vector top_p_sampling_from_probs(torch::Tensor probs, - torch::Tensor uniform_samples, - std::optional maybe_top_p_arr, - double top_p_val, bool deterministic); +torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v, + torch::Tensor tmp, + std::optional alibi_slopes, + unsigned int layout, int window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta); -std::vector top_k_sampling_from_probs(torch::Tensor probs, - torch::Tensor uniform_samples, - std::optional maybe_top_k_arr, - unsigned int top_k_val, bool deterministic); +std::vector BatchDecodeWithPagedKVCachePlan( + bool use_logits_soft_cap, unsigned int head_dim, torch::Tensor empty_q_data, + torch::Tensor empty_kv_data, torch::Tensor float_workspace_buffer, + torch::Tensor int_workspace_buffer, torch::Tensor page_locked_int_workspace_buffer, + torch::Tensor indptr, unsigned int batch_size, unsigned int num_qo_heads, + unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph); -std::vector min_p_sampling_from_probs(torch::Tensor probs, - torch::Tensor uniform_samples, - std::optional maybe_min_p_arr, - double min_p_val, bool deterministic); +torch::Tensor BatchDecodeWithPagedKVCacheRun( + torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, + std::vector plan_info_vec, torch::Tensor q, torch::Tensor paged_k_cache, + torch::Tensor paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, + torch::Tensor paged_kv_last_page_len, std::optional alibi_slopes, + unsigned int kv_layout_code, int window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, std::optional maybe_lse); -std::vector top_k_top_p_sampling_from_probs( - torch::Tensor probs, torch::Tensor uniform_samples, - std::optional maybe_top_k_arr, double top_k_val, - std::optional maybe_top_p_arr, double top_p_val, bool deterministic); +//========== gemm ========== -torch::Tensor top_p_renorm_probs(torch::Tensor probs, std::optional maybe_top_p_arr, - double top_p_val); - -torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional maybe_top_k_arr, - unsigned int top_k_val); +void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D, + torch::Tensor& A_scale, torch::Tensor& B_scale); -torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional maybe_top_k_arr, - unsigned int top_k_val); +torch::Tensor CutlassSegmentGEMM(torch::Tensor workspace_buffer, torch::Tensor seg_indptr, + torch::Tensor weight_indices, torch::Tensor x, + torch::Tensor weight, unsigned int batch_size, + bool weight_column_major); -torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids, - torch::Tensor uniform_samples, torch::Tensor target_probs, - torch::Tensor output_accepted_token_num, - torch::Tensor output_emitted_token_num, - bool deterministic); +//========== norm ========== void rmsnorm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double eps); @@ -77,11 +77,56 @@ void gemma_rmsnorm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weig void gemma_fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double eps); -void silu_and_mul(torch::Tensor& out, torch::Tensor& input); +//========== page ========== -void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); +void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, + torch::Tensor append_indptr, torch::Tensor paged_k_cache, + torch::Tensor paged_v_cache, torch::Tensor kv_indices, + torch::Tensor kv_indptr, torch::Tensor kv_last_page_len, + unsigned int layout); -void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); +//========== prefill ========== + +torch::Tensor single_prefill_with_kv_cache( + unsigned int mask_mode_code, torch::Tensor q, torch::Tensor k, torch::Tensor v, + std::optional maybe_packed_custom_mask, torch::Tensor tmp, + std::optional maybe_alibi_slopes, unsigned int layout, int32_t window_left, + float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + std::optional maybe_lse); + +std::vector BatchPrefillWithKVCachePlan( + unsigned int head_dim, torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, + torch::Tensor page_locked_int_workspace_buffer, torch::Tensor qo_indptr, + torch::Tensor kv_indptr, unsigned int batch_size, unsigned int num_qo_heads, + unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph); + +torch::Tensor BatchPrefillWithRaggedKVCacheRun( + unsigned int mask_mode_code, torch::Tensor float_workspace_buffer, + torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, + torch::Tensor k, torch::Tensor v, std::optional maybe_custom_mask, + std::optional maybe_alibi_slopes, torch::Tensor qo_indptr, + torch::Tensor kv_indptr, std::optional maybe_qk_indptr, unsigned int layout, + int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + std::optional maybe_lse); + +torch::Tensor BatchPrefillWithPagedKVCacheRun( + unsigned int mask_mode_code, torch::Tensor float_workspace_buffer, + torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, + torch::Tensor paged_k_cache, torch::Tensor paged_v_cache, + std::optional maybe_custom_mask, std::optional maybe_alibi_slopes, + torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, + torch::Tensor paged_kv_last_page_len, std::optional maybe_qk_indptr, + unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, std::optional maybe_lse); + +//========== quantization ========== + +torch::Tensor packbits(torch::Tensor x, const std::string& bitorder); + +torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, + torch::Tensor output_indptr, const std::string& bitorder); + +//========== rope ========== void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale, @@ -101,25 +146,99 @@ void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor float rope_scale, float rope_theta, float low_freq_factor, float high_freq_factor, float old_context_length); -torch::Tensor packbits(torch::Tensor x, const std::string& bitorder); +//========== sampling ========== -torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, - torch::Tensor output_indptr, const std::string& bitorder); +torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples, + bool deterministic); -void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D, - torch::Tensor& A_scale, torch::Tensor& B_scale); +std::vector top_p_sampling_from_probs(torch::Tensor probs, + torch::Tensor uniform_samples, + std::optional maybe_top_p_arr, + double top_p_val, bool deterministic); -torch::Tensor CutlassSegmentGEMM(torch::Tensor workspace_buffer, torch::Tensor seg_indptr, - torch::Tensor weight_indices, torch::Tensor x, - torch::Tensor weight, unsigned int batch_size, - bool weight_column_major); +std::vector top_k_sampling_from_probs(torch::Tensor probs, + torch::Tensor uniform_samples, + std::optional maybe_top_k_arr, + unsigned int top_k_val, bool deterministic); + +std::vector min_p_sampling_from_probs(torch::Tensor probs, + torch::Tensor uniform_samples, + std::optional maybe_min_p_arr, + double min_p_val, bool deterministic); + +std::vector top_k_top_p_sampling_from_probs( + torch::Tensor probs, torch::Tensor uniform_samples, + std::optional maybe_top_k_arr, double top_k_val, + std::optional maybe_top_p_arr, double top_p_val, bool deterministic); + +torch::Tensor top_p_renorm_probs(torch::Tensor probs, std::optional maybe_top_p_arr, + double top_p_val); + +torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional maybe_top_k_arr, + unsigned int top_k_val); + +torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional maybe_top_k_arr, + unsigned int top_k_val); + +torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids, + torch::Tensor uniform_samples, torch::Tensor target_probs, + torch::Tensor output_accepted_token_num, + torch::Tensor output_emitted_token_num, + bool deterministic); + +//========== pybind11 ========== PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator"); + // activation + m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul"); + m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul"); + m.def("gelu_and_mul", &gelu_and_mul, "Fused GeLU and Mul"); + + // cascade m.def("merge_state", &merge_state, "Merge two self-attention states"); m.def("merge_state_in_place", &merge_state_in_place, "Merge another self-attention state in-place."); m.def("merge_states", &merge_states, "Merge multiple self-attention states"); + + // decode + m.def("single_decode_with_kv_cache", &single_decode_with_kv_cache, + "Single-request decode with KV-Cache operator"); + m.def("batch_decode_with_paged_kv_cache_plan", &BatchDecodeWithPagedKVCachePlan); + m.def("batch_decode_with_paged_kv_cache_run", &BatchDecodeWithPagedKVCacheRun); + + // gemm + m.def("bmm_fp8", &bmm_fp8, "BMM FP8"); + m.def("cutlass_segment_gemm", &CutlassSegmentGEMM, "Cutlass Segment GEMM operator"); + + // norm + m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); + m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization"); + m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma Root mean square normalization"); + m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, + "Gemma Fused add root mean square normalization"); + + // page + m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator"); + + // prefill + m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache, + "Single-request prefill attention with KV-Cache operator"); + m.def("batch_prefill_with_kv_cache_plan", &BatchPrefillWithKVCachePlan); + m.def("batch_prefill_with_ragged_kv_cache_run", &BatchPrefillWithRaggedKVCacheRun); + m.def("batch_prefill_with_paged_kv_cache_run", &BatchPrefillWithPagedKVCacheRun); + + // quantization + m.def("packbits", &packbits, "GPU packbits operator"); + m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator"); + + // rope + m.def("apply_rope", &apply_rope, "Apply RoPE"); + m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE"); + m.def("apply_rope_pos_ids", &apply_rope_pos_ids, "Apply RoPE with positional ids"); + m.def("apply_llama31_rope_pos_ids", &apply_llama31_rope_pos_ids, + "Apply Llama 3.1 style RoPE with positional ids"); + + // sampling m.def("sampling_from_probs", &sampling_from_probs, "Sample from probabilities"); m.def("top_k_sampling_from_probs", &top_k_sampling_from_probs, "Top-k sampling from probabilities"); @@ -134,21 +253,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("top_k_mask_logits", &top_k_mask_logits, "Mask logits by top-k mask"); m.def("chain_speculative_sampling", &chain_speculative_sampling, "Speculative sampling from sequence of probabilities"); - m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); - m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization"); - m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma Root mean square normalization"); - m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, - "Gemma Fused add root mean square normalization"); - m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul"); - m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul"); - m.def("gelu_and_mul", &gelu_and_mul, "Fused GeLU and Mul"); - m.def("apply_rope", &apply_rope, "Apply RoPE"); - m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE"); - m.def("apply_rope_pos_ids", &apply_rope_pos_ids, "Apply RoPE with positional ids"); - m.def("apply_llama31_rope_pos_ids", &apply_llama31_rope_pos_ids, - "Apply Llama 3.1 style RoPE with positional ids"); - m.def("packbits", &packbits, "GPU packbits operator"); - m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator"); - m.def("cutlass_segment_gemm", &CutlassSegmentGEMM, "Cutlass Segment GEMM operator"); - m.def("bmm_fp8", &bmm_fp8, "BMM FP8"); } diff --git a/flashinfer-aot/csrc_aot/flashinfer_sm90_ops.cu b/python/csrc_aot/flashinfer_sm90_ops.cu similarity index 99% rename from flashinfer-aot/csrc_aot/flashinfer_sm90_ops.cu rename to python/csrc_aot/flashinfer_sm90_ops.cu index 5140982f..d4222473 100644 --- a/flashinfer-aot/csrc_aot/flashinfer_sm90_ops.cu +++ b/python/csrc_aot/flashinfer_sm90_ops.cu @@ -23,4 +23,4 @@ torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, torch PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, "Cutlass Segment GEMM operator for SM90"); -} \ No newline at end of file +} diff --git a/flashinfer-aot/csrc_aot/pytorch_extension_utils.h b/python/csrc_aot/pytorch_extension_utils.h similarity index 100% rename from flashinfer-aot/csrc_aot/pytorch_extension_utils.h rename to python/csrc_aot/pytorch_extension_utils.h diff --git a/flashinfer-aot/csrc_aot/single_decode.cu b/python/csrc_aot/single_decode.cu similarity index 100% rename from flashinfer-aot/csrc_aot/single_decode.cu rename to python/csrc_aot/single_decode.cu diff --git a/flashinfer-aot/csrc_aot/single_prefill.cu b/python/csrc_aot/single_prefill.cu similarity index 100% rename from flashinfer-aot/csrc_aot/single_prefill.cu rename to python/csrc_aot/single_prefill.cu diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index 5215ae45..4e7cd548 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -81,9 +81,9 @@ def get_single_decode_module(*args): if args not in _single_decode_modules: uri = get_single_decode_uri(*args) if has_prebuilt_ops and uri in prebuilt_ops_uri: - from . import _decode_kernels + from . import _kernels - run_func = _decode_kernels.single_decode_with_kv_cache + run_func = _kernels.single_decode_with_kv_cache else: run_func = compile_single_decode_module(*args).run @@ -143,7 +143,7 @@ def get_batch_decode_module(*args): if args not in _batch_decode_modules: uri = get_batch_decode_uri(*args) if has_prebuilt_ops and uri in prebuilt_ops_uri: - from . import _decode_kernels + from . import _kernels # NOTE(Zihao): we should avoid hard-coded index like this, refactor it later dtype_q = args[0] @@ -151,7 +151,7 @@ def get_batch_decode_module(*args): head_dim = args[4] use_logits_cap = args[7] plan_func = ( - lambda *plan_args: _decode_kernels.batch_decode_with_paged_kv_cache_plan( + lambda *plan_args: _kernels.batch_decode_with_paged_kv_cache_plan( use_logits_cap, head_dim, torch.empty(0, dtype=dtype_q), @@ -159,7 +159,7 @@ def get_batch_decode_module(*args): *plan_args, ) ) - run_func = _decode_kernels.batch_decode_with_paged_kv_cache_run + run_func = _kernels.batch_decode_with_paged_kv_cache_run else: mod = compile_batch_decode_module(*args) plan_func = mod.plan diff --git a/python/flashinfer/jit/env.py b/python/flashinfer/jit/env.py index cc48182c..cb774514 100644 --- a/python/flashinfer/jit/env.py +++ b/python/flashinfer/jit/env.py @@ -31,10 +31,10 @@ def _get_workspace_dir_name() -> pathlib.Path: FLASHINFER_WORKSPACE_DIR = _get_workspace_dir_name() FLASHINFER_JIT_DIR = FLASHINFER_WORKSPACE_DIR / "cached_ops" FLASHINFER_GEN_SRC_DIR = FLASHINFER_WORKSPACE_DIR / "generated" -_project_root = pathlib.Path(__file__).resolve().parent.parent.parent -FLASHINFER_INCLUDE_DIR = _project_root / "include" -FLASHINFER_CSRC_DIR = _project_root / "csrc" +_package_root = pathlib.Path(__file__).resolve().parents[1] +FLASHINFER_INCLUDE_DIR = _package_root / "data" / "include" +FLASHINFER_CSRC_DIR = _package_root / "data" / "csrc" CUTLASS_INCLUDE_DIRS = [ - _project_root / "3rdparty" / "cutlass" / "include", - _project_root / "3rdparty" / "cutlass" / "tools" / "util" / "include", + _package_root / "data" / "cutlass" / "include", + _package_root / "data" / "cutlass" / "tools" / "util" / "include", ] diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index bfb8f48e..47f5d1a5 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -49,7 +49,7 @@ ) if has_prebuilt_ops: - from . import _prefill_kernels # type: ignore[attr-defined] + from . import _kernels # type: ignore[attr-defined] def compile_single_prefill_module( @@ -87,7 +87,7 @@ def get_single_prefill_module(*args): if has_prebuilt_ops and uri in prebuilt_ops_uri: # NOTE(Zihao): we should avoid hard-coded index like this, refactor it later mask_mode = args[5] - run_func = lambda *run_args: _prefill_kernels.single_prefill_with_kv_cache( + run_func = lambda *run_args: _kernels.single_prefill_with_kv_cache( mask_mode, *run_args, ) @@ -159,21 +159,19 @@ def get_batch_prefill_module(*args): if has_prebuilt_ops and uri in prebuilt_ops_uri: # NOTE(Zihao): we should avoid hard-coded index like this, refactor it later head_dim = args[4] - plan_func = ( - lambda *plan_args: _prefill_kernels.batch_prefill_with_kv_cache_plan( - head_dim, - *plan_args, - ) + plan_func = lambda *plan_args: _kernels.batch_prefill_with_kv_cache_plan( + head_dim, + *plan_args, ) mask_mode = args[6] ragged_run_func = ( - lambda *run_args: _prefill_kernels.batch_prefill_with_ragged_kv_cache_run( + lambda *run_args: _kernels.batch_prefill_with_ragged_kv_cache_run( mask_mode, *run_args, ) ) paged_run_func = ( - lambda *run_args: _prefill_kernels.batch_prefill_with_paged_kv_cache_run( + lambda *run_args: _kernels.batch_prefill_with_paged_kv_cache_run( mask_mode, *run_args, ) diff --git a/python/flashinfer/triton/kernels/__init__.py b/python/flashinfer/triton/kernels/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/include b/python/include deleted file mode 120000 index 3a1af68f..00000000 --- a/python/include +++ /dev/null @@ -1 +0,0 @@ -../include/ \ No newline at end of file diff --git a/python/jit_MANIFEST.in b/python/jit_MANIFEST.in new file mode 100644 index 00000000..33022575 --- /dev/null +++ b/python/jit_MANIFEST.in @@ -0,0 +1,17 @@ +# MANIFEST.in for JIT sdist + +global-exclude *.so + +prune */__pycache__ +prune csrc +prune csrc_aot +exclude flashinfer/jit/aot_config.py +exclude aot_setup.py +exclude mypy.ini +exclude pylintrc + +include flashinfer/data/version.txt +graft flashinfer/data/csrc +graft flashinfer/data/include +graft flashinfer/data/cutlass/include +graft flashinfer/data/cutlass/tools/util/include diff --git a/python/setup.py b/python/setup.py index 52166d51..fff0d5e2 100644 --- a/python/setup.py +++ b/python/setup.py @@ -14,46 +14,87 @@ limitations under the License. """ -from typing import List, Tuple - -import pathlib +import contextlib import os +import pathlib +import shutil +import sys +from typing import Iterator +import warnings + import setuptools -root = pathlib.Path(__name__).parent +root = pathlib.Path(__file__).resolve().parents[1] +this_dir = pathlib.Path(__file__).parent def get_version(): version = os.getenv("FLASHINFER_BUILD_VERSION") if version is None: - with open(root / "version.txt") as f: + with open(this_dir / "flashinfer" / "data" / "version.txt") as f: version = f.read().strip() return version def generate_build_meta() -> None: version = get_version() - with open(root / "flashinfer/_build_meta.py", "w") as f: + with open(this_dir / "flashinfer" / "_build_meta.py", "w") as f: f.write(f"__version__ = {version!r}\n") def clear_aot_config(): # remove aot_config.py - aot_config_path = root / "flashinfer" / "jit" / "aot_config.py" + aot_config_path = this_dir / "flashinfer" / "jit" / "aot_config.py" if os.path.exists(aot_config_path): os.remove(aot_config_path) +@contextlib.contextmanager +def link_data_files() -> Iterator[None]: + this_dir = pathlib.Path(__file__).parent + data_dir = root / "python" / "flashinfer" / "data" + if data_dir.exists(): + shutil.rmtree(data_dir) + data_dir.mkdir(parents=True) + + def ln(src: str, dst: str, is_dir: bool = False) -> None: + (data_dir / dst).symlink_to(root / src, target_is_directory=is_dir) + + ln("3rdparty/cutlass", "cutlass", True) + ln("include", "include", True) + ln("python/csrc", "csrc", True) + ln("version.txt", "version.txt") + (this_dir / "MANIFEST.in").unlink(True) + (this_dir / "MANIFEST.in").symlink_to("jit_MANIFEST.in") + + yield + + if sys.argv[1] != "develop": + shutil.rmtree(data_dir) + (this_dir / "MANIFEST.in").unlink(True) + + if __name__ == "__main__": + link_data_files() generate_build_meta() clear_aot_config() - setuptools.setup( - name="flashinfer", - version=get_version(), - packages=setuptools.find_packages(), - author="FlashInfer team", - license="Apache License 2.0", - description="FlashInfer: Kernel Library for LLM Serving", - url="https://github.com/flashinfer-ai/flashinfer", - python_requires=">=3.8", - ) + + # Suppress warnings complaining that: + # Package 'flashinfer.data*' is absent from the `packages` configuration. + warnings.filterwarnings("ignore", r".*flashinfer\.data.*", UserWarning) + + with link_data_files(): + setuptools.setup( + name="flashinfer", + version=get_version(), + packages=setuptools.find_packages( + include=["flashinfer*"], + exclude=["flashinfer.data*"], + ), + include_package_data=True, + author="FlashInfer team", + license="Apache License 2.0", + description="FlashInfer: Kernel Library for LLM Serving", + url="https://github.com/flashinfer-ai/flashinfer", + python_requires=">=3.8", + ) diff --git a/python/version.txt b/python/version.txt deleted file mode 120000 index aa4e5bec..00000000 --- a/python/version.txt +++ /dev/null @@ -1 +0,0 @@ -../version.txt \ No newline at end of file diff --git a/scripts/run-ci-build-wheel.sh b/scripts/run-ci-build-wheel.sh index c50118a1..5d445982 100644 --- a/scripts/run-ci-build-wheel.sh +++ b/scripts/run-ci-build-wheel.sh @@ -42,7 +42,8 @@ echo "::endgroup::" echo "::group::Build wheel for FlashInfer" cd "$PROJECT_ROOT/python" -FLASHINFER_BUILD_VERSION="${FLASHINFER_BUILD_VERSION}+cu${CUDA_MAJOR}${CUDA_MINOR}torch${FLASHINFER_CI_TORCH_VERSION}" python -m build --no-isolation +FLASHINFER_BUILD_VERSION="${FLASHINFER_BUILD_VERSION}+cu${CUDA_MAJOR}${CUDA_MINOR}torch${FLASHINFER_CI_TORCH_VERSION}" python aot_setup.py bdist_wheel rm -f dist/*.tar.gz python -m build --no-isolation --sdist +ls -la dist/ echo "::endgroup::"