Skip to content

Commit 4200fa4

Browse files
authored
[None][feat] Add support for Hopper MLA chunked prefill (#6655)
Signed-off-by: Mingyang Jiang <[email protected]>
1 parent 868c5d1 commit 4200fa4

File tree

15 files changed

+253
-195
lines changed

15 files changed

+253
-195
lines changed

cpp/kernels/fmha_v2/fmha_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,23 @@ def test_trtllm_context_mla_attention_fmha(dtype, s, input_layout):
183183
shell=True,
184184
check=True)
185185

186+
# For chunked prefill, we need to enable -save-softmax (dtype: bf16, sm90, layout: paged-kv or separate-q-k-v).
187+
if dtype == "-bf16" and input_layout in [
188+
"-paged-kv", "-separate-q-k-v"
189+
]:
190+
# padding mask
191+
subprocess.run(
192+
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} \
193+
{epsilon} {input_layout} -save-softmax",
194+
shell=True,
195+
check=True)
196+
# causal mask
197+
subprocess.run(
198+
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} \
199+
-causal-mask {epsilon} {input_layout} -save-softmax",
200+
shell=True,
201+
check=True)
202+
186203

187204
@pytest.mark.parametrize('dtype', ["-bf16", "-e4m3", "-e4m3 -bf16-output"],
188205
ids=["bf16", "e4m3", "e4m3-bf16"])

cpp/kernels/fmha_v2/setup.py

Lines changed: 137 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -1971,8 +1971,7 @@ def selected_mask_types(kspec):
19711971
sliding_or_chunked_causal_mask = '0'
19721972
custom_mask = '0'
19731973
elif (kspec.head_size, kspec.head_size_v) == (192, 128):
1974-
# MLA context phase only needs causal mask now
1975-
padding_mask = '0'
1974+
# MLA context phase only needs causal mask and padding mask (for chunked prefill) now
19761975
sliding_or_chunked_causal_mask = '0'
19771976
custom_mask = '0'
19781977
elif (kspec.head_size, kspec.head_size_v) == (576, 512):
@@ -2311,8 +2310,7 @@ def gen_call(kspec, lname):
23112310
# whether support alibi or not.
23122311
if kspec.warp_specialization:
23132312
il_check += '&& params.has_alibi ' if kspec.alibi else '&& !params.has_alibi '
2314-
if kspec.input_layout.value == InputLayout.CONTIGUOUS_Q_KV:
2315-
il_check += '&& params.softmax_stats_ptr != nullptr ' if kspec.return_softmax_stats else '&& params.softmax_stats_ptr == nullptr '
2313+
il_check += '&& params.softmax_stats_ptr != nullptr ' if kspec.return_softmax_stats else '&& params.softmax_stats_ptr == nullptr '
23162314
# use enable_attn_logit_softcapping or not.
23172315
il_check += '&& enable_attn_logit_softcapping ' if kspec.enable_attn_logit_softcapping else '&& !enable_attn_logit_softcapping '
23182316
# check sage block sizes
@@ -3653,104 +3651,110 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
36533651
# alibi and enable_attn_logit_softcapping shouldn't be used together.
36543652
if alibi and enable_attn_logit_softcapping:
36553653
continue
3656-
if input_layout != InputLayout.CONTIGUOUS_Q_KV and return_softmax:
3657-
continue
3658-
# only specify
3659-
specs.append(
3660-
kernel_spec(
3661-
sm=sm,
3662-
sm_mma=90,
3663-
dtype=dtype,
3664-
seq_len=0, # support any sequence length
3665-
head_size=[32, 40, 48, 64],
3666-
warps_m=4, #4x1 warpgroups
3667-
warps_n=1,
3668-
version=2,
3669-
interleaved=False,
3670-
ldgsts_q=
3671-
False, # for Hopper kernels, ldgsts = False signals TMA usage.
3672-
ldgsts_k=False,
3673-
ldgsts_v=False,
3674-
share_smem_k_v=False,
3675-
loop_step=64,
3676-
q_tile_buffers=1, # only used by warp specialized kernels
3677-
has_noloop=0,
3678-
noloop_step=64,
3679-
kv_loop_step=256,
3680-
kv_tile_buffers=2, # only used by warp specialized kernels
3681-
unroll_threshold=1,
3682-
has_scale_max=False,
3683-
flash_attention=True,
3684-
warp_specialization=True,
3685-
alibi=alibi,
3686-
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
3687-
return_softmax_stats=return_softmax,
3688-
scheduling_mode=scheduling_mode,
3689-
input_layout=input_layout))
3654+
# for normal attention, we only need contiguous kv as input layout when returning softmax.
3655+
skip_combination = return_softmax and (input_layout
3656+
!= InputLayout.CONTIGUOUS_Q_KV)
3657+
# for context mla, we need paged kv or separate qkv as input layout when returning softmax.
3658+
skip_mla_combination = return_softmax and (
3659+
input_layout != InputLayout.Q_PAGED_KV
3660+
and input_layout != InputLayout.SEPARATE_Q_K_V)
3661+
if not skip_combination:
3662+
# only specify
3663+
specs.append(
3664+
kernel_spec(
3665+
sm=sm,
3666+
sm_mma=90,
3667+
dtype=dtype,
3668+
seq_len=0, # support any sequence length
3669+
head_size=[32, 40, 48, 64],
3670+
warps_m=4, #4x1 warpgroups
3671+
warps_n=1,
3672+
version=2,
3673+
interleaved=False,
3674+
ldgsts_q=
3675+
False, # for Hopper kernels, ldgsts = False signals TMA usage.
3676+
ldgsts_k=False,
3677+
ldgsts_v=False,
3678+
share_smem_k_v=False,
3679+
loop_step=64,
3680+
q_tile_buffers=1, # only used by warp specialized kernels
3681+
has_noloop=0,
3682+
noloop_step=64,
3683+
kv_loop_step=256,
3684+
kv_tile_buffers=2, # only used by warp specialized kernels
3685+
unroll_threshold=1,
3686+
has_scale_max=False,
3687+
flash_attention=True,
3688+
warp_specialization=True,
3689+
alibi=alibi,
3690+
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
3691+
return_softmax_stats=return_softmax,
3692+
scheduling_mode=scheduling_mode,
3693+
input_layout=input_layout))
36903694

3691-
specs.append(
3692-
kernel_spec(
3693-
sm=sm,
3694-
sm_mma=90,
3695-
dtype=dtype,
3696-
seq_len=0, # support any sequence length
3697-
head_size=[72, 80, 96, 104, 128],
3698-
warps_m=4, #4x1 warpgroups
3699-
warps_n=1,
3700-
version=2,
3701-
interleaved=False,
3702-
ldgsts_q=
3703-
False, # for Hopper kernels, ldgsts = False signals TMA usage.
3704-
ldgsts_k=False,
3705-
ldgsts_v=False,
3706-
share_smem_k_v=False,
3707-
loop_step=64,
3708-
q_tile_buffers=1, # only used by warp specialized kernels
3709-
has_noloop=0,
3710-
noloop_step=64,
3711-
kv_loop_step=128,
3712-
kv_tile_buffers=2, # only used by warp specialized kernels
3713-
unroll_threshold=1,
3714-
has_scale_max=False,
3715-
flash_attention=True,
3716-
warp_specialization=True,
3717-
alibi=alibi,
3718-
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
3719-
return_softmax_stats=return_softmax,
3720-
scheduling_mode=scheduling_mode,
3721-
input_layout=input_layout))
3695+
specs.append(
3696+
kernel_spec(
3697+
sm=sm,
3698+
sm_mma=90,
3699+
dtype=dtype,
3700+
seq_len=0, # support any sequence length
3701+
head_size=[72, 80, 96, 104, 128],
3702+
warps_m=4, #4x1 warpgroups
3703+
warps_n=1,
3704+
version=2,
3705+
interleaved=False,
3706+
ldgsts_q=
3707+
False, # for Hopper kernels, ldgsts = False signals TMA usage.
3708+
ldgsts_k=False,
3709+
ldgsts_v=False,
3710+
share_smem_k_v=False,
3711+
loop_step=64,
3712+
q_tile_buffers=1, # only used by warp specialized kernels
3713+
has_noloop=0,
3714+
noloop_step=64,
3715+
kv_loop_step=128,
3716+
kv_tile_buffers=2, # only used by warp specialized kernels
3717+
unroll_threshold=1,
3718+
has_scale_max=False,
3719+
flash_attention=True,
3720+
warp_specialization=True,
3721+
alibi=alibi,
3722+
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
3723+
return_softmax_stats=return_softmax,
3724+
scheduling_mode=scheduling_mode,
3725+
input_layout=input_layout))
37223726

3723-
specs.append(
3724-
kernel_spec(
3725-
sm=sm,
3726-
sm_mma=90,
3727-
dtype=dtype,
3728-
seq_len=0, # support any sequence length
3729-
head_size=[160, 192, 256],
3730-
warps_m=4, #4x1 warpgroups
3731-
warps_n=1,
3732-
version=2,
3733-
interleaved=False,
3734-
ldgsts_q=
3735-
False, # for Hopper kernels, ldgsts = False signals TMA usage.
3736-
ldgsts_k=False,
3737-
ldgsts_v=False,
3738-
share_smem_k_v=False,
3739-
loop_step=64,
3740-
q_tile_buffers=1, # only used by warp specialized kernels
3741-
has_noloop=0,
3742-
noloop_step=64,
3743-
kv_loop_step=64,
3744-
kv_tile_buffers=2, # only used by warp specialized kernels
3745-
unroll_threshold=1,
3746-
has_scale_max=False,
3747-
flash_attention=True,
3748-
warp_specialization=True,
3749-
alibi=alibi,
3750-
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
3751-
return_softmax_stats=return_softmax,
3752-
scheduling_mode=scheduling_mode,
3753-
input_layout=input_layout))
3727+
specs.append(
3728+
kernel_spec(
3729+
sm=sm,
3730+
sm_mma=90,
3731+
dtype=dtype,
3732+
seq_len=0, # support any sequence length
3733+
head_size=[160, 192, 256],
3734+
warps_m=4, #4x1 warpgroups
3735+
warps_n=1,
3736+
version=2,
3737+
interleaved=False,
3738+
ldgsts_q=
3739+
False, # for Hopper kernels, ldgsts = False signals TMA usage.
3740+
ldgsts_k=False,
3741+
ldgsts_v=False,
3742+
share_smem_k_v=False,
3743+
loop_step=64,
3744+
q_tile_buffers=1, # only used by warp specialized kernels
3745+
has_noloop=0,
3746+
noloop_step=64,
3747+
kv_loop_step=64,
3748+
kv_tile_buffers=2, # only used by warp specialized kernels
3749+
unroll_threshold=1,
3750+
has_scale_max=False,
3751+
flash_attention=True,
3752+
warp_specialization=True,
3753+
alibi=alibi,
3754+
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
3755+
return_softmax_stats=return_softmax,
3756+
scheduling_mode=scheduling_mode,
3757+
input_layout=input_layout))
37543758
'''
37553759
smem size = (q_step * d * q_buffers * NUM_COMPUTE_GROUPS
37563760
+ (kv_step * d + kv_step * dv) * kv_buffers) * ele_size
@@ -3762,38 +3766,39 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
37623766
Then for fp16/bf16 context MLA, d remains 192 (192 * 2 = 128 * 3), and dv remains 128,
37633767
if kv_step = 128, then smem_size = 208 KB, smem is fully utilized.
37643768
'''
3765-
specs.append(
3766-
kernel_spec(
3767-
sm=sm,
3768-
sm_mma=90,
3769-
dtype=dtype,
3770-
seq_len=0, # support any sequence length
3771-
head_size=192,
3772-
head_size_v=128,
3773-
warps_m=4, #4x1 warpgroups
3774-
warps_n=1,
3775-
version=2,
3776-
interleaved=False,
3777-
ldgsts_q=
3778-
False, # for Hopper kernels, ldgsts = False signals TMA usage.
3779-
ldgsts_k=False,
3780-
ldgsts_v=False,
3781-
share_smem_k_v=False,
3782-
loop_step=64,
3783-
q_tile_buffers=1, # only used by warp specialized kernels
3784-
has_noloop=0,
3785-
noloop_step=64,
3786-
kv_loop_step=128,
3787-
kv_tile_buffers=2, # only used by warp specialized kernels
3788-
unroll_threshold=1,
3789-
has_scale_max=False,
3790-
flash_attention=True,
3791-
warp_specialization=True,
3792-
alibi=alibi,
3793-
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
3794-
return_softmax_stats=return_softmax,
3795-
scheduling_mode=scheduling_mode,
3796-
input_layout=input_layout))
3769+
if not skip_mla_combination:
3770+
specs.append(
3771+
kernel_spec(
3772+
sm=sm,
3773+
sm_mma=90,
3774+
dtype=dtype,
3775+
seq_len=0, # support any sequence length
3776+
head_size=192,
3777+
head_size_v=128,
3778+
warps_m=4, #4x1 warpgroups
3779+
warps_n=1,
3780+
version=2,
3781+
interleaved=False,
3782+
ldgsts_q=
3783+
False, # for Hopper kernels, ldgsts = False signals TMA usage.
3784+
ldgsts_k=False,
3785+
ldgsts_v=False,
3786+
share_smem_k_v=False,
3787+
loop_step=64,
3788+
q_tile_buffers=1, # only used by warp specialized kernels
3789+
has_noloop=0,
3790+
noloop_step=64,
3791+
kv_loop_step=128,
3792+
kv_tile_buffers=2, # only used by warp specialized kernels
3793+
unroll_threshold=1,
3794+
has_scale_max=False,
3795+
flash_attention=True,
3796+
warp_specialization=True,
3797+
alibi=alibi,
3798+
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
3799+
return_softmax_stats=return_softmax,
3800+
scheduling_mode=scheduling_mode,
3801+
input_layout=input_layout))
37973802

37983803

37993804
# Note this will be used in TRT-LLM.

cpp/kernels/fmha_v2/src/fmha/fragment.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1904,8 +1904,7 @@ struct Softmax_saver
19041904
, softmax_sum_ptr_(reinterpret_cast<char*>(params.softmax_stats_ptr))
19051905
, softmax_stats_stride_in_bytes_(params.softmax_stats_stride_in_bytes)
19061906
{
1907-
size_t softmax_max_off = sizeof(float) * params.b * params.s * params.h;
1908-
softmax_max_ptr_ = reinterpret_cast<char*>(params.softmax_stats_ptr) + softmax_max_off;
1907+
softmax_max_ptr_ = reinterpret_cast<char*>(params.softmax_stats_ptr);
19091908

19101909
int warp = threadIdx.x / Cta_tile::THREADS_PER_WARP;
19111910
int lane = threadIdx.x % Cta_tile::THREADS_PER_WARP;
@@ -1917,9 +1916,9 @@ struct Softmax_saver
19171916
store_softmax_ = (lane % 4 == 0 && int(warp / WARPS_M) == 0);
19181917

19191918
// assume fixed seq length for the batch
1920-
size_t const bh_offset = (binfo.sum_s * params.h + binfo.bidh) * sizeof(float);
1921-
softmax_sum_ptr_ += bh_offset + row0_ * params.softmax_stats_stride_in_bytes;
1919+
size_t const bh_offset = (binfo.sum_s * params.h + binfo.bidh) * sizeof(float) * 2;
19221920
softmax_max_ptr_ += bh_offset + row0_ * params.softmax_stats_stride_in_bytes;
1921+
softmax_sum_ptr_ += bh_offset + row0_ * params.softmax_stats_stride_in_bytes + sizeof(float);
19231922
};
19241923

19251924
inline __device__ void store(int q_loop, float* p_sum, float* p_max)
@@ -1938,19 +1937,19 @@ struct Softmax_saver
19381937
int row_offset = q_loop * Cta_tile::M + mi * Mma_tile::M_PER_MMA_PER_CTA;
19391938
if (row0_ + row_offset < actual_q_len_)
19401939
{
1941-
fmha::stg(softmax_sum_ptr_ + row_offset * softmax_stats_stride_in_bytes_, sum0);
19421940
fmha::stg(softmax_max_ptr_ + row_offset * softmax_stats_stride_in_bytes_, max0);
1941+
fmha::stg(softmax_sum_ptr_ + row_offset * softmax_stats_stride_in_bytes_, sum0);
19431942
}
19441943
if (row0_ + row_offset + 8 < actual_q_len_)
19451944
{
1946-
fmha::stg(softmax_sum_ptr_ + (row_offset + 8) * softmax_stats_stride_in_bytes_, sum1);
19471945
fmha::stg(softmax_max_ptr_ + (row_offset + 8) * softmax_stats_stride_in_bytes_, max1);
1946+
fmha::stg(softmax_sum_ptr_ + (row_offset + 8) * softmax_stats_stride_in_bytes_, sum1);
19481947
}
19491948
}
19501949
}
19511950
}
19521951

1953-
// ptr
1952+
// ptr (total_token_q, h, 2) float
19541953
char* softmax_sum_ptr_ = nullptr;
19551954
char* softmax_max_ptr_ = nullptr;
19561955

cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -465,18 +465,17 @@ struct Softmax_saver_tma
465465
, softmax_sum_ptr_(reinterpret_cast<char*>(params.softmax_stats_ptr))
466466
, softmax_stats_stride_in_bytes_(params.softmax_stats_stride_in_bytes)
467467
{
468-
size_t softmax_max_off = sizeof(float) * params.b * params.s * params.h;
469-
softmax_max_ptr_ = reinterpret_cast<char*>(params.softmax_stats_ptr) + softmax_max_off;
468+
softmax_max_ptr_ = reinterpret_cast<char*>(params.softmax_stats_ptr);
470469
int warp = (threadIdx.x % 128) / Cta_tile::THREADS_PER_WARP;
471470
int lane = threadIdx.x % Cta_tile::THREADS_PER_WARP;
472471
// MMA row0 index (8x4 thread layout)
473472
row0_ = warp * Mma_tile::M_PER_MMA / WARPS_M + (lane / 4);
474473

475474
int sum_s = params.is_s_padded ? params.s * head_info.bidb : params.cu_q_seqlens[head_info.bidb];
476475
int token_id = sum_s * params.h + head_info.bidh;
477-
size_t const bh_offset = token_id * sizeof(float) + local_q_tile_offset_ * softmax_stats_stride_in_bytes_;
478-
softmax_sum_ptr_ += bh_offset + row0_ * softmax_stats_stride_in_bytes_;
476+
size_t const bh_offset = token_id * sizeof(float) * 2 + local_q_tile_offset_ * softmax_stats_stride_in_bytes_;
479477
softmax_max_ptr_ += bh_offset + row0_ * softmax_stats_stride_in_bytes_;
478+
softmax_sum_ptr_ += bh_offset + row0_ * softmax_stats_stride_in_bytes_ + sizeof(float);
480479
};
481480

482481
inline __device__ void store(float* p_sum, float* p_max, float sqrt_d, int row_offset, bool valid_run)
@@ -487,7 +486,7 @@ struct Softmax_saver_tma
487486
int lane = threadIdx.x % Cta_tile::THREADS_PER_WARP;
488487
if (lane % 4 < 2)
489488
{
490-
values = p_sum[lane % 2] == 0.f ? 1.f : 1.0f / p_sum[lane % 2];
489+
values = p_sum[lane % 2];
491490
}
492491
else
493492
{

0 commit comments

Comments
 (0)