Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cpp/kernels/xqa/mha.cu
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,8 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
bool const maskFlag = col + actualQSeqLen < nbValidCols
? true
: packedMask & (1u << ((col + actualQSeqLen - nbValidCols) - maskPosStart));
acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : -INFINITY;
acc(m, n)(i, j)
= maskFlag && col < nbValidCols ? acc(m, n)(i, j) : -mha::numeric_limits<float>::infinity();
}
}
}
Expand Down
5 changes: 5 additions & 0 deletions cpp/kernels/xqa/mha_stdheaders.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ public:
{
return -3.40282347E+38F;
}

static constexpr float infinity() noexcept
{
return __int_as_float(0x7f800000);
}
};

template <typename T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,15 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
jit::CubinObj const* const cubinObj = mResource->getCubinObjRegistry()->getCubin(key);
TLLM_CHECK(cubinObj != nullptr && cubinObj->isInitialized());
bool const isSpecDec = xqaParams.multi_query_tokens;
bool const isHMMAKernel = (cubinObj->getKernelType() == XQAKernelType::kAMPERE_WARP_SPECIALIZED);
bool const isGMMAKernel = (cubinObj->getKernelType() == XQAKernelType::kHOPPER_WARP_SPECIALIZED);
bool const isMLAKernel = (cubinObj->getKernelType() == XQAKernelType::kSM120_MLA);
TLLM_CHECK_WITH_INFO(
!isSpecDec || isGMMAKernel || (isMLAKernel && !xqaParams.spec_decoding_is_generation_length_variable),
"speculative decoding is available for GMMA/MLA kernel only in JIT path for now. For MLA, the input sequence "
"length must be uniform and draft tokens must be linear.");
TLLM_CHECK_DEBUG(isGMMAKernel == jit::supportConfigQGMMA(xqaParams, mSM, false));
TLLM_CHECK_WITH_INFO(!isSpecDec || !isMLAKernel || !xqaParams.spec_decoding_is_generation_length_variable,
"speculative decoding for MLA kernel requires that the input sequence length must be uniform and draft tokens "
"must be linear.");
TLLM_CHECK_DEBUG(!isHMMAKernel || jit::supportConfigHMMA(xqaParams, mSM, false));
TLLM_CHECK_DEBUG(!isGMMAKernel || jit::supportConfigQGMMA(xqaParams, mSM, false));
TLLM_CHECK_DEBUG(!isMLAKernel || jit::supportConfigMLA(xqaParams, mSM, false));
// @fixme: also embed these compile-time flags in cubin directly
// Whether RoPE is fused into the XQA kernel.
// * If applyRoPEInXqaKernel is true, XQA kernel applies RoPE AND performs SDPA.
Expand Down Expand Up @@ -355,7 +357,7 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
.mask = reinterpret_cast<SpecDecParams::MaskType const*>(xqaParams.spec_decoding_packed_mask)};
};

constexpr uint32_t kMAX_NB_KERNEL_PARAMS = 15;
constexpr uint32_t kMAX_NB_KERNEL_PARAMS = 20;
uint32_t idxNextParam = 0;
void* kernelParams[kMAX_NB_KERNEL_PARAMS];
auto appendParam = [&](auto* p) mutable
Expand Down Expand Up @@ -391,6 +393,59 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
dim3 const blockDim(128 * 3, 1, 1);
cubinObj->launch(dimGrid, blockDim, stream, kernelParams);
}
else if (isHMMAKernel && isSpecDec)
{
// MultiQueryTokens (generation_input_length > 1) need extra parameters (like qSeqLen, headGrpSize, and
// mask). Input parameters for MultiQueryTokens kernels.
unsigned int headGrpSize = num_q_heads_over_kv;
// Use mTileSize = 16 kernels when qSeqLen <= 16.
unsigned int qSeqLen = static_cast<unsigned int>(xqaParams.generation_input_length);
unsigned int mTileSize = qSeqLen <= 16 ? 16 : 32;
unsigned int nbTokenBlocksPerGrp = divUp(qSeqLen * headGrpSize, mTileSize);
int const* maskPtr = xqaParams.spec_decoding_packed_mask;
int const* cuQSeqLens = launchParams.cu_seq_lens;
unsigned int maxQSeqLen = xqaParams.spec_decoding_is_generation_length_variable ? // true for ReDrafter
xqaParams.spec_decoding_max_generation_length
: qSeqLen;

appendParam(&maxQSeqLen);
appendParam(&launchParams.num_k_heads);
appendParam(&headGrpSize);
appendParam(&cuQSeqLens);
bool const allowSlidingWindow
= !(isSpecDec && xqaParams.is_spec_dec_tree); // sliding windows does not support spec dec with tree-based
// token, only chained tokens
if (allowSlidingWindow)
{
appendParam(&launchParams.slidingWindowSize);
}
appendParam(&launchParams.qScale);
appendParam(&launchParams.output);
if (isFp8Out && !needOutputCvt)
{
appendParam(&launchParams.rcpOutScale);
}
appendParam(&kernel_input_tokens);
TLLM_CHECK_DEBUG(!applyRoPEInXqaKernel);
appendParam(&maskPtr);
appendParam(&xqaParams.attention_sinks);
appendParam(&launchParams.kvCacheParams);
if (xqaParams.beam_width > 1)
{
appendParam(&launchParams.beamSearchParams.value());
}
appendParam(&launchParams.batch_size);
appendParam(&launchParams.kv_scale_quant_orig);
appendParam(&launchParams.semaphores);
appendParam(&launchParams.scratch);
kernelParams[idxNextParam] = nullptr; // one extra nullptr at end as guard.
// TODO: merge SingleQueryToken params and MultiQueryTokens params into one kernelParams.
// XQA HMMA Spec-dec kernel (previously precompiled) does not support multi-block mode
uint32_t multi_block = 1;
auto const gridDim = dim3{multi_block, xqaParams.num_kv_heads * nbTokenBlocksPerGrp, xqaParams.batch_size};
dim3 const blockDim(128, 1, 2);
cubinObj->launch(gridDim, blockDim, stream, kernelParams);
}
else
{
appendParam(&launchParams.num_k_heads);
Expand Down Expand Up @@ -430,8 +485,7 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
SpecDecParams specDecParams{};
if (isSpecDec)
{
TLLM_CHECK_WITH_INFO(
isGMMAKernel, "speculative decoding is available for GMMA kernel only in JIT path for now.");
TLLM_CHECK_DEBUG(isGMMAKernel);
TLLM_CHECK_DEBUG_WITH_INFO(xqaParams.max_past_kv_length + 1 <= xqaParams.cyclic_attention_window_size,
"SWA and speculative decoding cannot be used at the same time for now.");
specDecParams = makeSpecDecParams();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,6 @@ constexpr inline T roundUp(T a, T b)

DecoderXQAImpl* DecoderXQARunner::getImplFromXQAParams(XQAParams const& xqaParams, bool for_configure_plugin)
{
int const smVersion = tensorrt_llm::common::getSMVersion();
if (xqaParams.multi_query_tokens)
{
auto const grpSize = xqaParams.num_q_heads / xqaParams.num_kv_heads;
// Ampere XQA supports spec dec with pre-compiled cubins (may also work with JIT but not implemented yet)
// Hopper XQA supports spec dec with JIT, but only for E4M3 kv cache data type. Only allow 64%grpSize==0 for
// now.
bool const supportedByHopperXqa
= (smVersion == 90 && xqaParams.kv_cache_data_type == XQADataType::DATA_TYPE_E4M3 && grpSize <= 64);
bool const supportedBySm120Mla
= (smVersion == 120 && xqaParams.isMLA() && xqaParams.kv_cache_data_type == XQADataType::DATA_TYPE_E4M3);
return (supportedByHopperXqa || supportedBySm120Mla) ? mJITImpl.get() : mPrecompiledImpl.get();
}

std::optional<bool> envEnableXQAJIT = tensorrt_llm::common::getEnvEnableXQAJIT();

if (envEnableXQAJIT.has_value())
Expand Down