diff --git a/cpp/tensorrt_llm/kernels/quantization.cu b/cpp/tensorrt_llm/kernels/quantization.cu index 817b0a57ee9..78248214c19 100644 --- a/cpp/tensorrt_llm/kernels/quantization.cu +++ b/cpp/tensorrt_llm/kernels/quantization.cu @@ -302,6 +302,98 @@ void invokeBlockScaleInterleaveReverse( block_scale_interleave_reverse_kernel<<>>(b, m, n, SFIn, SFOutput); } +template +struct VecTypeImpl +{ + using type = T; +}; + +template <> +struct VecTypeImpl +{ + using type = half2; +}; + +template <> +struct VecTypeImpl<__nv_bfloat16> +{ + using type = __nv_bfloat162; +}; + +template +using VecType = typename VecTypeImpl::type; + +template +__device__ float getMaxAbs(float4& vec) +{ + auto absMaxVec = cuda_abs(reinterpret_cast*>(&vec)[0]); + for (int i = 1; i < 4; ++i) + { + absMaxVec = cuda_max(absMaxVec, cuda_abs(reinterpret_cast*>(&vec)[i])); + } + float absMaxVal; + if constexpr (sizeof(T) == 4) + { + absMaxVal = static_cast(absMaxVec); + } + else + { + absMaxVal = static_cast(cuda_max(absMaxVec.x, absMaxVec.y)); + } + tensorrt_llm::common::blockReduceMaxV2(&absMaxVal); + return absMaxVal; +} + +template +__global__ void computePerTokenGlobalScaleForFP4QuantizationKernel( + int b, int m, int n, T const* input, int const* tokensPerBatch, float* globalScale) +{ + static constexpr int ElemsPerVec = 16 / sizeof(T); + int batchIdx = blockIdx.x; + int realTokensNum = (tokensPerBatch == nullptr) ? m : tokensPerBatch[batchIdx]; + input += batchIdx * m * n; + globalScale += batchIdx * m; + for (int tokenIdx = blockIdx.y; tokenIdx < realTokensNum; tokenIdx += gridDim.y) + { + float perTokenMaxAbsVal = 0.f; + for (int vecIdx = threadIdx.x; vecIdx < n / ElemsPerVec; vecIdx += blockDim.x) + { + float4 vec = reinterpret_cast(input + tokenIdx * n)[vecIdx]; + float maxAbsVal = getMaxAbs(vec); + perTokenMaxAbsVal = cuda_max(perTokenMaxAbsVal, maxAbsVal); + } + float globalScaleVal = 448.f * 6.f / perTokenMaxAbsVal; + if (threadIdx.x == 0) + { + globalScale[tokenIdx] = globalScaleVal; + } + } +} + +template +void computePerTokenGlobalScaleForFP4Quantization(int b, int m, int n, T const* input, int const* tokensPerBatch, + float* globalScale, int multiProcessorCount, cudaStream_t stream) +{ + + static constexpr int ElemsPerVec = 16 / sizeof(T); + TLLM_CHECK(n % (ElemsPerVec * 32) == 0 and b > 0); + dim3 block(std::min(n / ElemsPerVec, 1024)); + dim3 grid(b, std::max(1, std::min(m, multiProcessorCount / b))); + + cudaLaunchConfig_t config; + config.gridDim = grid; + config.blockDim = block; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); + config.numAttrs = 1; + config.attrs = attrs; + TLLM_CUDA_CHECK(cudaLaunchKernelEx( + &config, computePerTokenGlobalScaleForFP4QuantizationKernel, b, m, n, input, tokensPerBatch, globalScale)); +} + // Instantiate the function. template void invokeFP4Quantization(int b, int m, int n, half const* input, float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount, @@ -311,6 +403,8 @@ template void invokeFP4Quantization(int b, int m, int n, half const* i cudaStream_t stream); template void invokeMxFP8Quantization(int b, int m, int n, int padded_n, half const* input, int64_t* output, int32_t* SFOuput, QuantizationSFLayout layout, int multiProcessorCount, cudaStream_t stream); +template void computePerTokenGlobalScaleForFP4Quantization(int b, int m, int n, half const* input, + int const* tokensPerBatch, float* globalScale, int multiProcessorCount, cudaStream_t stream); #ifdef ENABLE_BF16 template void invokeFP4Quantization<__nv_bfloat16, 16>(int b, int m, int n, __nv_bfloat16 const* input, float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, @@ -320,6 +414,9 @@ template void invokeFP4Quantization<__nv_bfloat16, 32>(int b, int m, int n, __nv int multiProcessorCount, cudaStream_t stream); template void invokeMxFP8Quantization<__nv_bfloat16>(int b, int m, int n, int padded_n, __nv_bfloat16 const* input, int64_t* output, int32_t* SFOuput, QuantizationSFLayout layout, int multiProcessorCount, cudaStream_t stream); +template void computePerTokenGlobalScaleForFP4Quantization<__nv_bfloat16>(int b, int m, int n, + __nv_bfloat16 const* input, int const* tokensPerBatch, float* globalScale, int multiProcessorCount, + cudaStream_t stream); #endif #ifdef ENABLE_FP8 diff --git a/cpp/tensorrt_llm/kernels/quantization.h b/cpp/tensorrt_llm/kernels/quantization.h index 160a54428a7..70776b27906 100644 --- a/cpp/tensorrt_llm/kernels/quantization.h +++ b/cpp/tensorrt_llm/kernels/quantization.h @@ -88,5 +88,9 @@ void invokeBlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded, void invokeBlockScaleInterleaveReverse( int b, int m, int n, uint8_t const* SFIn, uint8_t* SFOutput, int multiProcessorCount, cudaStream_t stream = 0); +template +void computePerTokenGlobalScaleForFP4Quantization(int b, int m, int n, T const* input, int const* tokensPerBatch, + float* globalScale, int multiProcessorCount, cudaStream_t stream = 0); + } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/thop/fp4Quantize.cpp b/cpp/tensorrt_llm/thop/fp4Quantize.cpp index 956b7523249..7fb66047ea2 100644 --- a/cpp/tensorrt_llm/thop/fp4Quantize.cpp +++ b/cpp/tensorrt_llm/thop/fp4Quantize.cpp @@ -153,6 +153,83 @@ std::tuple fp4_quantize(at::Tensor const& self, std::opt return {valueE2M1, scaleFP8SF}; } + +at::Tensor calculate_nvfp4_global_scale(at::Tensor const& input, std::optional const& tokensPerBatch) +{ + CHECK_TH_CUDA(input); + CHECK_CONTIGUOUS(input); + + auto const& inputShape = input.sizes(); + auto const& rank = inputShape.size(); + + TORCH_CHECK(rank >= 2 && rank <= 3); + + // Calculate batch and token numbers + int64_t batch_size = 1; + int64_t token_num = 1; + int64_t hidden_size = inputShape[rank - 1]; + + if (rank == 2) + { + // [token_num, hidden_size] + token_num = inputShape[0]; + batch_size = 1; + } + else if (rank == 3) + { + // [batch, token_num, hidden_size] + batch_size = inputShape[0]; + token_num = inputShape[1]; + } + + // Create output tensor with same dimensions as input, but last dimension size is 1 + std::vector outputShape(inputShape.begin(), inputShape.end()); + outputShape[rank - 1] = 1; + + at::Tensor globalScale = at::detail::empty_cuda(outputShape, torch::kFloat32, input.device(), std::nullopt); + + // Get multi-processor count + static int multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount(); + + // Prepare tokensPerBatch pointer - should have shape (batch_size) + int const* tokensPerBatchPtr = nullptr; + if (tokensPerBatch.has_value()) + { + CHECK_TH_CUDA(tokensPerBatch.value()); + CHECK_CONTIGUOUS(tokensPerBatch.value()); + + auto const& tokensShape = tokensPerBatch.value().sizes(); + TORCH_CHECK(tokensShape.size() == 1, "tokensPerBatch should have exactly 1 dimension"); + TORCH_CHECK(tokensShape[0] == batch_size, "tokensPerBatch first dimension must match input batch size"); + + tokensPerBatchPtr = tokensPerBatch.value().data_ptr(); + } + + // Call corresponding kernel based on input data type + if (input.scalar_type() == at::ScalarType::Half) + { + tensorrt_llm::kernels::computePerTokenGlobalScaleForFP4Quantization(batch_size, token_num, hidden_size, + reinterpret_cast(input.data_ptr()), tokensPerBatchPtr, globalScale.data_ptr(), + multiProcessorCount, at::cuda::getCurrentCUDAStream(input.get_device())); + } + else if (input.scalar_type() == at::ScalarType::BFloat16) + { +#ifdef ENABLE_BF16 + tensorrt_llm::kernels::computePerTokenGlobalScaleForFP4Quantization<__nv_bfloat16>(batch_size, token_num, + hidden_size, reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr()), tokensPerBatchPtr, + globalScale.data_ptr(), multiProcessorCount, at::cuda::getCurrentCUDAStream(input.get_device())); +#else + C10_THROW_ERROR(NotImplementedError, "BFloat16 must be enabled to compute global scale for bf16 tensor."); +#endif + } + else + { + C10_THROW_ERROR( + NotImplementedError, "calculate_nvfp4_global_scale only supports input tensor with dtypes fp16/bf16."); + } + + return globalScale; +} } // namespace torch_ext TORCH_LIBRARY_FRAGMENT(trtllm, m) @@ -161,9 +238,11 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m) "fp4_quantize(Tensor input, Tensor? globalScale, int sfVecSize, bool sfUseUE8M0=False, bool " "isSfSwizzledLayout=True) " "-> (Tensor, Tensor)"); + m.def("calculate_nvfp4_global_scale(Tensor input, Tensor? tokensPerBatch) -> Tensor"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("fp4_quantize", TORCH_FN(torch_ext::fp4_quantize)); + m.impl("calculate_nvfp4_global_scale", TORCH_FN(torch_ext::calculate_nvfp4_global_scale)); } diff --git a/cpp/tensorrt_llm/thop/fp4Quantize.h b/cpp/tensorrt_llm/thop/fp4Quantize.h index e460cb9c95a..0d3b36f9c22 100644 --- a/cpp/tensorrt_llm/thop/fp4Quantize.h +++ b/cpp/tensorrt_llm/thop/fp4Quantize.h @@ -26,4 +26,6 @@ namespace torch_ext { std::tuple fp4_quantize(at::Tensor const& self, std::optional const& globalScale, int64_t sfVecSize, bool sfUseUE8M0, bool isSfSwizzledLayout); -} + +at::Tensor calculate_nvfp4_global_scale(at::Tensor const& input, std::optional const& tokensPerBatch); +} // namespace torch_ext diff --git a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py index 0d75cf636c3..738486b653b 100644 --- a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py @@ -179,6 +179,10 @@ def _( return (input.new_empty(output_shape, dtype=torch.uint8), global_scale.new_empty(scale_shape, dtype=torch.uint8)) + @torch.library.register_fake("trtllm::calculate_nvfp4_global_scale") + def _(input: torch.Tensor, tokens_per_batch: Optional[torch.Tensor]): + return input.new_empty((input.shape[:-1], 1), dtype=torch.float32) + @torch.library.register_fake("trtllm::moe_comm") def _( inputs: List[torch.Tensor], diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index 22d14a83b55..a3e232bb730 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -691,8 +691,8 @@ def forward_chunk( self.expert_size_per_partition, num_tokens_per_expert_for_fused_moe, self.hidden_size) if self.use_low_precision_combine: - global_scales = (448 * 6) / final_hidden_states.abs().max( - dim=-1, keepdim=True).values.to(torch.float32) + global_scales = torch.ops.trtllm.calculate_nvfp4_global_scale( + final_hidden_states, recv_expert_count) final_hidden_states = self.deep_ep_buffer.low_latency_combine_fp4( final_hidden_states, global_scales, deep_ep_topk_idx, deep_ep_topk_weights, deep_ep_handle) diff --git a/tests/unittest/_torch/thop/test_fp4_calculate_global_scale.py b/tests/unittest/_torch/thop/test_fp4_calculate_global_scale.py new file mode 100644 index 00000000000..87ce36c3378 --- /dev/null +++ b/tests/unittest/_torch/thop/test_fp4_calculate_global_scale.py @@ -0,0 +1,223 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized +from utils.util import skip_pre_blackwell_unittest, unittest_name_func + +import tensorrt_llm + + +def reference_calculate_global_scale(input_tensor): + max_abs_values = input_tensor.abs().max(dim=-1, keepdim=True).values.to( + torch.float32) + global_scales = (448 * 6) / max_abs_values + return global_scales + + +class TestFP4CalculateGlobalScale(unittest.TestCase): + + def setUp(self): + tensorrt_llm.logger.set_level("warning") + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + @parameterized.expand([ + [1, 64, 7168, torch.bfloat16, False], + [1, 64, 7168, torch.float16, False], + [1, 64, 7168, torch.bfloat16, True], + [1, 64, 4096, torch.bfloat16, True], + [8, 8 * 64, 7168, torch.bfloat16, False], + [8, 8 * 64, 7168, torch.bfloat16, True], + [16, 16 * 64, 7168, torch.bfloat16, True], + [32, 32 * 64, 7168, torch.bfloat16, True], + ], + name_func=unittest_name_func) + @skip_pre_blackwell_unittest + def test_calculate_nvfp4_global_scale_accuracy(self, batch_size, + max_token_num, hidden_size, + dtype, use_tokens_per_batch): + if batch_size == 1: + input_shape = (max_token_num, hidden_size) + else: + input_shape = (batch_size, max_token_num, hidden_size) + input_tensor = torch.randn(input_shape, dtype=dtype, device='cuda') + + assert hidden_size % 16 == 0, f"Hidden size {hidden_size} must be divisible by 16" + + tokens_per_batch = None + if use_tokens_per_batch: + # Create tokensPerBatch tensor with shape (batch_size) + # Each value represents the actual number of meaningful tokens in that batch + tokens_per_batch = torch.randint(0, + max_token_num + 1, (batch_size, ), + device='cuda', + dtype=torch.int32) + + reference_result = reference_calculate_global_scale(input_tensor) + custom_result = torch.ops.trtllm.calculate_nvfp4_global_scale( + input_tensor, tokens_per_batch) + torch.cuda.synchronize() + + self.assertEqual(custom_result.shape, reference_result.shape) + + if use_tokens_per_batch: + # Create mask for meaningful tokens based on tokens_per_batch + # Only compare results for tokens that are within the meaningful range + meaningful_mask = torch.zeros(custom_result.shape, + dtype=torch.bool, + device='cuda') + if batch_size == 1: + meaningful_mask[:tokens_per_batch[0]] = True + else: + for i in range(batch_size): + meaningful_mask[i, :tokens_per_batch[i]] = True + + custom_result = custom_result * meaningful_mask + reference_result = reference_result * meaningful_mask + + torch.testing.assert_close( + custom_result, + reference_result, + atol=1e-3, + rtol=1e-3, + msg= + f"Shape: {input_shape}, dtype: {dtype}, custom_result: {custom_result}, reference_result: {reference_result}" + ) + + @parameterized.expand( + [ + # [local_experts_num, ranks_num * max_token_num_per_rank, max_token_num_per_rank, hidden_size] + [32, 8 * 64, 64, 7168, torch.bfloat16, False], + [32, 8 * 64, 64, 7168, torch.bfloat16, True], + [16, 16 * 64, 64, 7168, torch.bfloat16, False], + [16, 16 * 64, 64, 7168, torch.bfloat16, True], + [8, 32 * 64, 64, 7168, torch.bfloat16, False], + [8, 32 * 64, 64, 7168, torch.bfloat16, True], + ], + name_func=unittest_name_func) + @skip_pre_blackwell_unittest + def test_calculate_nvfp4_global_scale_performance(self, batch_size, + max_token_num, + real_token_num, + hidden_size, dtype, + use_tokens_per_batch): + if batch_size == 1: + input_shape = (max_token_num, hidden_size) + else: + input_shape = (batch_size, max_token_num, hidden_size) + input_tensor = torch.randn(input_shape, dtype=dtype, device='cuda') + + tokens_per_batch = None + if use_tokens_per_batch: + tokens_per_batch = torch.zeros((batch_size, ), + device='cuda', + dtype=torch.int32) + tokens_per_batch[:] = real_token_num + + for _ in range(10): + _ = torch.ops.trtllm.calculate_nvfp4_global_scale( + input_tensor, tokens_per_batch) + _ = reference_calculate_global_scale(input_tensor) + + torch.cuda.synchronize() + + num_iterations = 100 + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(num_iterations): + _ = torch.ops.trtllm.calculate_nvfp4_global_scale( + input_tensor, tokens_per_batch) + end_event.record() + torch.cuda.synchronize() + custom_time = start_event.elapsed_time(end_event) + + start_event.record() + for _ in range(num_iterations): + _ = reference_calculate_global_scale(input_tensor) + end_event.record() + torch.cuda.synchronize() + reference_time = start_event.elapsed_time(end_event) + + custom_avg_time = custom_time / num_iterations + reference_avg_time = reference_time / num_iterations + speedup = reference_avg_time / custom_avg_time + + tokens_info = "with tokensPerBatch" if use_tokens_per_batch else "without tokensPerBatch" + print( + f"\nPerformance Test Results for {input_shape}, {real_token_num}, {dtype}, {tokens_info}:" + ) + print(f"Custom op average time: {custom_avg_time*1000:.3f} us") + print(f"Reference average time: {reference_avg_time*1000:.3f} us") + print(f"Speedup: {speedup:.2f}x") + + @skip_pre_blackwell_unittest + def test_calculate_nvfp4_global_scale_invalid_inputs(self): + # Test with 1D tensor (should fail) + input_tensor = torch.randn(4096, dtype=torch.float16, device='cuda') + with self.assertRaises(Exception): + torch.ops.trtllm.calculate_nvfp4_global_scale(input_tensor) + + # Test with hidden_size not divisible by 16 (should fail) + input_tensor = torch.randn((4, 32, 4095), + dtype=torch.float16, + device='cuda') + with self.assertRaises(Exception): + torch.ops.trtllm.calculate_nvfp4_global_scale(input_tensor) + + # Test with mismatched tokensPerBatch shape (wrong batch size) + input_tensor = torch.randn((4, 32, 4096), + dtype=torch.float16, + device='cuda') + tokens_per_batch = torch.randint(1, + 33, (5, ), + device='cuda', + dtype=torch.int32) # Wrong batch size + with self.assertRaises(Exception): + torch.ops.trtllm.calculate_nvfp4_global_scale( + input_tensor, tokens_per_batch) + + # Test with tokensPerBatch having wrong number of dimensions (should be 1D) + input_tensor = torch.randn((4, 32, 4096), + dtype=torch.float16, + device='cuda') + tokens_per_batch = torch.randint(1, + 33, (4, 32), + device='cuda', + dtype=torch.int32) # 2D instead of 1D + with self.assertRaises(Exception): + torch.ops.trtllm.calculate_nvfp4_global_scale( + input_tensor, tokens_per_batch) + + # Test with tokensPerBatch having wrong first dimension size + input_tensor = torch.randn((4, 32, 4096), + dtype=torch.float16, + device='cuda') + tokens_per_batch = torch.randint(1, + 33, (3, ), + device='cuda', + dtype=torch.int32) # Wrong batch size + with self.assertRaises(Exception): + torch.ops.trtllm.calculate_nvfp4_global_scale( + input_tensor, tokens_per_batch) + + +if __name__ == '__main__': + unittest.main()