From a0595866db384b4a782c1ec70df72251b17de287 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 24 Nov 2024 02:38:28 -0800 Subject: [PATCH] perf: speedup jit compilation of prefill attention kernels (#632) Followup of https://github.com/flashinfer-ai/flashinfer/pull/628, this PR splits prefill attention jit templates so that we compile different mask modes in different files. JIT compilation time of a prefill kernels of a certain configuration (shape, dtype etc) could be reduced to 10 seconds after this PR. --- python/flashinfer/jit/batch_prefill_templ.py | 129 +++++++++++++++++- python/flashinfer/jit/single_prefill_templ.py | 113 +++++++++++++-- 2 files changed, 224 insertions(+), 18 deletions(-) diff --git a/python/flashinfer/jit/batch_prefill_templ.py b/python/flashinfer/jit/batch_prefill_templ.py index b6be762c..c07cd73b 100644 --- a/python/flashinfer/jit/batch_prefill_templ.py +++ b/python/flashinfer/jit/batch_prefill_templ.py @@ -14,15 +14,102 @@ limitations under the License. """ -import itertools - batch_prefill_suffix = [ "_plan.cu", + *[f"_ragged_kernel_mask_{mask_mode}.cu" for mask_mode in [0, 1, 2]], "_ragged_run.cu", + *[f"_paged_kernel_mask_{mask_mode}.cu" for mask_mode in [0, 1, 2]], "_paged_run.cu", "_pybind.cc", ] + +def ragged_prefill_inst_templ(mask_mode: str) -> str: + return ( + r"""#include +#include +#include + +namespace flashinfer { + +{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %} +using RaggedParamsT = BatchPrefillRaggedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>; +constexpr bool use_custom_mask = """ + + mask_mode + + r""" == MaskMode::kCustom; +using RaggedAttentionVariant = ComposedAttention; + +template +cudaError_t BatchPrefillWithRaggedKVCacheDispatched( + typename RaggedAttentionVariant::ParamsT params, + typename RaggedAttentionVariant::DTypeO* tmp_v, + float* tmp_s, cudaStream_t stream); + +template +cudaError_t BatchPrefillWithRaggedKVCacheDispatched( + typename RaggedAttentionVariant::ParamsT params, + typename RaggedAttentionVariant::DTypeO* tmp_v, + float* tmp_s, cudaStream_t stream); + +template +cudaError_t BatchPrefillWithRaggedKVCacheDispatched( + typename RaggedAttentionVariant::ParamsT params, + typename RaggedAttentionVariant::DTypeO* tmp_v, + float* tmp_s, cudaStream_t stream); +} +""" + ) + + +def paged_prefill_inst_templ(mask_mode: str) -> str: + return ( + r"""#include +#include +#include + +namespace flashinfer { + +{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %} +using PagedParamsT = BatchPrefillPagedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>; +constexpr bool use_custom_mask = """ + + mask_mode + + r""" == MaskMode::kCustom; +using PagedAttentionVariant = ComposedAttention; + +template +cudaError_t BatchPrefillWithPagedKVCacheDispatched( + typename PagedAttentionVariant::ParamsT params, + typename PagedAttentionVariant::DTypeO* tmp_v, + float* tmp_s, cudaStream_t stream); + +template +cudaError_t BatchPrefillWithPagedKVCacheDispatched( + typename PagedAttentionVariant::ParamsT params, + typename PagedAttentionVariant::DTypeO* tmp_v, + float* tmp_s, cudaStream_t stream); + +template +cudaError_t BatchPrefillWithPagedKVCacheDispatched( + typename PagedAttentionVariant::ParamsT params, + typename PagedAttentionVariant::DTypeO* tmp_v, + float* tmp_s, cudaStream_t stream); +} +""" + ) + + batch_prefill_templ = [ r"""#include #include "pytorch_extension_utils.h" @@ -60,10 +147,15 @@ return plan_info.ToVector(); } """, + *[ + ragged_prefill_inst_templ(mask_mode) + for mask_mode in ["MaskMode::kNone", "MaskMode::kCausal", "MaskMode::kCustom"] + ], r""" #include +#include #include -#include +#include #include #include #include "pytorch_extension_utils.h" @@ -73,6 +165,16 @@ {% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %} using RaggedParamsT = BatchPrefillRaggedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>; +namespace flashinfer { + +template +cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::ParamsT params, + typename AttentionVariant::DTypeO* tmp_v, + float* tmp_s, cudaStream_t stream); + +}; + void BatchPrefillWithRaggedKVCacheRun( unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, @@ -153,7 +255,7 @@ constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom; using RaggedAttentionVariant = ComposedAttention; DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, { - status = BatchPrefillWithRaggedKVCacheDispatched< + status = flashinfer::BatchPrefillWithRaggedKVCacheDispatched< CTA_TILE_Q, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, MASK_MODE, RaggedAttentionVariant>( params, tmp_v, tmp_s, stream); }); @@ -162,9 +264,14 @@ TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithRaggedKVCache failed with error ", cudaGetErrorString(status)); } """, + *[ + paged_prefill_inst_templ(mask_mode) + for mask_mode in ["MaskMode::kNone", "MaskMode::kCausal", "MaskMode::kCustom"] + ], r"""#include +#include #include -#include +#include #include #include #include "pytorch_extension_utils.h" @@ -174,6 +281,16 @@ {% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %} using PagedParamsT = BatchPrefillPagedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>; +namespace flashinfer { + +template +cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::ParamsT params, + typename AttentionVariant::DTypeO* tmp_v, + float* tmp_s, cudaStream_t stream); + +}; + void BatchPrefillWithPagedKVCacheRun( unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, @@ -274,7 +391,7 @@ constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom; using PagedAttentionVariant = ComposedAttention; DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, { - status = BatchPrefillWithPagedKVCacheDispatched< + status = flashinfer::BatchPrefillWithPagedKVCacheDispatched< CTA_TILE_Q, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, MASK_MODE, PagedAttentionVariant>( params, tmp_v, tmp_s, stream); }); diff --git a/python/flashinfer/jit/single_prefill_templ.py b/python/flashinfer/jit/single_prefill_templ.py index 83e5ebcb..a2b7104a 100644 --- a/python/flashinfer/jit/single_prefill_templ.py +++ b/python/flashinfer/jit/single_prefill_templ.py @@ -15,19 +15,12 @@ """ single_prefill_suffix = [ + *[f"_kernel_mask_{mask_mode}.cu" for mask_mode in [0, 1, 2]], ".cu", "_pybind.cc", ] -customizable_single_prefill_templ = [ - r""" -#include -#include -#include "pytorch_extension_utils.h" - -using namespace flashinfer; - - +customizable_struct_templ = r""" struct SinglePrefillParams { using DTypeQ = {{ dtype_q }}; using DTypeKV = {{ dtype_kv }}; @@ -82,10 +75,63 @@ return kv_len; } }; +""" + + +def customizable_single_prefill_inst_templ(mask_mode: str) -> str: + return ( + r"""#include + +using namespace flashinfer; +""" + + customizable_struct_templ + + r"""{{ variant_decl }} +using ParamsT = SinglePrefillParams; +using AttentionVariant = {{ variant_name }}; + +namespace flashinfer { + +template +cudaError_t SinglePrefillWithKVCacheDispatched<{{ head_dim }}, PosEncodingMode::kNone, false, """ + f"{mask_mode}" + r""", AttentionVariant>( + typename AttentionVariant::ParamsT params, + typename AttentionVariant::DTypeO* tmp, + cudaStream_t stream); + +}; +""" + ) + + +customizable_single_prefill_templ = [ + *[ + customizable_single_prefill_inst_templ(mask_mode) + for mask_mode in ["MaskMode::kNone", "MaskMode::kCausal", "MaskMode::kCustom"] + ], + r""" +#include +#include +#include +#include "pytorch_extension_utils.h" +using namespace flashinfer; +""" + + customizable_struct_templ + + r""" {{ variant_decl }} +namespace flashinfer { + +template +cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::ParamsT params, + typename AttentionVariant::DTypeO* tmp, + cudaStream_t stream); + +} + at::Tensor single_prefill_with_kv_cache( unsigned int mask_mode_code, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp, at::Tensor o, unsigned int layout, int32_t window_left, @@ -155,10 +201,43 @@ """, ] + +def single_prefill_inst_templ(mask_mode: str) -> str: + return ( + r"""#include +#include +#include + +namespace flashinfer { + +{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %} +using ParamsT = SinglePrefillParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}>; +constexpr bool use_custom_mask = """ + f"{mask_mode}" + r"""== MaskMode::kCustom; +using AttentionVariant = ComposedAttention; + +template +cudaError_t SinglePrefillWithKVCacheDispatched<{{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """ + f"{mask_mode}" + r""", AttentionVariant>( + typename AttentionVariant::ParamsT params, + typename AttentionVariant::DTypeO* tmp, + cudaStream_t stream); + +} +""" + ) + + single_prefill_templ = [ - r""" -#include -#include + *[ + single_prefill_inst_templ(mask_mode) + for mask_mode in ["MaskMode::kNone", "MaskMode::kCausal", "MaskMode::kCustom"] + ], + r"""#include +#include +#include #include #include #include "pytorch_extension_utils.h" @@ -168,6 +247,16 @@ {% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %} using ParamsT = SinglePrefillParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}>; +namespace flashinfer { + +template +cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::ParamsT params, + typename AttentionVariant::DTypeO* tmp, + cudaStream_t stream); + +} + void single_prefill_with_kv_cache( unsigned int mask_mode_code, at::Tensor q, at::Tensor k, at::Tensor v, std::optional maybe_packed_custom_mask,