diff --git a/3rdparty/trt_fused_multihead_attention/fused_multihead_attention_v2.h b/3rdparty/trt_fused_multihead_attention/fused_multihead_attention_v2.h index f04e74ade..dbe76e5d3 100644 --- a/3rdparty/trt_fused_multihead_attention/fused_multihead_attention_v2.h +++ b/3rdparty/trt_fused_multihead_attention/fused_multihead_attention_v2.h @@ -5390,8 +5390,8 @@ class FusedMultiHeadAttentionXMMAKernelV2: virtual uint64_t hashID(const KernelMeta& kernelMeta) const { - assert(kernelMeta.mD == 64 || kernelMeta.mD == 32 || kernel.mD == 40 || kernel.mD == 80 || kernel.mD == 128 - || kernel.mD == 160 || kernel.mD == 256); + assert(kernelMeta.mD == 64 || kernelMeta.mD == 32 || kernelMeta.mD == 40 || kernelMeta.mD == 80 || kernelMeta.mD == 128 + || kernelMeta.mD == 160 || kernelMeta.mD == 256); return hashID(kernelMeta.mS, kernelMeta.mD, kernelMeta.mInterleaved, diff --git a/3rdparty/trt_fused_multihead_attention/qkvToContext.cu b/3rdparty/trt_fused_multihead_attention/qkvToContext.cu index bce5af3e3..54a07c022 100644 --- a/3rdparty/trt_fused_multihead_attention/qkvToContext.cu +++ b/3rdparty/trt_fused_multihead_attention/qkvToContext.cu @@ -94,7 +94,9 @@ public: warps_n = 8; } else { - assert(false && "Unsupporte seqlen"); + // S >= 512, flash attention + warps_m = 4; + warps_n = 1; } } else { @@ -111,7 +113,9 @@ public: warps_n = 8; } else { - assert(false && "Unsupporte seqlen"); + // S >= 512, flash attention + warps_m = 4; + warps_n = 1; } } // The number of threads per CTA.