Skip to content

Commit

Permalink
perf: speedup jit compilation of prefill attention kernels (#632)
Browse files Browse the repository at this point in the history
Followup of #628, this
PR splits prefill attention jit templates so that we compile different
mask modes in different files.

JIT compilation time of a prefill kernels of a certain configuration
(shape, dtype etc) could be reduced to 10 seconds after this PR.
  • Loading branch information
yzh119 authored Nov 24, 2024
1 parent 5bf36ce commit a059586
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 18 deletions.
129 changes: 123 additions & 6 deletions python/flashinfer/jit/batch_prefill_templ.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,102 @@
limitations under the License.
"""

import itertools

batch_prefill_suffix = [
"_plan.cu",
*[f"_ragged_kernel_mask_{mask_mode}.cu" for mask_mode in [0, 1, 2]],
"_ragged_run.cu",
*[f"_paged_kernel_mask_{mask_mode}.cu" for mask_mode in [0, 1, 2]],
"_paged_run.cu",
"_pybind.cc",
]


def ragged_prefill_inst_templ(mask_mode: str) -> str:
return (
r"""#include <flashinfer/attention/prefill.cuh>
#include <flashinfer/attention/prefill_params.cuh>
#include <flashinfer/attention/variants.cuh>
namespace flashinfer {
{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
using RaggedParamsT = BatchPrefillRaggedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>;
constexpr bool use_custom_mask = """
+ mask_mode
+ r""" == MaskMode::kCustom;
using RaggedAttentionVariant = ComposedAttention<RaggedParamsT, get_variant_code(use_custom_mask, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;
template
cudaError_t BatchPrefillWithRaggedKVCacheDispatched</*cta_tile_q=*/16, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """
+ mask_mode
+ r""", RaggedAttentionVariant>(
typename RaggedAttentionVariant::ParamsT params,
typename RaggedAttentionVariant::DTypeO* tmp_v,
float* tmp_s, cudaStream_t stream);
template
cudaError_t BatchPrefillWithRaggedKVCacheDispatched</*cta_tile_q=*/64, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """
+ mask_mode
+ r""", RaggedAttentionVariant>(
typename RaggedAttentionVariant::ParamsT params,
typename RaggedAttentionVariant::DTypeO* tmp_v,
float* tmp_s, cudaStream_t stream);
template
cudaError_t BatchPrefillWithRaggedKVCacheDispatched</*cta_tile_q=*/128, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """
+ mask_mode
+ r""", RaggedAttentionVariant>(
typename RaggedAttentionVariant::ParamsT params,
typename RaggedAttentionVariant::DTypeO* tmp_v,
float* tmp_s, cudaStream_t stream);
}
"""
)


def paged_prefill_inst_templ(mask_mode: str) -> str:
return (
r"""#include <flashinfer/attention/prefill.cuh>
#include <flashinfer/attention/prefill_params.cuh>
#include <flashinfer/attention/variants.cuh>
namespace flashinfer {
{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
using PagedParamsT = BatchPrefillPagedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>;
constexpr bool use_custom_mask = """
+ mask_mode
+ r""" == MaskMode::kCustom;
using PagedAttentionVariant = ComposedAttention<PagedParamsT, get_variant_code(use_custom_mask, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;
template
cudaError_t BatchPrefillWithPagedKVCacheDispatched</*cta_tile_q=*/16, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """
+ mask_mode
+ r""", PagedAttentionVariant>(
typename PagedAttentionVariant::ParamsT params,
typename PagedAttentionVariant::DTypeO* tmp_v,
float* tmp_s, cudaStream_t stream);
template
cudaError_t BatchPrefillWithPagedKVCacheDispatched</*cta_tile_q=*/64, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """
+ mask_mode
+ r""", PagedAttentionVariant>(
typename PagedAttentionVariant::ParamsT params,
typename PagedAttentionVariant::DTypeO* tmp_v,
float* tmp_s, cudaStream_t stream);
template
cudaError_t BatchPrefillWithPagedKVCacheDispatched</*cta_tile_q=*/128, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """
+ mask_mode
+ r""", PagedAttentionVariant>(
typename PagedAttentionVariant::ParamsT params,
typename PagedAttentionVariant::DTypeO* tmp_v,
float* tmp_s, cudaStream_t stream);
}
"""
)


batch_prefill_templ = [
r"""#include <flashinfer/attention/scheduler.cuh>
#include "pytorch_extension_utils.h"
Expand Down Expand Up @@ -60,10 +147,15 @@
return plan_info.ToVector();
}
""",
*[
ragged_prefill_inst_templ(mask_mode)
for mask_mode in ["MaskMode::kNone", "MaskMode::kCausal", "MaskMode::kCustom"]
],
r"""
#include <optional>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/attention/scheduler.cuh>
#include <flashinfer/attention/prefill.cuh>
#include <flashinfer/attention/mask.cuh>
#include <flashinfer/attention/prefill_params.cuh>
#include <flashinfer/attention/variants.cuh>
#include "pytorch_extension_utils.h"
Expand All @@ -73,6 +165,16 @@
{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
using RaggedParamsT = BatchPrefillRaggedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>;
namespace flashinfer {
template <uint32_t CTA_TILE_Q, uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE,
bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE, typename AttentionVariant>
cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::ParamsT params,
typename AttentionVariant::DTypeO* tmp_v,
float* tmp_s, cudaStream_t stream);
};
void BatchPrefillWithRaggedKVCacheRun(
unsigned int mask_mode_code,
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
Expand Down Expand Up @@ -153,7 +255,7 @@
constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom;
using RaggedAttentionVariant = ComposedAttention<RaggedParamsT, get_variant_code(use_custom_mask, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;
DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, {
status = BatchPrefillWithRaggedKVCacheDispatched<
status = flashinfer::BatchPrefillWithRaggedKVCacheDispatched<
CTA_TILE_Q, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, MASK_MODE, RaggedAttentionVariant>(
params, tmp_v, tmp_s, stream);
});
Expand All @@ -162,9 +264,14 @@
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithRaggedKVCache failed with error ", cudaGetErrorString(status));
}
""",
*[
paged_prefill_inst_templ(mask_mode)
for mask_mode in ["MaskMode::kNone", "MaskMode::kCausal", "MaskMode::kCustom"]
],
r"""#include <optional>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/attention/scheduler.cuh>
#include <flashinfer/attention/prefill.cuh>
#include <flashinfer/attention/mask.cuh>
#include <flashinfer/attention/prefill_params.cuh>
#include <flashinfer/attention/variants.cuh>
#include "pytorch_extension_utils.h"
Expand All @@ -174,6 +281,16 @@
{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
using PagedParamsT = BatchPrefillPagedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>;
namespace flashinfer {
template <uint32_t CTA_TILE_Q, uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE,
bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE, typename AttentionVariant>
cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::ParamsT params,
typename AttentionVariant::DTypeO* tmp_v,
float* tmp_s, cudaStream_t stream);
};
void BatchPrefillWithPagedKVCacheRun(
unsigned int mask_mode_code,
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
Expand Down Expand Up @@ -274,7 +391,7 @@
constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom;
using PagedAttentionVariant = ComposedAttention<PagedParamsT, get_variant_code(use_custom_mask, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;
DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, {
status = BatchPrefillWithPagedKVCacheDispatched<
status = flashinfer::BatchPrefillWithPagedKVCacheDispatched<
CTA_TILE_Q, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, MASK_MODE, PagedAttentionVariant>(
params, tmp_v, tmp_s, stream);
});
Expand Down
113 changes: 101 additions & 12 deletions python/flashinfer/jit/single_prefill_templ.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,12 @@
"""

single_prefill_suffix = [
*[f"_kernel_mask_{mask_mode}.cu" for mask_mode in [0, 1, 2]],
".cu",
"_pybind.cc",
]

customizable_single_prefill_templ = [
r"""
#include <optional>
#include <flashinfer/attention/prefill.cuh>
#include "pytorch_extension_utils.h"
using namespace flashinfer;
customizable_struct_templ = r"""
struct SinglePrefillParams {
using DTypeQ = {{ dtype_q }};
using DTypeKV = {{ dtype_kv }};
Expand Down Expand Up @@ -82,10 +75,63 @@
return kv_len;
}
};
"""


def customizable_single_prefill_inst_templ(mask_mode: str) -> str:
return (
r"""#include <flashinfer/attention/prefill.cuh>
using namespace flashinfer;
"""
+ customizable_struct_templ
+ r"""{{ variant_decl }}
using ParamsT = SinglePrefillParams;
using AttentionVariant = {{ variant_name }}<ParamsT>;
namespace flashinfer {
template
cudaError_t SinglePrefillWithKVCacheDispatched<{{ head_dim }}, PosEncodingMode::kNone, false, """
f"{mask_mode}"
r""", AttentionVariant>(
typename AttentionVariant::ParamsT params,
typename AttentionVariant::DTypeO* tmp,
cudaStream_t stream);
};
"""
)


customizable_single_prefill_templ = [
*[
customizable_single_prefill_inst_templ(mask_mode)
for mask_mode in ["MaskMode::kNone", "MaskMode::kCausal", "MaskMode::kCustom"]
],
r"""
#include <optional>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/attention/mask.cuh>
#include "pytorch_extension_utils.h"
using namespace flashinfer;
"""
+ customizable_struct_templ
+ r"""
{{ variant_decl }}
namespace flashinfer {
template <uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE, bool ALLOW_FP16_QK_REDUCTION,
MaskMode MASK_MODE, typename AttentionVariant>
cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::ParamsT params,
typename AttentionVariant::DTypeO* tmp,
cudaStream_t stream);
}
at::Tensor single_prefill_with_kv_cache(
unsigned int mask_mode_code, at::Tensor q, at::Tensor k, at::Tensor v,
at::Tensor tmp, at::Tensor o, unsigned int layout, int32_t window_left,
Expand Down Expand Up @@ -155,10 +201,43 @@
""",
]


def single_prefill_inst_templ(mask_mode: str) -> str:
return (
r"""#include <flashinfer/attention/prefill.cuh>
#include <flashinfer/attention/prefill_params.cuh>
#include <flashinfer/attention/variants.cuh>
namespace flashinfer {
{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
using ParamsT = SinglePrefillParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}>;
constexpr bool use_custom_mask = """
f"{mask_mode}"
r"""== MaskMode::kCustom;
using AttentionVariant = ComposedAttention<ParamsT, get_variant_code(use_custom_mask, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;
template
cudaError_t SinglePrefillWithKVCacheDispatched<{{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """
f"{mask_mode}"
r""", AttentionVariant>(
typename AttentionVariant::ParamsT params,
typename AttentionVariant::DTypeO* tmp,
cudaStream_t stream);
}
"""
)


single_prefill_templ = [
r"""
#include <optional>
#include <flashinfer/attention/prefill.cuh>
*[
single_prefill_inst_templ(mask_mode)
for mask_mode in ["MaskMode::kNone", "MaskMode::kCausal", "MaskMode::kCustom"]
],
r"""#include <optional>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/attention/mask.cuh>
#include <flashinfer/attention/variants.cuh>
#include <flashinfer/attention/prefill_params.cuh>
#include "pytorch_extension_utils.h"
Expand All @@ -168,6 +247,16 @@
{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
using ParamsT = SinglePrefillParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}>;
namespace flashinfer {
template <uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE, bool ALLOW_FP16_QK_REDUCTION,
MaskMode MASK_MODE, typename AttentionVariant>
cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::ParamsT params,
typename AttentionVariant::DTypeO* tmp,
cudaStream_t stream);
}
void single_prefill_with_kv_cache(
unsigned int mask_mode_code,
at::Tensor q, at::Tensor k, at::Tensor v, std::optional<at::Tensor> maybe_packed_custom_mask,
Expand Down

0 comments on commit a059586

Please sign in to comment.