Skip to content

Commit

Permalink
[Refactor] Accelerate PyTorch Extension Compilation (#98)
Browse files Browse the repository at this point in the history
This PR accelerates the compilation process of PyTorch Extensions by
template instantiation. The template instances are automatically
generated in `setup.py`, this trick is adapted from Punica project.

Co-authored-by: Lequn Chen <[email protected]>
  • Loading branch information
yzh119 and abcdabcd987 authored Jan 31, 2024
1 parent 633b537 commit 6031092
Show file tree
Hide file tree
Showing 31 changed files with 1,060 additions and 728 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Generated cu files
python/csrc/generated/
python/flashinfer/_build_meta.py

microbenchmark/
.vscode/

Expand Down
1 change: 1 addition & 0 deletions include/flashinfer.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@
#include "flashinfer/page.cuh"
#include "flashinfer/prefill.cuh"
#include "flashinfer/rope.cuh"
#include "flashinfer/wrapper.cuh"

#endif // FLASHINFER_CUH_
8 changes: 4 additions & 4 deletions include/flashinfer/cascade.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ template <typename DTypeIn, typename DTypeOut>
cudaError_t MergeState(DTypeIn* v_a, float* s_a, DTypeIn* v_b, float* s_b, DTypeOut* v_merged,
float* s_merged, uint32_t seq_len, uint32_t num_heads, uint32_t head_dim,
cudaStream_t stream = nullptr) {
SWITCH_HEAD_DIM(head_dim, HEAD_DIM, {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16U / sizeof(DTypeIn), HEAD_DIM / 32U);
uint32_t bdx = HEAD_DIM / vec_size;
uint32_t bdy = num_heads;
Expand Down Expand Up @@ -391,7 +391,7 @@ template <typename DType>
cudaError_t MergeStateInPlace(DType* v, float* s, DType* v_other, float* s_other, uint32_t seq_len,
uint32_t num_heads, uint32_t head_dim,
cudaStream_t stream = nullptr) {
SWITCH_HEAD_DIM(head_dim, HEAD_DIM, {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16U / sizeof(DType), HEAD_DIM / 32U);
uint32_t bdx = HEAD_DIM / vec_size;
uint32_t bdy = num_heads;
Expand Down Expand Up @@ -424,7 +424,7 @@ template <typename DTypeIn, typename DTypeOut>
cudaError_t MergeStates(DTypeIn* v, float* s, DTypeOut* v_merged, float* s_merged,
uint32_t num_index_sets, uint32_t seq_len, uint32_t num_heads,
uint32_t head_dim, cudaStream_t stream = nullptr) {
SWITCH_HEAD_DIM(head_dim, HEAD_DIM, {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16U / sizeof(DTypeIn), HEAD_DIM / 32U);
constexpr uint32_t bdx = HEAD_DIM / vec_size;
if (num_index_sets >= seq_len) {
Expand Down Expand Up @@ -457,7 +457,7 @@ template <typename DTypeIn, typename DTypeOut, typename IdType>
cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTypeOut* v_merged,
float* s_merged, uint32_t seq_len, uint32_t num_heads,
uint32_t head_dim, cudaStream_t stream = nullptr) {
SWITCH_HEAD_DIM(head_dim, HEAD_DIM, {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16U / sizeof(DTypeIn), HEAD_DIM / 32U);
constexpr uint32_t bdx = HEAD_DIM / vec_size;
constexpr uint32_t num_threads = 128;
Expand Down
36 changes: 18 additions & 18 deletions include/flashinfer/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -735,12 +735,12 @@ cudaError_t SingleDecodeWithKVCacheWorkEstimation(uint32_t& tmp_size, uint32_t&
if (seq_len <= 256U) {
tmp_size = 0;
} else {
SWITCH_GQA_GROUP_SIZE(
DISPATCH_GQA_GROUP_SIZE(
num_qo_heads / num_kv_heads, GROUP_SIZE,
{SWITCH_HEAD_DIM(
{DISPATCH_HEAD_DIM(
head_dim, HEAD_DIM,
{SWITCH_ROTARY_MODE(
rotary_mode, ROTARY_MODE, {SWITCH_LAYOUT(kv_layout, KV_LAYOUT, {
{DISPATCH_ROTARY_MODE(
rotary_mode, ROTARY_MODE, {DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, {
constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL);
constexpr uint32_t num_stages_smem = 2U;
constexpr uint32_t bdx = HEAD_DIM / vec_size;
Expand Down Expand Up @@ -817,12 +817,12 @@ cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut
throw std::invalid_argument(err_msg.str());
}

SWITCH_GQA_GROUP_SIZE(
DISPATCH_GQA_GROUP_SIZE(
num_qo_heads / num_kv_heads, GROUP_SIZE,
{SWITCH_HEAD_DIM(
{DISPATCH_HEAD_DIM(
head_dim, HEAD_DIM,
{SWITCH_ROTARY_MODE(
rotary_mode, ROTARY_MODE, {SWITCH_LAYOUT(kv_layout, KV_LAYOUT, {
{DISPATCH_ROTARY_MODE(
rotary_mode, ROTARY_MODE, {DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, {
constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL);
constexpr uint32_t num_stages_smem = 2U;
constexpr uint32_t bdx = HEAD_DIM / vec_size;
Expand Down Expand Up @@ -1055,10 +1055,10 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimation(
uint32_t& new_batch_size, uint32_t batch_size, IdType* kv_indptr, const uint32_t num_qo_heads,
const uint32_t num_kv_heads, const uint32_t head_dim, const uint32_t page_size,
const RotaryMode rotary_mode = RotaryMode::kNone, cudaStream_t stream = nullptr) {
SWITCH_GQA_GROUP_SIZE(
DISPATCH_GQA_GROUP_SIZE(
num_qo_heads / num_kv_heads, GROUP_SIZE,
{SWITCH_HEAD_DIM(
head_dim, HEAD_DIM, {SWITCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, {
{DISPATCH_HEAD_DIM(
head_dim, HEAD_DIM, {DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, {
constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL);
constexpr uint32_t num_stages_smem = 2U;
constexpr uint32_t bdx = HEAD_DIM / vec_size;
Expand Down Expand Up @@ -1226,10 +1226,10 @@ cudaError_t BatchDecodeWithPagedKVCache(
throw std::invalid_argument(err_msg.str());
}

SWITCH_GQA_GROUP_SIZE(
DISPATCH_GQA_GROUP_SIZE(
num_qo_heads / num_kv_heads, GROUP_SIZE,
{SWITCH_HEAD_DIM(
head_dim, HEAD_DIM, {SWITCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, {
{DISPATCH_HEAD_DIM(
head_dim, HEAD_DIM, {DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, {
return BatchDecodeWithPagedKVCacheDispatched<GROUP_SIZE, HEAD_DIM, page_storage,
kv_layout, ROTARY_MODE, DTypeIn, DTypeOut,
IdType>(
Expand Down Expand Up @@ -1295,12 +1295,12 @@ cudaError_t BatchDecodeWithPaddedKVCache(
throw std::invalid_argument(err_msg.str());
}

SWITCH_GQA_GROUP_SIZE(
DISPATCH_GQA_GROUP_SIZE(
num_qo_heads / num_kv_heads, GROUP_SIZE,
{SWITCH_HEAD_DIM(
{DISPATCH_HEAD_DIM(
head_dim, HEAD_DIM,
{SWITCH_ROTARY_MODE(
rotary_mode, ROTARY_MODE, {SWITCH_LAYOUT(kv_layout, KV_LAYOUT, {
{DISPATCH_ROTARY_MODE(
rotary_mode, ROTARY_MODE, {DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, {
return BatchDecodeWithPaddedKVCacheDispatched<GROUP_SIZE, HEAD_DIM, KV_LAYOUT,
ROTARY_MODE, DTypeIn, DTypeOut>(
q, k, v, o, tmp, lse, batch_size, padded_kv_len, num_qo_heads, rope_scale,
Expand Down
189 changes: 0 additions & 189 deletions include/flashinfer/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include <vector>

#include "decode.cuh"
#include "prefill.cuh"
#include "rope.cuh"
#include "utils.cuh"

Expand Down Expand Up @@ -243,193 +242,5 @@ class BatchPrefillHandler {
cudaStream_t stream_;
};

/*!
* \brief Wrapper of BatchDecodeWithPagedKVCache function, and caches the temporary buffer
* for cooperative kernels.
* \tparam page_storage Whether to store indices or pointers of each active page
* \tparam kv_layout The layout of last 3 dimensions in KV-Cache
* \tparam DTypeIn The data type of input tensor.
* \tparam DTypeOut The data type of output tensor.
* \tparam IdType The data type of index tensor.
* \param handler The handler for the batch decode forward request.
* \param q The input tensor.
* \param paged_kv The paged key-value tensor.
* \param o The output tensor.
* \param lse The logsumexp values.
* \param num_qo_heads The number of heads.
* \param rotary_mode The rotary mode.
* \param rope_scale The scale of rope.
* \param rope_theta The theta of rope.
* \param stream The CUDA stream.
* \note This wrapper function should be only called after we call BeginForward function in the
* BatchDecodeHandler.
*/
template <PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
typename IdType>
cudaError_t BatchDecodeWithPagedKVCacheWrapper(
BatchDecodeHandler* handler, DTypeIn* q,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
uint32_t num_qo_heads, RotaryMode rotary_mode = RotaryMode::kNone, float rope_scale = 1.f,
float rope_theta = 1e4, cudaStream_t stream = nullptr) {
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> new_paged_kv = paged_kv;
kv_partition_info_t<IdType> kv_partition_info;
DTypeOut* tmp = handler->GetTempFloatBuffer<DTypeOut>();
if (handler->IsForwardStarted()) {
if (tmp != nullptr) {
// create auxiliary information for cooperative kernels
new_paged_kv.batch_size = handler->GetBatchSizeAfterPartition();
new_paged_kv.indptr = handler->GetNewIndPtr<IdType>();
new_paged_kv.last_page_len = handler->GetNewLastPageLen<IdType>();
kv_partition_info.batch_size_before_partition = handler->GetBatchSizeBeforePartition();
kv_partition_info.chunk_indptr = handler->GetChunkIndPtr<IdType>();
kv_partition_info.batch_idx_map = handler->GetBatchIdxMap<IdType>();
kv_partition_info.chunk_start_pos = handler->GetChunkStartPos<IdType>();
kv_partition_info.seq_lens_before_partition = handler->GetSeqLengthsBeforePartition<IdType>();
}
} else {
std::ostringstream err_msg;
err_msg << "Please call BatchDecodeHandler's BeginForward() before calling "
"BatchDecodeWithPagedKVCacheWrapper()";
throw std::runtime_error(err_msg.str());
}
return BatchDecodeWithPagedKVCache<page_storage, kv_layout, DTypeIn, DTypeOut, IdType>(
q, new_paged_kv, kv_partition_info, o, tmp, lse, num_qo_heads, rotary_mode, rope_scale,
rope_theta, stream);
}

template <PageStorage page_storage, QKVLayout kv_layout, uint32_t GROUP_SIZE, uint32_t HEAD_DIM,
RotaryMode ROTARY_MODE, bool ALLOW_FP16_QK_REDUCTION, bool CAUSAL, typename DTypeIn,
typename DTypeOut, typename IdType>
cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
uint32_t num_qo_heads, float rope_scale = 1.f, float rope_theta = 1e4,
cudaStream_t stream = nullptr) {
float* tmp = nullptr;
IdType* request_indices = nullptr;
IdType* tile_indices = nullptr;
uint32_t num_frags_x = 0U;
uint32_t num_qo_tiles = 0U;
if (handler->IsForwardStarted()) {
request_indices = handler->GetRequestIndices<IdType>();
tile_indices = handler->GetTileIndices<IdType>();
num_frags_x = handler->GetNumFragsX();
num_qo_tiles = handler->GetNumQOTiles();
} else {
std::ostringstream err_msg;
err_msg << "Please call BatchPrefillHandler's BeginForward() before calling "
"BatchPrefillWithPagedKVCacheWrapper()";
throw std::runtime_error(err_msg.str());
}

SWITCH_NUM_FRAGS_X(
num_frags_x, NUM_FRAGS_X, {SWITCH_PAGE_SIZE(paged_kv.page_size, PAGE_SIZE, {
if constexpr (PAGE_SIZE == 0) {
return BatchPrefillWithPagedKVCacheFallbackDispatched<
page_storage, kv_layout, NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, ROTARY_MODE,
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
q, request_indices, tile_indices, qo_indptr, paged_kv, o, tmp, lse, num_qo_tiles,
rope_scale, rope_theta, stream);
} else {
return BatchPrefillWithPagedKVCacheDispatched<
page_storage, kv_layout, NUM_FRAGS_X, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, ROTARY_MODE,
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
q, request_indices, tile_indices, qo_indptr, paged_kv, o, tmp, lse, num_qo_tiles,
rope_scale, rope_theta, stream);
}
})});
return cudaSuccess;
}

template <PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
typename IdType>
cudaError_t BatchPrefillWithPagedKVCacheWrapper(
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
uint32_t num_qo_heads, bool causal = true, RotaryMode rotary_mode = RotaryMode::kNone,
bool allow_fp16_qk_reduction = false, float rope_scale = 1.f, float rope_theta = 1e4,
cudaStream_t stream = nullptr) {
const uint32_t num_kv_heads = paged_kv.num_heads;
const uint32_t head_dim = paged_kv.head_dim;
SWITCH_GQA_GROUP_SIZE(
num_qo_heads / num_kv_heads, GROUP_SIZE,
{SWITCH_HEAD_DIM(
head_dim, HEAD_DIM,
{SWITCH_CAUSAL(causal, CAUSAL,
{SWITCH_ROTARY_MODE(
rotary_mode, ROTARY_MODE,
{SWITCH_ALLOW_FP16_QK_REDUCTION(
allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, {
return BatchPrefillWithPagedKVCacheWrapperDispatched<
page_storage, kv_layout, GROUP_SIZE, HEAD_DIM, ROTARY_MODE,
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
handler, q, qo_indptr, paged_kv, o, lse, num_qo_heads,
rope_scale, rope_theta, stream);
})})})})});
return cudaSuccess;
}

template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT, RotaryMode ROTARY_MODE,
bool ALLOW_FP16_QK_REDUCTION, bool CAUSAL, typename DTypeIn, typename DTypeOut,
typename IdType>
cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched(
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v,
IdType* kv_indptr, DTypeOut* o, float* lse, const uint32_t batch_size,
const uint32_t num_kv_heads, const float rope_scale = 1.f, const float rope_theta = 1e4,
cudaStream_t stream = nullptr) {
float* tmp = nullptr;
IdType* request_indices = nullptr;
IdType* tile_indices = nullptr;
uint32_t num_frags_x = 0U;
uint32_t num_qo_tiles = 0U;
if (handler->IsForwardStarted()) {
request_indices = handler->GetRequestIndices<IdType>();
tile_indices = handler->GetTileIndices<IdType>();
num_frags_x = handler->GetNumFragsX();
num_qo_tiles = handler->GetNumQOTiles();
} else {
std::ostringstream err_msg;
err_msg << "Please call BatchPrefillHandler's BeginForward() before calling "
"BatchPrefillWithRaggedKVWrapperCache()";
throw std::runtime_error(err_msg.str());
}

SWITCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, {
return BatchPrefillWithRaggedKVCacheDispatched<NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, KV_LAYOUT,
ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL,
DTypeIn, DTypeOut, IdType>(
q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, o, tmp, lse, batch_size,
num_qo_tiles, num_kv_heads, rope_scale, rope_theta, stream);
});
return cudaSuccess;
}

template <typename DTypeIn, typename DTypeOut, typename IdType>
cudaError_t BatchPrefillWithRaggedKVCacheWrapper(
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v,
IdType* kv_indptr, DTypeOut* o, float* lse, const uint32_t batch_size,
const uint32_t num_qo_heads, const uint32_t num_kv_heads, const uint32_t head_dim,
bool causal = true, RotaryMode rotary_mode = RotaryMode::kNone,
bool allow_fp16_qk_reduction = false, const float rope_scale = 1.f,
const float rope_theta = 1e4, cudaStream_t stream = nullptr) {
constexpr QKVLayout KV_LAYOUT = QKVLayout::kNHD;
SWITCH_GQA_GROUP_SIZE(
num_qo_heads / num_kv_heads, GROUP_SIZE,
{SWITCH_HEAD_DIM(
head_dim, HEAD_DIM,
{SWITCH_CAUSAL(causal, CAUSAL,
{SWITCH_ROTARY_MODE(
rotary_mode, ROTARY_MODE,
{SWITCH_ALLOW_FP16_QK_REDUCTION(
allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, {
return BatchPrefillWithRaggedKVCacheWrapperDispatched<
GROUP_SIZE, HEAD_DIM, KV_LAYOUT, ROTARY_MODE,
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
handler, q, qo_indptr, k, v, kv_indptr, o, lse, batch_size,
num_kv_heads, rope_scale, rope_theta, stream);
})})})})});
return cudaSuccess;
}

} // namespace flashinfer
#endif // FLASHINFER_HANDLER_CUH_
6 changes: 3 additions & 3 deletions include/flashinfer/page.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ cudaError_t AppendPagedKVCacheDecode(paged_kv_t<page_storage, layout, DType, IdT
uint32_t head_dim = paged_kv.head_dim;
uint32_t batch_size = paged_kv.batch_size;
uint32_t num_heads = paged_kv.num_heads;
SWITCH_HEAD_DIM(head_dim, HEAD_DIM, {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
uint32_t bdx = HEAD_DIM / vec_size;
uint32_t bdy = num_heads;
Expand Down Expand Up @@ -449,7 +449,7 @@ cudaError_t AppendPagedKVCache(paged_kv_t<page_storage, layout, DType, IdType> p
uint32_t head_dim = paged_kv.head_dim;
uint32_t batch_size = paged_kv.batch_size;
uint32_t num_heads = paged_kv.num_heads;
SWITCH_HEAD_DIM(head_dim, HEAD_DIM, {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
uint32_t bdx = HEAD_DIM / vec_size;
uint32_t bdy = num_heads;
Expand Down Expand Up @@ -530,7 +530,7 @@ cudaError_t PagedKVCacheToRaggedTensor(paged_kv_t<page_storage, layout, DType, I
const uint32_t num_heads = paged_kv.num_heads;
const uint32_t page_size = paged_kv.page_size;

SWITCH_HEAD_DIM(head_dim, HEAD_DIM, {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16U / sizeof(DType), HEAD_DIM / 32U);
uint32_t bdx = HEAD_DIM / vec_size;
uint32_t bdy = num_heads;
Expand Down
Loading

0 comments on commit 6031092

Please sign in to comment.