Skip to content

Commit 581cc30

Browse files
brb-nvsymphonylyh
authored andcommitted
AGI 0804 cherry-pick NVIDIA#6499, NVIDIA#6526, NVIDIA#4416
fix: Fix poor generation with FP8 Gemma3 1B checkpoint (NVIDIA#6499) Signed-off-by: Balaram Buddharaju <[email protected]> [None][fix] Serialize the window_size in the kv event (NVIDIA#6526) Signed-off-by: richardhuo-nv <[email protected]> [None][feat] Multi-block mode for Hopper spec dec XQA kernel (NVIDIA#4416) Signed-off-by: Jhao-Ting Chen <[email protected]>
1 parent 6ecaeee commit 581cc30

File tree

14 files changed

+161
-50
lines changed

14 files changed

+161
-50
lines changed

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2074,35 +2074,31 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
20742074
debugCheckSemaphores(stream);
20752075
#endif
20762076

2077-
// Medusa doesn't support multi-block mode.
2078-
if (!(mIsSpecDecodingEnabled && mUseSpecDecoding))
2077+
if (params.runtime_perf_knobs)
20792078
{
2080-
if (params.runtime_perf_knobs)
2081-
{
2082-
int64_t multi_block_mode_val = params.runtime_perf_knobs[0];
2083-
mMultiBlockMode = multi_block_mode_val == 1;
2084-
if (common::getEnvForceDeterministicAttention())
2085-
{
2086-
mMultiBlockMode = false;
2087-
}
2088-
}
2089-
2079+
int64_t multi_block_mode_val = params.runtime_perf_knobs[0];
2080+
mMultiBlockMode = multi_block_mode_val == 1;
20902081
if (common::getEnvForceDeterministicAttention())
20912082
{
20922083
mMultiBlockMode = false;
20932084
}
2085+
}
20942086

2095-
// TODO only for debug usage
2096-
if (!mMultiBlockMode)
2097-
{
2098-
char* isForceMultiBlockModeChar = std::getenv("FORCE_MULTI_BLOCK_MODE");
2099-
bool isForceMultiBlockMode
2100-
= (isForceMultiBlockModeChar != nullptr && std::string(isForceMultiBlockModeChar) == "ON");
2101-
TLLM_CHECK_WITH_INFO(!(common::getEnvForceDeterministicAttention() && isForceMultiBlockMode),
2102-
"FORCE_MULTI_BLOCK_MODE and FORCE_DETERMINISTIC/FORCE_ATTENTION_KERNEL_DETERMINISTIC can not be set at "
2103-
"the same time.");
2104-
mMultiBlockMode = isForceMultiBlockMode;
2105-
}
2087+
if (common::getEnvForceDeterministicAttention())
2088+
{
2089+
mMultiBlockMode = false;
2090+
}
2091+
2092+
// TODO only for debug usage
2093+
if (!mMultiBlockMode)
2094+
{
2095+
char* isForceMultiBlockModeChar = std::getenv("FORCE_MULTI_BLOCK_MODE");
2096+
bool isForceMultiBlockMode
2097+
= (isForceMultiBlockModeChar != nullptr && std::string(isForceMultiBlockModeChar) == "ON");
2098+
TLLM_CHECK_WITH_INFO(!(common::getEnvForceDeterministicAttention() && isForceMultiBlockMode),
2099+
"FORCE_MULTI_BLOCK_MODE and FORCE_DETERMINISTIC/FORCE_ATTENTION_KERNEL_DETERMINISTIC can not be set at "
2100+
"the same time.");
2101+
mMultiBlockMode = isForceMultiBlockMode;
21062102
}
21072103

21082104
// Check that the chunked-attention and sliding-window-attention are not enabled at the same time.
@@ -2720,7 +2716,6 @@ int AttentionOp::initialize() noexcept
27202716
{
27212717
fixedParams.outputDataType = DATA_TYPE_E4M3;
27222718
TLLM_CHECK_WITH_INFO(mNumHeads % mNumKVHeads == 0, "mNumHeads should be multiples of mNumKVHeads.");
2723-
TLLM_CHECK_WITH_INFO(!mMultiBlockMode, "Medusa doesn't support multi-block mode.");
27242719
}
27252720
fixedParams.numQHeads = mNumAttnHeads;
27262721
fixedParams.numKvHeads = mNumAttnKVHeads;

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,5 +378,90 @@ inline int computeMultiBlockCountForMLA(XQAParams const& xqaParams, int multipro
378378
return 1; // disable multi-block for MLA kernel for now.
379379
}
380380

381+
inline int computeMultiBlockCountSpecDecGMMA(
382+
XQAParams const& xqaParams, int batch_size, int multiprocessor_count, int specDecBlocks)
383+
{
384+
auto const userSpecified = tensorrt_llm::common::getEnvXqaBlocksPerSequence();
385+
if (userSpecified.has_value())
386+
{
387+
return userSpecified.value();
388+
}
389+
int multi_block_count = 1;
390+
391+
int num_kv_heads = xqaParams.num_kv_heads;
392+
int history_length = xqaParams.max_past_kv_length;
393+
394+
// skip tuning for large BS or short ISL case.
395+
if (batch_size > 32 || history_length < 2048)
396+
{
397+
return multi_block_count;
398+
}
399+
400+
// gridDim = dim3{specDecBlocks, multi_block, nbKVHeads * xqaParams.batch_size}
401+
int single_block_count = specDecBlocks * num_kv_heads * batch_size;
402+
double wave_count = (double) single_block_count / (double) multiprocessor_count;
403+
404+
// Multi block tuning for low CTA: populating CTAs to at most 1 wave of SMs
405+
if (wave_count < 1)
406+
{
407+
auto highestPowerof2 = [](int x)
408+
{
409+
x |= x >> 1;
410+
x |= x >> 2;
411+
x |= x >> 4;
412+
x |= x >> 8;
413+
x |= x >> 16;
414+
return x ^ (x >> 1);
415+
};
416+
417+
// calculate the maximum blocks to be populated at most 1 wave
418+
multi_block_count = floor(multiprocessor_count / single_block_count);
419+
// make multi_block_count a power of 2 for tuning convenience.
420+
multi_block_count = highestPowerof2(multi_block_count);
421+
// make multi_block_count at most 64 and at least 1.
422+
multi_block_count = std::min(multi_block_count, 64);
423+
multi_block_count = std::max(multi_block_count, 1);
424+
425+
// tune only when original CTA is too small, multi_block_count is too big, and history length < 2^16
426+
// For Hopper, most cases there are 114, 132, 144 SMs. For H20 about 78.
427+
// single_block_count = [1..8]
428+
// multi_block_count = [16,32,64,128]
429+
// history_length = [1024..65536]
430+
if (single_block_count <= 8 && multi_block_count >= 16 && history_length < 65536)
431+
{
432+
if (history_length < 2048)
433+
{
434+
// for history length < 2048 and low CTA, scaling is not effective, so we set a hard limit to
435+
// multi_block_count = 4
436+
multi_block_count = std::min(multi_block_count, 4);
437+
}
438+
else if (history_length < 65536)
439+
{
440+
// at single_block == 8, multi_block_count can only be 16. (SM / 8 ~= 16)
441+
// tune only 2048 <= kvlen < 8192
442+
if (single_block_count == 8 && history_length <= 8192)
443+
{
444+
multi_block_count >>= 1;
445+
}
446+
else
447+
{
448+
auto getLog2 = [](int x) { return x ? 31 - __builtin_clz(x) : -1; };
449+
auto history_length_log2 = getLog2(history_length);
450+
// Adjust multi_block_count based on history length using formula:
451+
// shift_amount = 3 - (log2(history_length) - 10) / 2
452+
// This gives us:
453+
// - history_length in [2^11, 2^12): shift by 3
454+
// - history_length in [2^13, 2^14): shift by 2
455+
// - history_length in [2^15, 2^16): shift by 1
456+
multi_block_count >>= 3 - (history_length_log2 - 10) / 2;
457+
}
458+
}
459+
}
460+
TLLM_CHECK_WITH_INFO((multi_block_count * single_block_count) <= multiprocessor_count,
461+
"The adjusted MultiBlock exceed number of SMs, adding additional wave may result to perf drop.");
462+
}
463+
return multi_block_count;
464+
}
465+
381466
} // namespace kernels
382467
} // namespace tensorrt_llm

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,15 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
449449
uint32_t multi_block = 1;
450450
if (xqaParams.multi_block_mode)
451451
{
452-
multi_block = computeMultiBlockCount(xqaParams, xqaParams.batch_size, multiprocessor_count);
452+
if (isSpecDec && isGMMAKernel)
453+
{
454+
multi_block = computeMultiBlockCountSpecDecGMMA(
455+
xqaParams, xqaParams.batch_size, multiprocessor_count, specDecBlocks);
456+
}
457+
else if (!isSpecDec)
458+
{
459+
multi_block = computeMultiBlockCount(xqaParams, xqaParams.batch_size, multiprocessor_count);
460+
}
453461
}
454462
uint32_t const nbKVHeads = xqaParams.num_kv_heads;
455463
auto const gridDim = (isGMMAKernel ? dim3{specDecBlocks, multi_block, nbKVHeads * xqaParams.batch_size}

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -285,14 +285,8 @@ class XQAKernelList
285285
void* kernelParams[] = {&maxQSeqLen, &launchParams.num_k_heads, &headGrpSize, &cuQSeqLens,
286286
&launchParams.output, &xqa_q_input_ptr, &maskPtr, &launchParams.kvCacheParams, &launchParams.batch_size,
287287
&launchParams.kv_scale_quant_orig, &launchParams.scratch};
288+
// precompiled XQA Spec-dec kernel does not support multi-block mode
288289
int multi_block = 1;
289-
if (xqaParams.multi_block_mode)
290-
{
291-
multi_block = computeMultiBlockCount(xqaParams, xqaParams.batch_size, multiprocessor_count);
292-
check_cuda_error(cudaMemsetAsync(xqaParams.workspaces, 0,
293-
sizeof(int) * xqaParams.batch_size * qSeqLen * xqaParams.num_kv_heads, stream));
294-
sync_check_cuda_error(stream);
295-
}
296290
TLLM_CU_CHECK(mDriver->cuLaunchKernel(func, multi_block, xqaParams.num_kv_heads * nbTokenBlocksPerGrp,
297291
xqaParams.batch_size, 128, 1, 2, shared_mem_bytes, stream, kernelParams, nullptr));
298292
}

cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads,
7272
mMaskType = mask_type;
7373
mBlockSparseParams = block_sparse_params;
7474
mType = type;
75-
mMultiBlockMode
76-
= is_spec_decoding_enabled ? false : true; // set to true in build time to account for enough workspace size
75+
mMultiBlockMode = true;
7776
mEnableXQA = true;
7877
mKVCacheQuantMode = tc::QuantMode(kv_cache_quant_mode);
7978
mRemovePadding = remove_input_padding;

cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -703,8 +703,6 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
703703
= static_cast<bool>(reinterpret_cast<int const*>(inputs[getIdx(IdxEntry::SPEC_DECODING_USE)])[0]);
704704
changeSpecDecodingMode = mUseSpecDecoding != useSpecDecoding;
705705
mUseSpecDecoding = useSpecDecoding;
706-
// change mMultiBlockMode to default
707-
mMultiBlockMode = mUseSpecDecoding ? false : true;
708706
}
709707

710708
[[maybe_unused]] MlaParams<T> mla_params;

cpp/tensorrt_llm/thop/attentionOp.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,6 @@ void attention_inplace(torch::Tensor q, torch::optional<torch::Tensor> k, torch:
527527
op->mIsSpecDecodingEnabled = spec_decoding_bool_params[0]; // is_spec_decoding_enabled
528528
op->mUseSpecDecoding = spec_decoding_bool_params[1]; // use_spec_decoding
529529
op->mIsSpecDecTree = spec_decoding_bool_params[2]; // is_spec_dec_tree
530-
op->mMultiBlockMode = op->mIsSpecDecodingEnabled ? false : true;
531530

532531
if (is_mla_enable)
533532
{

tensorrt_llm/_torch/models/modeling_gemma3.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -156,25 +156,29 @@ def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
156156

157157
class Gemma3MLP(nn.Module):
158158

159-
def __init__(self, config: Gemma3TextConfig):
159+
def __init__(self, model_config: ModelConfig[Gemma3TextConfig]):
160160
super().__init__()
161-
self.config = config
162-
self.hidden_size = config.hidden_size
163-
self.intermediate_size = config.intermediate_size
164-
self.dtype = config.torch_dtype
161+
self.config = model_config.pretrained_config
162+
self.hidden_size = self.config.hidden_size
163+
self.intermediate_size = self.config.intermediate_size
164+
self.dtype = self.config.torch_dtype
165+
self.quant_config = model_config.get_quant_config()
165166
self.gate_proj = Linear(self.hidden_size,
166167
self.intermediate_size,
167168
bias=False,
168-
dtype=self.dtype)
169+
dtype=self.dtype,
170+
quant_config=self.quant_config)
169171
self.up_proj = Linear(self.hidden_size,
170172
self.intermediate_size,
171173
bias=False,
172-
dtype=self.dtype)
174+
dtype=self.dtype,
175+
quant_config=self.quant_config)
173176
self.down_proj = Linear(self.intermediate_size,
174177
self.hidden_size,
175178
bias=False,
176-
dtype=self.dtype)
177-
self.act_fn = ACT2FN[config.hidden_activation]
179+
dtype=self.dtype,
180+
quant_config=self.quant_config)
181+
self.act_fn = ACT2FN[self.config.hidden_activation]
178182

179183
def forward(self, x):
180184
down_proj = self.down_proj(
@@ -199,7 +203,7 @@ def __init__(
199203
is_sliding=is_sliding,
200204
)
201205

202-
self.mlp = Gemma3MLP(config)
206+
self.mlp = Gemma3MLP(model_config=model_config)
203207

204208
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
205209
eps=config.rms_norm_eps,

tensorrt_llm/_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,7 @@ def to_json_str(cls, event):
991991
return {
992992
"event_id": event.event_id,
993993
"data": event_serialize_func(event.data),
994+
"window_size": event.window_size
994995
}
995996

996997
@staticmethod

tests/integration/defs/accuracy/references/cnn_dailymail.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
google/gemma-3-1b-it:
22
- accuracy: 22.988
3+
- quant_algo: FP8
4+
kv_cache_quant_algo: FP8
5+
accuracy: 22.988
6+
google/gemma-3-27b-it:
7+
- accuracy: 28.90
38
gpt2:
49
- accuracy: 18.408
510
- quant_algo: W8A16

0 commit comments

Comments
 (0)