@@ -214,6 +214,10 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
214214 }
215215 xqaParams.kv_cache_data_type = DATA_TYPE_E4M3;
216216 }
217+ else if (mKVCacheQuantMode .hasFp4KvCache ())
218+ {
219+ xqaParams.kv_cache_data_type = DATA_TYPE_E2M1;
220+ }
217221 else
218222 {
219223 xqaParams.kv_cache_data_type = xqaParams.data_type ;
@@ -959,6 +963,9 @@ int AttentionOp::mlaGeneration(
959963 generation_params.can_use_one_more_block , generation_params.host_primary_pool_pointer ,
960964 generation_params.host_secondary_pool_pointer , generation_params.block_offsets );
961965
966+ // Currently NVFP4 KV cache is not supported for MLA. An empty placeholder is provided.
967+ auto kv_scale_cache_buffer = KVBlockArray ();
968+
962969 // Workspace pointer shift
963970 int8_t * workspace_byte_ptr = reinterpret_cast <int8_t *>(params.workspace );
964971 size_t offset = 0 ;
@@ -1234,7 +1241,7 @@ int AttentionOp::mlaGeneration(
12341241 {
12351242 TLLM_LOG_DEBUG (" XQA kernels are selected in the generation phase." );
12361243 xqaParams.stream = stream;
1237- mXqaDispatcher ->run (xqaParams, kv_cache_buffer);
1244+ mXqaDispatcher ->run (xqaParams, kv_cache_buffer, kv_scale_cache_buffer );
12381245 return 0 ;
12391246 }
12401247 else if (mIsSpecDecodingEnabled && mUseSpecDecoding )
@@ -1308,8 +1315,10 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
13081315 float const q_scaling = mQScaling ;
13091316
13101317 KVCacheBuffer kv_cache_buffer;
1311- auto const elemSize = mKVCacheQuantMode .hasKvCacheQuant () ? sizeof (int8_t ) : sizeof (T);
1312- auto sizePerToken = mNumAttnKVHeads * headSize * elemSize;
1318+ KVCacheBuffer kv_scale_cache_buffer;
1319+
1320+ auto sizePerToken = mNumAttnKVHeads * headSize * getKvCacheElemSizeInBits<T>() / 8 /* bits*/ ;
1321+
13131322 if (useKVCache ())
13141323 {
13151324 if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
@@ -1318,6 +1327,14 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
13181327 sizePerToken, params.cyclic_attention_window_size , params.max_cyclic_attention_window_size ,
13191328 params.sink_token_length , params.can_use_one_more_block , params.host_primary_pool_pointer ,
13201329 params.host_secondary_pool_pointer , params.block_offsets );
1330+ if (mKVCacheQuantMode .hasFp4KvCache ())
1331+ {
1332+ kv_scale_cache_buffer = KVBlockArray (params.batch_size , params.max_blocks_per_sequence , mTokensPerBlock ,
1333+ sizePerToken / 8 , params.cyclic_attention_window_size , params.max_cyclic_attention_window_size ,
1334+ params.sink_token_length , params.can_use_one_more_block ,
1335+ params.host_primary_block_scale_pool_pointer , params.host_secondary_block_scale_pool_pointer ,
1336+ params.block_offsets );
1337+ }
13211338 }
13221339 else if constexpr (std::is_same_v<KVCacheBuffer, KVLinearBuffer>)
13231340 {
@@ -1326,6 +1343,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
13261343 isCrossAttention () ? params.cross_kv_length : params.max_attention_window_size , sizePerToken,
13271344 params.cyclic_attention_window_size , params.sink_token_length , false ,
13281345 reinterpret_cast <BufferDataType*>(params.key_value_cache ));
1346+ TLLM_CHECK_WITH_INFO (!(mKVCacheQuantMode .hasFp4KvCache ()), " FP4 KV cache only supports paged KV." );
13291347 }
13301348 }
13311349
@@ -1490,8 +1508,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
14901508 decoder_params.blockSparseParams = mBlockSparseParams ;
14911509 decoder_params.fmhaTileCounter = fmha_tile_counter_ptr;
14921510 decoder_params.quantScaleO = params.attention_output_orig_quant ;
1493- decoder_params.dequantScaleQ = params.kv_scale_quant_orig ;
1494- decoder_params.dequantScaleKv = params. kv_scale_quant_orig ;
1511+ decoder_params.dequantScaleQkv = params.kv_scale_quant_orig ;
1512+ decoder_params.separateQkvScales = mKVCacheQuantMode . hasFp4KvCache () ;
14951513 decoder_params.fmhaHostBmm1Scale = 1 .0f / (sqrtf (getHeadSize () * 1 .0f ) * q_scaling);
14961514 decoder_params.fmhaBmm1Scale = fmha_bmm1_scale_ptr;
14971515 decoder_params.fmhaBmm2Scale = fmha_bmm2_scale_ptr;
@@ -1549,9 +1567,19 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
15491567 sync_check_cuda_error (stream);
15501568 }
15511569
1552- KvCacheDataType const cache_type = mKVCacheQuantMode .hasInt8KvCache ()
1553- ? KvCacheDataType::INT8
1554- : (mKVCacheQuantMode .hasFp8KvCache () ? KvCacheDataType::FP8 : KvCacheDataType::BASE);
1570+ KvCacheDataType cache_type{KvCacheDataType::BASE};
1571+ if (mKVCacheQuantMode .hasInt8KvCache ())
1572+ {
1573+ cache_type = KvCacheDataType::INT8;
1574+ }
1575+ else if (mKVCacheQuantMode .hasFp8KvCache ())
1576+ {
1577+ cache_type = KvCacheDataType::FP8;
1578+ }
1579+ else if (mKVCacheQuantMode .hasFp4KvCache ())
1580+ {
1581+ cache_type = KvCacheDataType::NVFP4;
1582+ }
15551583
15561584 cudaDataType_t const gemm_data_type = tc::CudaDataType<T>::value;
15571585 int const attention_seq_len_1 = params.input_seq_length ; // q length
@@ -1600,6 +1628,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
16001628 preprocessingParams.quantized_qkv_output = fp8_qkv_buffer;
16011629 preprocessingParams.q_output = q_buf_2_;
16021630 preprocessingParams.kv_cache_buffer = kv_cache_buffer;
1631+ preprocessingParams.kv_cache_block_scales_buffer = kv_scale_cache_buffer;
16031632 preprocessingParams.qkv_bias = params.qkv_bias ;
16041633 preprocessingParams.tokens_info = decoder_params.tokensInfo ;
16051634 preprocessingParams.seq_lens = params.context_lengths ;
@@ -1612,7 +1641,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
16121641 preprocessingParams.rotary_embedding_inv_freq = rotary_inv_freq_buf;
16131642 preprocessingParams.rotary_coef_cache_buffer = params.rotary_cos_sin ;
16141643 preprocessingParams.mrope_rotary_cos_sin = params.mrope_rotary_cos_sin ;
1615- preprocessingParams.kvScaleOrigQuant = params.kv_scale_orig_quant ;
1644+ preprocessingParams.qkv_scale_orig_quant = params.kv_scale_orig_quant ;
16161645 preprocessingParams.spec_decoding_position_offsets = nullptr ;
16171646 preprocessingParams.logn_scaling = params.logn_scaling_ptr ;
16181647
@@ -1781,6 +1810,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
17811810 if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
17821811 {
17831812 fmhaParams.pagedKvCache = kv_cache_buffer;
1813+ fmhaParams.pagedKvSfCache = kv_scale_cache_buffer;
17841814 }
17851815 fmhaParams.cuQSeqLenPtr = cu_q_seqlens;
17861816 fmhaParams.kvSeqLenPtr = decoder_params.seqKVLengths ;
@@ -2126,8 +2156,10 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
21262156 int32_t const batch_beam = params.beam_width * params.num_requests ;
21272157
21282158 KVCacheBuffer kv_cache_buffer;
2129- auto const elemSize = mKVCacheQuantMode .hasKvCacheQuant () ? sizeof (int8_t ) : sizeof (T);
2130- auto const sizePerToken = mNumAttnKVHeads * headSize * elemSize;
2159+ KVCacheBuffer kv_scale_cache_buffer;
2160+
2161+ auto const sizePerToken = mNumAttnKVHeads * headSize * getKvCacheElemSizeInBits<T>() / 8 /* bits*/ ;
2162+
21312163 if (useKVCache ())
21322164 {
21332165 if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
@@ -2137,13 +2169,22 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
21372169 params.cyclic_attention_window_size , params.max_cyclic_attention_window_size , params.sink_token_length ,
21382170 params.can_use_one_more_block , params.host_primary_pool_pointer , params.host_secondary_pool_pointer ,
21392171 reinterpret_cast <BufferDataType*>(params.block_offsets ));
2172+ if (mKVCacheQuantMode .hasFp4KvCache ())
2173+ {
2174+ kv_scale_cache_buffer = KVBlockArray (batch_beam, params.max_blocks_per_sequence , mTokensPerBlock ,
2175+ sizePerToken / 8 , params.cyclic_attention_window_size , params.max_cyclic_attention_window_size ,
2176+ params.sink_token_length , params.can_use_one_more_block ,
2177+ params.host_primary_block_scale_pool_pointer , params.host_secondary_block_scale_pool_pointer ,
2178+ reinterpret_cast <BufferDataType*>(params.block_offsets ));
2179+ }
21402180 }
21412181 else if constexpr (std::is_same_v<KVCacheBuffer, KVLinearBuffer>)
21422182 {
21432183 using BufferDataType = typename KVCacheBuffer::DataType;
21442184 kv_cache_buffer = KVLinearBuffer (batch_beam, params.max_attention_window_size , sizePerToken,
21452185 params.cyclic_attention_window_size , params.sink_token_length , false ,
21462186 reinterpret_cast <BufferDataType*>(params.key_value_cache ));
2187+ TLLM_CHECK_WITH_INFO (!(mKVCacheQuantMode .hasFp4KvCache ()), " FP4 KV cache only supports paged KV." );
21472188 }
21482189 }
21492190 sync_check_cuda_error (stream);
@@ -2215,7 +2256,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
22152256 xqaParams.output = mhaOutput;
22162257 xqaParams.qkv = attention_input;
22172258 }
2218- mXqaDispatcher ->run (xqaParams, kv_cache_buffer);
2259+ mXqaDispatcher ->run (xqaParams, kv_cache_buffer, kv_scale_cache_buffer );
22192260 if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1 )
22202261 {
22212262 this ->template ulyssesGenerationPostprocess <T>(
@@ -2232,6 +2273,10 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
22322273 {
22332274 TLLM_CHECK_WITH_INFO (false , " No available kernels are found for FP4 output." );
22342275 }
2276+ else if (mKVCacheQuantMode .hasFp4KvCache ())
2277+ {
2278+ TLLM_CHECK_WITH_INFO (false , " No available kernels are found for FP4 KV cache." );
2279+ }
22352280 else
22362281 {
22372282 TLLM_LOG_DEBUG (" XQA kernels are not selected in the generation phase." );
@@ -2503,6 +2548,10 @@ int AttentionOp::initialize() noexcept
25032548 TLLM_CHECK_WITH_INFO (!mFuseFp4Quant || mSM == 100 || mSM == 120 || mSM == 121 ,
25042549 " fuse_fp4_quant only supports SM100 or SM120 or SM121 devices." );
25052550
2551+ // Check requirements for FP4 KV cache.
2552+ TLLM_CHECK_WITH_INFO (!mKVCacheQuantMode .hasFp4KvCache () || mFP8ContextFMHA ,
2553+ " mFP8ContextFMHA must enable if FP4 KV cache is enabled" );
2554+
25062555 TLLM_CHECK (isRoPE () == (mRotaryEmbeddingDim != 0 ));
25072556 TLLM_CHECK_WITH_INFO ((mSM >= 80 ) || (mType != nvinfer1::DataType::kBF16 ),
25082557 " Unsupported data type, pre SM 80 GPUs do not support bfloat16" );
@@ -2579,7 +2628,10 @@ int AttentionOp::initialize() noexcept
25792628 {
25802629 fmhaParams.dataTypeKv = DATA_TYPE_E4M3;
25812630 }
2582- // TODO: add FP4 KV cache support.
2631+ else if (mKVCacheQuantMode .hasFp4KvCache ())
2632+ {
2633+ fmhaParams.dataTypeKv = DATA_TYPE_E2M1;
2634+ }
25832635 }
25842636 // The output dtype.
25852637 fmhaParams.dataTypeOut = data_type;
@@ -2789,6 +2841,11 @@ int AttentionOp::initialize() noexcept
27892841 fixedParams.kvDataType = DATA_TYPE_E4M3;
27902842 fixedParams.mathDataType = DATA_TYPE_E4M3;
27912843 }
2844+ else if (mKVCacheQuantMode .hasFp4KvCache ())
2845+ {
2846+ fixedParams.kvDataType = DATA_TYPE_E2M1;
2847+ fixedParams.mathDataType = DATA_TYPE_E4M3;
2848+ }
27922849 else
27932850 {
27942851 fixedParams.kvDataType = fixedParams.inputDataType ;
0 commit comments