Skip to content

Commit 0f0f0e5

Browse files
committed
custom kernel for calculate nvfp4 global scale
Signed-off-by: Yilin Zhang <[email protected]>
1 parent 54ec2c1 commit 0f0f0e5

File tree

6 files changed

+411
-3
lines changed

6 files changed

+411
-3
lines changed

cpp/tensorrt_llm/kernels/quantization.cu

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,98 @@ void invokeBlockScaleInterleaveReverse(
302302
block_scale_interleave_reverse_kernel<<<grid, block, 0, stream>>>(b, m, n, SFIn, SFOutput);
303303
}
304304

305+
template <typename T>
306+
struct VecTypeImpl
307+
{
308+
using type = T;
309+
};
310+
311+
template <>
312+
struct VecTypeImpl<half>
313+
{
314+
using type = half2;
315+
};
316+
317+
template <>
318+
struct VecTypeImpl<__nv_bfloat16>
319+
{
320+
using type = __nv_bfloat162;
321+
};
322+
323+
template <typename T>
324+
using VecType = typename VecTypeImpl<T>::type;
325+
326+
template <typename T>
327+
__device__ float getMaxAbs(float4& vec)
328+
{
329+
auto absMaxVec = cuda_abs(reinterpret_cast<VecType<T>*>(&vec)[0]);
330+
for (int i = 1; i < 4; ++i)
331+
{
332+
absMaxVec = cuda_max(absMaxVec, cuda_abs(reinterpret_cast<VecType<T>*>(&vec)[i]));
333+
}
334+
float absMaxVal;
335+
if constexpr (sizeof(T) == 4)
336+
{
337+
absMaxVal = static_cast<float>(absMaxVec);
338+
}
339+
else
340+
{
341+
absMaxVal = static_cast<float>(cuda_max(absMaxVec.x, absMaxVec.y));
342+
}
343+
tensorrt_llm::common::blockReduceMaxV2<float, 1>(&absMaxVal);
344+
return absMaxVal;
345+
}
346+
347+
template <typename T>
348+
__global__ void computePerTokenGlobalScaleForFP4QuantizationKernel(
349+
int b, int m, int n, T const* input, int const* tokensPerBatch, float* globalScale)
350+
{
351+
static constexpr int ElemsPerVec = 16 / sizeof(T);
352+
int batchIdx = blockIdx.x;
353+
int realTokensNum = (tokensPerBatch == nullptr) ? m : tokensPerBatch[batchIdx];
354+
input += batchIdx * m * n;
355+
globalScale += batchIdx * m;
356+
for (int tokenIdx = blockIdx.y; tokenIdx < realTokensNum; tokenIdx += gridDim.y)
357+
{
358+
float perTokenMaxAbsVal = 0.f;
359+
for (int vecIdx = threadIdx.x; vecIdx < n / ElemsPerVec; vecIdx += blockDim.x)
360+
{
361+
float4 vec = reinterpret_cast<float4 const*>(input + tokenIdx * n)[vecIdx];
362+
float maxAbsVal = getMaxAbs<T>(vec);
363+
perTokenMaxAbsVal = cuda_max(perTokenMaxAbsVal, maxAbsVal);
364+
}
365+
float globalScaleVal = 448.f * 6.f / perTokenMaxAbsVal;
366+
if (threadIdx.x == 0)
367+
{
368+
globalScale[tokenIdx] = globalScaleVal;
369+
}
370+
}
371+
}
372+
373+
template <typename T>
374+
void computePerTokenGlobalScaleForFP4Quantization(int b, int m, int n, T const* input, int const* tokensPerBatch,
375+
float* globalScale, int multiProcessorCount, cudaStream_t stream)
376+
{
377+
378+
static constexpr int ElemsPerVec = 16 / sizeof(T);
379+
TLLM_CHECK(n % (ElemsPerVec * 32) == 0 and b > 0);
380+
dim3 block(std::min(n / ElemsPerVec, 1024));
381+
dim3 grid(b, std::max(1, std::min(m, multiProcessorCount / b)));
382+
383+
cudaLaunchConfig_t config;
384+
config.gridDim = grid;
385+
config.blockDim = block;
386+
config.dynamicSmemBytes = 0;
387+
config.stream = stream;
388+
cudaLaunchAttribute attrs[1];
389+
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
390+
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
391+
config.numAttrs = 1;
392+
config.attrs = attrs;
393+
TLLM_CUDA_CHECK(cudaLaunchKernelEx(
394+
&config, computePerTokenGlobalScaleForFP4QuantizationKernel<T>, b, m, n, input, tokensPerBatch, globalScale));
395+
}
396+
305397
// Instantiate the function.
306398
template void invokeFP4Quantization<half, 16>(int b, int m, int n, half const* input, float const* SFScale,
307399
int64_t* output, int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
@@ -311,6 +403,8 @@ template void invokeFP4Quantization<half, 32>(int b, int m, int n, half const* i
311403
cudaStream_t stream);
312404
template void invokeMxFP8Quantization<half>(int b, int m, int n, int padded_n, half const* input, int64_t* output,
313405
int32_t* SFOuput, QuantizationSFLayout layout, int multiProcessorCount, cudaStream_t stream);
406+
template void computePerTokenGlobalScaleForFP4Quantization<half>(int b, int m, int n, half const* input,
407+
int const* tokensPerBatch, float* globalScale, int multiProcessorCount, cudaStream_t stream);
314408
#ifdef ENABLE_BF16
315409
template void invokeFP4Quantization<__nv_bfloat16, 16>(int b, int m, int n, __nv_bfloat16 const* input,
316410
float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout,
@@ -320,6 +414,9 @@ template void invokeFP4Quantization<__nv_bfloat16, 32>(int b, int m, int n, __nv
320414
int multiProcessorCount, cudaStream_t stream);
321415
template void invokeMxFP8Quantization<__nv_bfloat16>(int b, int m, int n, int padded_n, __nv_bfloat16 const* input,
322416
int64_t* output, int32_t* SFOuput, QuantizationSFLayout layout, int multiProcessorCount, cudaStream_t stream);
417+
template void computePerTokenGlobalScaleForFP4Quantization<__nv_bfloat16>(int b, int m, int n,
418+
__nv_bfloat16 const* input, int const* tokensPerBatch, float* globalScale, int multiProcessorCount,
419+
cudaStream_t stream);
323420
#endif
324421

325422
#ifdef ENABLE_FP8

cpp/tensorrt_llm/kernels/quantization.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,5 +88,9 @@ void invokeBlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded,
8888
void invokeBlockScaleInterleaveReverse(
8989
int b, int m, int n, uint8_t const* SFIn, uint8_t* SFOutput, int multiProcessorCount, cudaStream_t stream = 0);
9090

91+
template <typename T>
92+
void computePerTokenGlobalScaleForFP4Quantization(int b, int m, int n, T const* input, int const* tokensPerBatch,
93+
float* globalScale, int multiProcessorCount, cudaStream_t stream = 0);
94+
9195
} // namespace kernels
9296
} // namespace tensorrt_llm

cpp/tensorrt_llm/thop/fp4Quantize.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,86 @@ std::tuple<at::Tensor, at::Tensor> fp4_quantize(at::Tensor const& self, std::opt
153153

154154
return {valueE2M1, scaleFP8SF};
155155
}
156+
157+
at::Tensor calculate_nvfp4_global_scale(at::Tensor const& input, std::optional<at::Tensor> const& tokensPerBatch)
158+
{
159+
CHECK_TH_CUDA(input);
160+
CHECK_CONTIGUOUS(input);
161+
162+
auto const& inputShape = input.sizes();
163+
auto const& rank = inputShape.size();
164+
165+
TORCH_CHECK(rank >= 2 && rank <= 3);
166+
167+
// Calculate batch and token numbers
168+
int64_t batch_size = 1;
169+
int64_t token_num = 1;
170+
int64_t hidden_size = inputShape[rank - 1];
171+
172+
if (rank == 2)
173+
{
174+
// [token_num, hidden_size]
175+
token_num = inputShape[0];
176+
batch_size = 1;
177+
}
178+
else if (rank == 3)
179+
{
180+
// [batch, token_num, hidden_size]
181+
batch_size = inputShape[0];
182+
token_num = inputShape[1];
183+
}
184+
185+
// Check if hidden_size is aligned
186+
TORCH_CHECK(hidden_size % 16 == 0, "Hidden size must be divisible by 16 for FP4 quantization.");
187+
188+
// Create output tensor with same dimensions as input, but last dimension size is 1
189+
std::vector<int64_t> outputShape(inputShape.begin(), inputShape.end());
190+
outputShape[rank - 1] = 1;
191+
192+
at::Tensor globalScale = at::detail::empty_cuda(outputShape, torch::kFloat32, input.device(), std::nullopt);
193+
194+
// Get multi-processor count
195+
static int multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();
196+
197+
// Prepare tokensPerBatch pointer - should have shape (batch_size)
198+
int const* tokensPerBatchPtr = nullptr;
199+
if (tokensPerBatch.has_value())
200+
{
201+
CHECK_TH_CUDA(tokensPerBatch.value());
202+
CHECK_CONTIGUOUS(tokensPerBatch.value());
203+
204+
auto const& tokensShape = tokensPerBatch.value().sizes();
205+
TORCH_CHECK(tokensShape.size() == 1, "tokensPerBatch should have exactly 1 dimension");
206+
TORCH_CHECK(tokensShape[0] == batch_size, "tokensPerBatch first dimension must match input batch size");
207+
208+
tokensPerBatchPtr = tokensPerBatch.value().data_ptr<int>();
209+
}
210+
211+
// Call corresponding kernel based on input data type
212+
if (input.scalar_type() == at::ScalarType::Half)
213+
{
214+
tensorrt_llm::kernels::computePerTokenGlobalScaleForFP4Quantization<half>(batch_size, token_num, hidden_size,
215+
reinterpret_cast<half const*>(input.data_ptr()), tokensPerBatchPtr, globalScale.data_ptr<float>(),
216+
multiProcessorCount, at::cuda::getCurrentCUDAStream(input.get_device()));
217+
}
218+
else if (input.scalar_type() == at::ScalarType::BFloat16)
219+
{
220+
#ifdef ENABLE_BF16
221+
tensorrt_llm::kernels::computePerTokenGlobalScaleForFP4Quantization<__nv_bfloat16>(batch_size, token_num,
222+
hidden_size, reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr()), tokensPerBatchPtr,
223+
globalScale.data_ptr<float>(), multiProcessorCount, at::cuda::getCurrentCUDAStream(input.get_device()));
224+
#else
225+
C10_THROW_ERROR(NotImplementedError, "BFloat16 must be enabled to compute global scale for bf16 tensor.");
226+
#endif
227+
}
228+
else
229+
{
230+
C10_THROW_ERROR(
231+
NotImplementedError, "calculate_nvfp4_global_scale only supports input tensor with dtypes fp16/bf16.");
232+
}
233+
234+
return globalScale;
235+
}
156236
} // namespace torch_ext
157237

158238
TORCH_LIBRARY_FRAGMENT(trtllm, m)
@@ -161,9 +241,11 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
161241
"fp4_quantize(Tensor input, Tensor? globalScale, int sfVecSize, bool sfUseUE8M0=False, bool "
162242
"isSfSwizzledLayout=True) "
163243
"-> (Tensor, Tensor)");
244+
m.def("calculate_nvfp4_global_scale(Tensor input, Tensor? tokensPerBatch) -> Tensor");
164245
}
165246

166247
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
167248
{
168249
m.impl("fp4_quantize", TORCH_FN(torch_ext::fp4_quantize));
250+
m.impl("calculate_nvfp4_global_scale", TORCH_FN(torch_ext::calculate_nvfp4_global_scale));
169251
}

cpp/tensorrt_llm/thop/fp4Quantize.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,6 @@ namespace torch_ext
2626
{
2727
std::tuple<at::Tensor, at::Tensor> fp4_quantize(at::Tensor const& self, std::optional<at::Tensor> const& globalScale,
2828
int64_t sfVecSize, bool sfUseUE8M0, bool isSfSwizzledLayout);
29-
}
29+
30+
at::Tensor calculate_nvfp4_global_scale(at::Tensor const& input, std::optional<at::Tensor> const& tokensPerBatch);
31+
} // namespace torch_ext

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -689,8 +689,8 @@ def forward_chunk(
689689
self.expert_size_per_partition,
690690
num_tokens_per_expert_for_fused_moe, self.hidden_size)
691691
if self.use_low_precision_combine:
692-
global_scales = (448 * 6) / final_hidden_states.abs().max(
693-
dim=-1, keepdim=True).values.to(torch.float32)
692+
global_scales = torch.ops.trtllm.calculate_nvfp4_global_scale(
693+
final_hidden_states, recv_expert_count)
694694
final_hidden_states = self.deep_ep_buffer.low_latency_combine_fp4(
695695
final_hidden_states, global_scales, deep_ep_topk_idx,
696696
deep_ep_topk_weights, deep_ep_handle)

0 commit comments

Comments
 (0)