diff --git a/lmdeploy/pytorch/kernels/cuda/flashattention.py b/lmdeploy/pytorch/kernels/cuda/flashattention.py index 3cca3f9b8d..51c82e7400 100644 --- a/lmdeploy/pytorch/kernels/cuda/flashattention.py +++ b/lmdeploy/pytorch/kernels/cuda/flashattention.py @@ -444,9 +444,9 @@ def _flash_prefill_fwd_kernel_with_mask( stride_os: tl.constexpr, stride_oh: tl.constexpr, stride_od: tl.constexpr, - stride_amb: tl.constexpr, - stride_amqs: tl.constexpr, - stride_amkvs: tl.constexpr, + stride_amb, + stride_amqs, + stride_amkvs, kv_group_num, head_dim_k, head_dim_v,