@@ -1971,8 +1971,7 @@ def selected_mask_types(kspec):
19711971 sliding_or_chunked_causal_mask = '0'
19721972 custom_mask = '0'
19731973 elif (kspec .head_size , kspec .head_size_v ) == (192 , 128 ):
1974- # MLA context phase only needs causal mask now
1975- padding_mask = '0'
1974+ # MLA context phase only needs causal mask and padding mask (for chunked prefill) now
19761975 sliding_or_chunked_causal_mask = '0'
19771976 custom_mask = '0'
19781977 elif (kspec .head_size , kspec .head_size_v ) == (576 , 512 ):
@@ -2311,8 +2310,7 @@ def gen_call(kspec, lname):
23112310 # whether support alibi or not.
23122311 if kspec .warp_specialization :
23132312 il_check += '&& params.has_alibi ' if kspec .alibi else '&& !params.has_alibi '
2314- if kspec .input_layout .value == InputLayout .CONTIGUOUS_Q_KV :
2315- il_check += '&& params.softmax_stats_ptr != nullptr ' if kspec .return_softmax_stats else '&& params.softmax_stats_ptr == nullptr '
2313+ il_check += '&& params.softmax_stats_ptr != nullptr ' if kspec .return_softmax_stats else '&& params.softmax_stats_ptr == nullptr '
23162314 # use enable_attn_logit_softcapping or not.
23172315 il_check += '&& enable_attn_logit_softcapping ' if kspec .enable_attn_logit_softcapping else '&& !enable_attn_logit_softcapping '
23182316 # check sage block sizes
@@ -3653,104 +3651,110 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
36533651 # alibi and enable_attn_logit_softcapping shouldn't be used together.
36543652 if alibi and enable_attn_logit_softcapping :
36553653 continue
3656- if input_layout != InputLayout .CONTIGUOUS_Q_KV and return_softmax :
3657- continue
3658- # only specify
3659- specs .append (
3660- kernel_spec (
3661- sm = sm ,
3662- sm_mma = 90 ,
3663- dtype = dtype ,
3664- seq_len = 0 , # support any sequence length
3665- head_size = [32 , 40 , 48 , 64 ],
3666- warps_m = 4 , #4x1 warpgroups
3667- warps_n = 1 ,
3668- version = 2 ,
3669- interleaved = False ,
3670- ldgsts_q =
3671- False , # for Hopper kernels, ldgsts = False signals TMA usage.
3672- ldgsts_k = False ,
3673- ldgsts_v = False ,
3674- share_smem_k_v = False ,
3675- loop_step = 64 ,
3676- q_tile_buffers = 1 , # only used by warp specialized kernels
3677- has_noloop = 0 ,
3678- noloop_step = 64 ,
3679- kv_loop_step = 256 ,
3680- kv_tile_buffers = 2 , # only used by warp specialized kernels
3681- unroll_threshold = 1 ,
3682- has_scale_max = False ,
3683- flash_attention = True ,
3684- warp_specialization = True ,
3685- alibi = alibi ,
3686- enable_attn_logit_softcapping = enable_attn_logit_softcapping ,
3687- return_softmax_stats = return_softmax ,
3688- scheduling_mode = scheduling_mode ,
3689- input_layout = input_layout ))
3654+ # for normal attention, we only need contiguous kv as input layout when returning softmax.
3655+ skip_combination = return_softmax and (input_layout
3656+ != InputLayout .CONTIGUOUS_Q_KV )
3657+ # for context mla, we need paged kv or separate qkv as input layout when returning softmax.
3658+ skip_mla_combination = return_softmax and (
3659+ input_layout != InputLayout .Q_PAGED_KV
3660+ and input_layout != InputLayout .SEPARATE_Q_K_V )
3661+ if not skip_combination :
3662+ # only specify
3663+ specs .append (
3664+ kernel_spec (
3665+ sm = sm ,
3666+ sm_mma = 90 ,
3667+ dtype = dtype ,
3668+ seq_len = 0 , # support any sequence length
3669+ head_size = [32 , 40 , 48 , 64 ],
3670+ warps_m = 4 , #4x1 warpgroups
3671+ warps_n = 1 ,
3672+ version = 2 ,
3673+ interleaved = False ,
3674+ ldgsts_q =
3675+ False , # for Hopper kernels, ldgsts = False signals TMA usage.
3676+ ldgsts_k = False ,
3677+ ldgsts_v = False ,
3678+ share_smem_k_v = False ,
3679+ loop_step = 64 ,
3680+ q_tile_buffers = 1 , # only used by warp specialized kernels
3681+ has_noloop = 0 ,
3682+ noloop_step = 64 ,
3683+ kv_loop_step = 256 ,
3684+ kv_tile_buffers = 2 , # only used by warp specialized kernels
3685+ unroll_threshold = 1 ,
3686+ has_scale_max = False ,
3687+ flash_attention = True ,
3688+ warp_specialization = True ,
3689+ alibi = alibi ,
3690+ enable_attn_logit_softcapping = enable_attn_logit_softcapping ,
3691+ return_softmax_stats = return_softmax ,
3692+ scheduling_mode = scheduling_mode ,
3693+ input_layout = input_layout ))
36903694
3691- specs .append (
3692- kernel_spec (
3693- sm = sm ,
3694- sm_mma = 90 ,
3695- dtype = dtype ,
3696- seq_len = 0 , # support any sequence length
3697- head_size = [72 , 80 , 96 , 104 , 128 ],
3698- warps_m = 4 , #4x1 warpgroups
3699- warps_n = 1 ,
3700- version = 2 ,
3701- interleaved = False ,
3702- ldgsts_q =
3703- False , # for Hopper kernels, ldgsts = False signals TMA usage.
3704- ldgsts_k = False ,
3705- ldgsts_v = False ,
3706- share_smem_k_v = False ,
3707- loop_step = 64 ,
3708- q_tile_buffers = 1 , # only used by warp specialized kernels
3709- has_noloop = 0 ,
3710- noloop_step = 64 ,
3711- kv_loop_step = 128 ,
3712- kv_tile_buffers = 2 , # only used by warp specialized kernels
3713- unroll_threshold = 1 ,
3714- has_scale_max = False ,
3715- flash_attention = True ,
3716- warp_specialization = True ,
3717- alibi = alibi ,
3718- enable_attn_logit_softcapping = enable_attn_logit_softcapping ,
3719- return_softmax_stats = return_softmax ,
3720- scheduling_mode = scheduling_mode ,
3721- input_layout = input_layout ))
3695+ specs .append (
3696+ kernel_spec (
3697+ sm = sm ,
3698+ sm_mma = 90 ,
3699+ dtype = dtype ,
3700+ seq_len = 0 , # support any sequence length
3701+ head_size = [72 , 80 , 96 , 104 , 128 ],
3702+ warps_m = 4 , #4x1 warpgroups
3703+ warps_n = 1 ,
3704+ version = 2 ,
3705+ interleaved = False ,
3706+ ldgsts_q =
3707+ False , # for Hopper kernels, ldgsts = False signals TMA usage.
3708+ ldgsts_k = False ,
3709+ ldgsts_v = False ,
3710+ share_smem_k_v = False ,
3711+ loop_step = 64 ,
3712+ q_tile_buffers = 1 , # only used by warp specialized kernels
3713+ has_noloop = 0 ,
3714+ noloop_step = 64 ,
3715+ kv_loop_step = 128 ,
3716+ kv_tile_buffers = 2 , # only used by warp specialized kernels
3717+ unroll_threshold = 1 ,
3718+ has_scale_max = False ,
3719+ flash_attention = True ,
3720+ warp_specialization = True ,
3721+ alibi = alibi ,
3722+ enable_attn_logit_softcapping = enable_attn_logit_softcapping ,
3723+ return_softmax_stats = return_softmax ,
3724+ scheduling_mode = scheduling_mode ,
3725+ input_layout = input_layout ))
37223726
3723- specs .append (
3724- kernel_spec (
3725- sm = sm ,
3726- sm_mma = 90 ,
3727- dtype = dtype ,
3728- seq_len = 0 , # support any sequence length
3729- head_size = [160 , 192 , 256 ],
3730- warps_m = 4 , #4x1 warpgroups
3731- warps_n = 1 ,
3732- version = 2 ,
3733- interleaved = False ,
3734- ldgsts_q =
3735- False , # for Hopper kernels, ldgsts = False signals TMA usage.
3736- ldgsts_k = False ,
3737- ldgsts_v = False ,
3738- share_smem_k_v = False ,
3739- loop_step = 64 ,
3740- q_tile_buffers = 1 , # only used by warp specialized kernels
3741- has_noloop = 0 ,
3742- noloop_step = 64 ,
3743- kv_loop_step = 64 ,
3744- kv_tile_buffers = 2 , # only used by warp specialized kernels
3745- unroll_threshold = 1 ,
3746- has_scale_max = False ,
3747- flash_attention = True ,
3748- warp_specialization = True ,
3749- alibi = alibi ,
3750- enable_attn_logit_softcapping = enable_attn_logit_softcapping ,
3751- return_softmax_stats = return_softmax ,
3752- scheduling_mode = scheduling_mode ,
3753- input_layout = input_layout ))
3727+ specs .append (
3728+ kernel_spec (
3729+ sm = sm ,
3730+ sm_mma = 90 ,
3731+ dtype = dtype ,
3732+ seq_len = 0 , # support any sequence length
3733+ head_size = [160 , 192 , 256 ],
3734+ warps_m = 4 , #4x1 warpgroups
3735+ warps_n = 1 ,
3736+ version = 2 ,
3737+ interleaved = False ,
3738+ ldgsts_q =
3739+ False , # for Hopper kernels, ldgsts = False signals TMA usage.
3740+ ldgsts_k = False ,
3741+ ldgsts_v = False ,
3742+ share_smem_k_v = False ,
3743+ loop_step = 64 ,
3744+ q_tile_buffers = 1 , # only used by warp specialized kernels
3745+ has_noloop = 0 ,
3746+ noloop_step = 64 ,
3747+ kv_loop_step = 64 ,
3748+ kv_tile_buffers = 2 , # only used by warp specialized kernels
3749+ unroll_threshold = 1 ,
3750+ has_scale_max = False ,
3751+ flash_attention = True ,
3752+ warp_specialization = True ,
3753+ alibi = alibi ,
3754+ enable_attn_logit_softcapping = enable_attn_logit_softcapping ,
3755+ return_softmax_stats = return_softmax ,
3756+ scheduling_mode = scheduling_mode ,
3757+ input_layout = input_layout ))
37543758 '''
37553759 smem size = (q_step * d * q_buffers * NUM_COMPUTE_GROUPS
37563760 + (kv_step * d + kv_step * dv) * kv_buffers) * ele_size
@@ -3762,38 +3766,39 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
37623766 Then for fp16/bf16 context MLA, d remains 192 (192 * 2 = 128 * 3), and dv remains 128,
37633767 if kv_step = 128, then smem_size = 208 KB, smem is fully utilized.
37643768 '''
3765- specs .append (
3766- kernel_spec (
3767- sm = sm ,
3768- sm_mma = 90 ,
3769- dtype = dtype ,
3770- seq_len = 0 , # support any sequence length
3771- head_size = 192 ,
3772- head_size_v = 128 ,
3773- warps_m = 4 , #4x1 warpgroups
3774- warps_n = 1 ,
3775- version = 2 ,
3776- interleaved = False ,
3777- ldgsts_q =
3778- False , # for Hopper kernels, ldgsts = False signals TMA usage.
3779- ldgsts_k = False ,
3780- ldgsts_v = False ,
3781- share_smem_k_v = False ,
3782- loop_step = 64 ,
3783- q_tile_buffers = 1 , # only used by warp specialized kernels
3784- has_noloop = 0 ,
3785- noloop_step = 64 ,
3786- kv_loop_step = 128 ,
3787- kv_tile_buffers = 2 , # only used by warp specialized kernels
3788- unroll_threshold = 1 ,
3789- has_scale_max = False ,
3790- flash_attention = True ,
3791- warp_specialization = True ,
3792- alibi = alibi ,
3793- enable_attn_logit_softcapping = enable_attn_logit_softcapping ,
3794- return_softmax_stats = return_softmax ,
3795- scheduling_mode = scheduling_mode ,
3796- input_layout = input_layout ))
3769+ if not skip_mla_combination :
3770+ specs .append (
3771+ kernel_spec (
3772+ sm = sm ,
3773+ sm_mma = 90 ,
3774+ dtype = dtype ,
3775+ seq_len = 0 , # support any sequence length
3776+ head_size = 192 ,
3777+ head_size_v = 128 ,
3778+ warps_m = 4 , #4x1 warpgroups
3779+ warps_n = 1 ,
3780+ version = 2 ,
3781+ interleaved = False ,
3782+ ldgsts_q =
3783+ False , # for Hopper kernels, ldgsts = False signals TMA usage.
3784+ ldgsts_k = False ,
3785+ ldgsts_v = False ,
3786+ share_smem_k_v = False ,
3787+ loop_step = 64 ,
3788+ q_tile_buffers = 1 , # only used by warp specialized kernels
3789+ has_noloop = 0 ,
3790+ noloop_step = 64 ,
3791+ kv_loop_step = 128 ,
3792+ kv_tile_buffers = 2 , # only used by warp specialized kernels
3793+ unroll_threshold = 1 ,
3794+ has_scale_max = False ,
3795+ flash_attention = True ,
3796+ warp_specialization = True ,
3797+ alibi = alibi ,
3798+ enable_attn_logit_softcapping = enable_attn_logit_softcapping ,
3799+ return_softmax_stats = return_softmax ,
3800+ scheduling_mode = scheduling_mode ,
3801+ input_layout = input_layout ))
37973802
37983803
37993804# Note this will be used in TRT-LLM.
0 commit comments