Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion cpp/tensorrt_llm/common/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,9 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
xqaParams.fp4_out_sf_scale = generationsParams.attention_output_sf_scale;
xqaParams.start_token_idx_sf = generationsParams.start_token_idx_sf;

// Cross attention parameters.
xqaParams.encoder_input_lengths = generationsParams.encoder_input_lengths;

return true;
}

Expand Down Expand Up @@ -2229,6 +2232,10 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
{
TLLM_CHECK_WITH_INFO(false, "No available kernels are found for FP4 output.");
}
else
{
TLLM_LOG_DEBUG("XQA kernels are not selected in the generation phase.");
}
}

// This is the number of kv tokens that q needs to visit, but excluding one as it will be processed before the kv
Expand Down Expand Up @@ -2750,7 +2757,7 @@ int AttentionOp::initialize() noexcept
!useCustomMask() || mEnableContextFMHA, "Only Context FMHA supports custom mask input currently.");
}

mEnableXQA = (mEnableXQA || mIsSpecDecodingEnabled) && !mCrossAttention
mEnableXQA = (mEnableXQA || mIsSpecDecodingEnabled)
&& (mType == nvinfer1::DataType::kHALF || mType == nvinfer1::DataType::kBF16) && mUseKVCache;

if (mEnableXQA)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ struct XQAParams

void* quant_q_buffer_ptr = nullptr;

// for cross attention
int32_t const* encoder_input_lengths = nullptr;

cudaStream_t stream = 0;

std::string toString() const
Expand Down Expand Up @@ -175,6 +178,7 @@ struct XQAParams
<< "total_num_input_tokens :" << total_num_input_tokens << std ::endl
<< "is_fp8_output :" << (is_fp8_output ? "true" : "false") << std ::endl
<< "fp8_out_scale :" << fp8_out_scale << std ::endl
<< "encoder_input_lengths: " << encoder_input_lengths << std::endl
<< "stream :" << stream;

return ss.str();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1348,15 +1348,17 @@ __global__ void updateKVCacheForCrossAttention(QKVPreprocessingParams<T, KVCache
int const batch_idx = blockIdx.z;

// The decoder sequence length.
int const decoder_seq_len = params.seq_lens[batch_idx];
// Spec decoding not supported for cross-attention at the moment so we can set 1 and batch_idx here
int const decoder_seq_len = params.generation_phase ? 1 : params.seq_lens[batch_idx];
// The decoder sequence offset.
int const decoder_seq_offset = params.cu_seq_lens[batch_idx];
int const decoder_seq_offset = params.generation_phase ? batch_idx : params.cu_seq_lens[batch_idx];
// The decoder cache sequence length (includes the current input).
int const decoder_cache_seq_len = params.cache_seq_lens[batch_idx];
// The encoder sequence length.
int const encoder_seq_len = params.encoder_seq_lens[batch_idx];
// The encoder sequence offset.
int const encoder_seq_offset = params.cu_kv_seq_lens[batch_idx];
// Not needed in Gen phase
int const encoder_seq_offset = params.generation_phase ? -1 : params.cu_kv_seq_lens[batch_idx];
// THe maximum sequence length of encoder and decoder.
int const max_seq_len = max(decoder_seq_len, encoder_seq_len);

Expand Down Expand Up @@ -1411,45 +1413,49 @@ __global__ void updateKVCacheForCrossAttention(QKVPreprocessingParams<T, KVCache
}
}

// Encoder tokens (i.e. KV tokens).
if (head_idx == (kv_head_idx * params.qheads_per_kv_head) && token_idx < encoder_seq_len
&& store_encoder_kv_cache && params.kv_cache_buffer.data != nullptr)
if (!params.generation_phase)
{
// The global token idx in all sequences.
int global_token_idx = token_idx + encoder_seq_offset;

// The memory offset.
auto const src_k_idx = static_cast<size_t>(global_token_idx) * params.kv_hidden_size * 2 + hidden_idx_kv;
auto const src_v_idx
= static_cast<size_t>(global_token_idx) * params.kv_hidden_size * 2 + src_v_offset + hidden_idx_kv;

// Only load K,V tokens from encoder qkv input.
auto k = *reinterpret_cast<VecT const*>(&params.cross_kv_input[src_k_idx]);
auto v = *reinterpret_cast<VecT const*>(&params.cross_kv_input[src_v_idx]);

// The kv cache pointers.
auto k_cache_block_ptr
= reinterpret_cast<TCache*>(params.kv_cache_buffer.getKBlockPtr(batch_idx, token_idx));
auto v_cache_block_ptr
= reinterpret_cast<TCache*>(params.kv_cache_buffer.getVBlockPtr(batch_idx, token_idx));
// The vector idx in the cache block.
auto block_vec_idx
= params.kv_cache_buffer.getKVLocalIdx(token_idx, kv_head_idx, VECS_PER_HEAD, head_dim_vec_idx);

// Store K and V to the cache.
// INT8/FP8 kv cache.
if constexpr (sizeof(TCache) == 1)
{
// The element index inside the block.
auto block_elt_idx = block_vec_idx * ELTS_PER_VEC;
// Store 8bits kv cache.
mmha::store_8bits_vec(k_cache_block_ptr, k, block_elt_idx, scale_orig_quant);
mmha::store_8bits_vec(v_cache_block_ptr, v, block_elt_idx, scale_orig_quant);
}
else
// Encoder tokens (i.e. KV tokens).
if (head_idx == (kv_head_idx * params.qheads_per_kv_head) && token_idx < encoder_seq_len
&& store_encoder_kv_cache && params.kv_cache_buffer.data != nullptr)
{
reinterpret_cast<VecT*>(k_cache_block_ptr)[block_vec_idx] = k;
reinterpret_cast<VecT*>(v_cache_block_ptr)[block_vec_idx] = v;
// The global token idx in all sequences.
int global_token_idx = token_idx + encoder_seq_offset;

// The memory offset.
auto const src_k_idx
= static_cast<size_t>(global_token_idx) * params.kv_hidden_size * 2 + hidden_idx_kv;
auto const src_v_idx
= static_cast<size_t>(global_token_idx) * params.kv_hidden_size * 2 + src_v_offset + hidden_idx_kv;

// Only load K,V tokens from encoder qkv input.
auto k = *reinterpret_cast<VecT const*>(&params.cross_kv_input[src_k_idx]);
auto v = *reinterpret_cast<VecT const*>(&params.cross_kv_input[src_v_idx]);

// The kv cache pointers.
auto k_cache_block_ptr
= reinterpret_cast<TCache*>(params.kv_cache_buffer.getKBlockPtr(batch_idx, token_idx));
auto v_cache_block_ptr
= reinterpret_cast<TCache*>(params.kv_cache_buffer.getVBlockPtr(batch_idx, token_idx));
// The vector idx in the cache block.
auto block_vec_idx
= params.kv_cache_buffer.getKVLocalIdx(token_idx, kv_head_idx, VECS_PER_HEAD, head_dim_vec_idx);

// Store K and V to the cache.
// INT8/FP8 kv cache.
if constexpr (sizeof(TCache) == 1)
{
// The element index inside the block.
auto block_elt_idx = block_vec_idx * ELTS_PER_VEC;
// Store 8bits kv cache.
mmha::store_8bits_vec(k_cache_block_ptr, k, block_elt_idx, scale_orig_quant);
mmha::store_8bits_vec(v_cache_block_ptr, v, block_elt_idx, scale_orig_quant);
}
else
{
reinterpret_cast<VecT*>(k_cache_block_ptr)[block_vec_idx] = k;
reinterpret_cast<VecT*>(v_cache_block_ptr)[block_vec_idx] = v;
}
}
}
}
Expand Down
160 changes: 95 additions & 65 deletions cpp/tensorrt_llm/kernels/xqaDispatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

#include "xqaDispatcher.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h"
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
#include <cstdint>

namespace
{
Expand All @@ -38,6 +40,87 @@ constexpr inline T roundUp(T a, T b)
namespace tensorrt_llm::kernels
{

namespace
{

template <typename T, typename KVCacheBuffer>
QKVPreprocessingParams<T, KVCacheBuffer> makeQKVPreprocessingParams(XQAParams const& params,
XQALaunchParam<KVCacheBuffer> const& launchParams, void* xqa_q_input_ptr, Data_type QDataType,
KvCacheDataType cache_type, int32_t batch_beam_size, KVCacheBuffer const& kv_cache_buffer,
int32_t const* cu_seqlens, int32_t const* cu_kv_seqlens, float const* rotary_inv_freq_buf, int multiProcessorCount)
{
QKVPreprocessingParams<T, KVCacheBuffer> preprocessingParms;
memset(&preprocessingParms, 0, sizeof(preprocessingParms));
// Set parameters.
preprocessingParms.qkv_input = static_cast<T*>(const_cast<void*>(params.qkv));
preprocessingParms.q_output = static_cast<T*>(xqa_q_input_ptr);
preprocessingParms.kv_cache_buffer = kv_cache_buffer;
preprocessingParms.kv_cache_block_scales_buffer = {};
preprocessingParms.qkv_bias = static_cast<T const*>(params.qkv_bias);
// Prepare values for fmha.
preprocessingParms.fmha_bmm1_scale = launchParams.bmm1_scale_ptr;
preprocessingParms.fmha_bmm2_scale = launchParams.bmm2_scale_ptr;
bool const is_fp8_q_input = (QDataType == DATA_TYPE_E4M3);
if (params.kv_cache_quant_mode.hasFp8KvCache())
{
preprocessingParms.q_scale_quant_orig = params.kv_scale_quant_orig;
preprocessingParms.kv_scale_quant_orig = params.kv_scale_quant_orig;
}
if (params.is_fp8_output)
{
preprocessingParms.o_scale_orig_quant = params.fp8_out_scale;
}
// Buffers.
preprocessingParms.logn_scaling = params.logn_scaling_ptr;
preprocessingParms.seq_lens = params.spec_decoding_generation_lengths;
preprocessingParms.cache_seq_lens = params.sequence_lengths;
preprocessingParms.cu_seq_lens = cu_seqlens;
preprocessingParms.rotary_embedding_inv_freq = rotary_inv_freq_buf;
preprocessingParms.rotary_coef_cache_buffer = params.rotary_cos_sin;
preprocessingParms.kvScaleOrigQuant = params.kv_scale_orig_quant;
preprocessingParms.kv_cache_scale_factors = nullptr;
preprocessingParms.spec_decoding_position_offsets
= params.cross_attention ? nullptr : params.spec_decoding_position_offsets;
preprocessingParms.mrope_position_deltas = params.mrope_position_deltas;
// Scalar parameters.
preprocessingParms.batch_size = int(batch_beam_size);
preprocessingParms.max_input_seq_len = params.generation_input_length;
preprocessingParms.max_kv_seq_len = params.max_past_kv_length;
preprocessingParms.cyclic_kv_cache_len
= params.cross_attention ? params.max_past_kv_length : params.cyclic_attention_window_size;
preprocessingParms.sink_token_len = params.cross_attention ? 0 : params.sink_token_length;
preprocessingParms.token_num = params.total_num_input_tokens;
preprocessingParms.remove_padding = true;
preprocessingParms.cross_attention = params.cross_attention;
preprocessingParms.head_num = params.num_q_heads;
preprocessingParms.kv_head_num = params.num_kv_heads;
preprocessingParms.qheads_per_kv_head = params.num_q_heads / params.num_kv_heads;
preprocessingParms.size_per_head = params.head_size;
preprocessingParms.fmha_host_bmm1_scale = 1.0f / (sqrtf(params.head_size * 1.0f) * params.q_scaling);
preprocessingParms.rotary_embedding_dim = params.rotary_embedding_dim;
preprocessingParms.rotary_embedding_base = params.rotary_embedding_base;
preprocessingParms.rotary_scale_type = params.rotary_embedding_scale_type;
preprocessingParms.rotary_embedding_scale = params.rotary_embedding_scale;
preprocessingParms.rotary_embedding_max_positions = params.rotary_embedding_max_positions;
preprocessingParms.position_embedding_type = params.position_embedding_type;
preprocessingParms.position_shift_enabled = params.position_shift_enabled;
preprocessingParms.cache_type = cache_type;
preprocessingParms.separate_q_kv_output = true;
preprocessingParms.quantized_fp8_output = is_fp8_q_input;
preprocessingParms.generation_phase = true;
preprocessingParms.multi_processor_count = multiProcessorCount;
preprocessingParms.rotary_vision_start = params.rotary_vision_start;
preprocessingParms.rotary_vision_length = params.rotary_vision_length;

// Cross-attention only.

preprocessingParms.encoder_seq_lens = params.encoder_input_lengths;

return preprocessingParms;
}

} // namespace

////////////////////////////////////////////////////////////////////////////////////////////////////

XqaDispatcher::XqaDispatcher(XqaFixedParams fixedParams)
Expand Down Expand Up @@ -137,9 +220,10 @@ bool XqaDispatcher::shouldUse(XQAParams const& params)
{
SHOULD_NOT_USE("Fallback to MMHA as unidirectional is not supported by TRTLLM-GEN kernels.");
}
if (params.cross_attention)
if (params.cross_attention && !params.paged_kv_cache)
{
SHOULD_NOT_USE("Fallback to MMHA as cross attention is not supported by TRTLLM-GEN kernels.");
SHOULD_NOT_USE(
"Fallback to MMHA as cross attention without paged KV Cache is not supported by TRTLLM-GEN kernels.");
}
if (params.paged_kv_cache && params.tokens_per_block < 8)
{
Expand Down Expand Up @@ -252,8 +336,8 @@ void XqaDispatcher::runImpl(XQAParams params, KVCacheBuffer const& kv_cache_buff
decoder_params.seqQOffsets = launchParams.cu_seq_lens;
decoder_params.seqKVOffsets = launchParams.cu_kv_seq_lens;
decoder_params.seqQLengths = params.spec_decoding_generation_lengths;
decoder_params.seqKVLengths = params.sequence_lengths;
decoder_params.batchSize = int(batch_beam_size);
decoder_params.seqKVLengths = params.cross_attention ? params.encoder_input_lengths : params.sequence_lengths;
decoder_params.batchSize = static_cast<int>(batch_beam_size);
decoder_params.maxQSeqLength = params.generation_input_length;
decoder_params.numTokens = params.total_num_input_tokens;
decoder_params.removePadding = true;
Expand All @@ -273,10 +357,12 @@ void XqaDispatcher::runImpl(XQAParams params, KVCacheBuffer const& kv_cache_buff
float const* rotary_inv_freq_buf = params.rotary_embedding_inv_freq_cache;
// Use the nullptr for cu_seqlens when it is not computed.
int const* cu_seqlens{nullptr};
int const* cu_kv_seqlens{nullptr};
if (decoder_params.isBuildDecoderInfoKernelNeeded())
{
rotary_inv_freq_buf = launchParams.rotary_inv_freq_buf;
cu_seqlens = launchParams.cu_seq_lens;
cu_kv_seqlens = launchParams.cu_kv_seq_lens;
invokeBuildDecoderInfo(decoder_params, params.stream);
sync_check_cuda_error(params.stream);
}
Expand All @@ -285,66 +371,10 @@ void XqaDispatcher::runImpl(XQAParams params, KVCacheBuffer const& kv_cache_buff
// NOTE: MHA kernels should read kv cache that has already been appended with new tokens' kv cache.
void* xqa_q_input_ptr = inputScratch;
// The preprocessing kernel that applies RoPE and updates kv cache.
QKVPreprocessingParams<T, KVCacheBuffer> preprocessingParms;
memset(&preprocessingParms, 0, sizeof(preprocessingParms));
// Set parameters.
preprocessingParms.qkv_input = static_cast<T*>(const_cast<void*>(params.qkv));
preprocessingParms.q_output = static_cast<T*>(xqa_q_input_ptr);
preprocessingParms.kv_cache_buffer = kv_cache_buffer;
preprocessingParms.kv_cache_block_scales_buffer = {};
preprocessingParms.qkv_bias = static_cast<T const*>(params.qkv_bias);
// Prepare values for fmha.
preprocessingParms.fmha_bmm1_scale = launchParams.bmm1_scale_ptr;
preprocessingParms.fmha_bmm2_scale = launchParams.bmm2_scale_ptr;
bool const is_fp8_q_input = (mQDataType == DATA_TYPE_E4M3);
if (params.kv_cache_quant_mode.hasFp8KvCache())
{
preprocessingParms.q_scale_quant_orig = params.kv_scale_quant_orig;
preprocessingParms.kv_scale_quant_orig = params.kv_scale_quant_orig;
}
if (params.is_fp8_output)
{
preprocessingParms.o_scale_orig_quant = params.fp8_out_scale;
}
// Buffers.
preprocessingParms.logn_scaling = params.logn_scaling_ptr;
preprocessingParms.seq_lens = params.spec_decoding_generation_lengths;
preprocessingParms.cache_seq_lens = params.sequence_lengths;
preprocessingParms.cu_seq_lens = cu_seqlens;
preprocessingParms.rotary_embedding_inv_freq = rotary_inv_freq_buf;
preprocessingParms.rotary_coef_cache_buffer = params.rotary_cos_sin;
preprocessingParms.kvScaleOrigQuant = params.kv_scale_orig_quant;
preprocessingParms.kv_cache_scale_factors = nullptr;
preprocessingParms.spec_decoding_position_offsets = params.spec_decoding_position_offsets;
preprocessingParms.mrope_position_deltas = params.mrope_position_deltas;
// Scalar parameters.
preprocessingParms.batch_size = int(batch_beam_size);
preprocessingParms.max_input_seq_len = params.generation_input_length;
preprocessingParms.max_kv_seq_len = params.max_past_kv_length;
preprocessingParms.cyclic_kv_cache_len = params.cyclic_attention_window_size;
preprocessingParms.sink_token_len = params.sink_token_length;
preprocessingParms.token_num = params.total_num_input_tokens;
preprocessingParms.remove_padding = true;
preprocessingParms.cross_attention = false;
preprocessingParms.head_num = params.num_q_heads;
preprocessingParms.kv_head_num = params.num_kv_heads;
preprocessingParms.qheads_per_kv_head = params.num_q_heads / params.num_kv_heads;
preprocessingParms.size_per_head = params.head_size;
preprocessingParms.fmha_host_bmm1_scale = 1.0f / (sqrtf(params.head_size * 1.0f) * params.q_scaling);
preprocessingParms.rotary_embedding_dim = params.rotary_embedding_dim;
preprocessingParms.rotary_embedding_base = params.rotary_embedding_base;
preprocessingParms.rotary_scale_type = params.rotary_embedding_scale_type;
preprocessingParms.rotary_embedding_scale = params.rotary_embedding_scale;
preprocessingParms.rotary_embedding_max_positions = params.rotary_embedding_max_positions;
preprocessingParms.position_embedding_type = params.position_embedding_type;
preprocessingParms.position_shift_enabled = params.position_shift_enabled;
preprocessingParms.cache_type = cache_type;
preprocessingParms.separate_q_kv_output = true;
preprocessingParms.quantized_fp8_output = is_fp8_q_input;
preprocessingParms.generation_phase = true;
preprocessingParms.multi_processor_count = mMultiProcessorCount;
preprocessingParms.rotary_vision_start = params.rotary_vision_start;
preprocessingParms.rotary_vision_length = params.rotary_vision_length;

auto preprocessingParms = makeQKVPreprocessingParams<T, KVCacheBuffer>(params, launchParams, xqa_q_input_ptr,
mQDataType, cache_type, batch_beam_size, kv_cache_buffer, cu_seqlens, cu_kv_seqlens, rotary_inv_freq_buf,
mMultiProcessorCount);

invokeQKVPreprocessing<T, KVCacheBuffer>(preprocessingParms, params.stream);
sync_check_cuda_error(params.stream);
Expand Down Expand Up @@ -394,7 +424,7 @@ void XqaDispatcher::runImpl(XQAParams params, KVCacheBuffer const& kv_cache_buff
= reinterpret_cast<float const*>(launchParams.bmm1_scale_ptr + kIdxScaleSoftmaxLog2Ptr);
tllmRunnerParams.oSfScalePtr = params.fp4_out_sf_scale;
// The sequence lengths for K/V.
tllmRunnerParams.seqLensKvPtr = params.sequence_lengths;
tllmRunnerParams.seqLensKvPtr = params.cross_attention ? params.encoder_input_lengths : params.sequence_lengths;

tllmRunnerParams.oPtr = params.output;
tllmRunnerParams.oSfPtr = params.output_sf;
Expand Down