@@ -302,6 +302,98 @@ void invokeBlockScaleInterleaveReverse(
302302 block_scale_interleave_reverse_kernel<<<grid, block, 0 , stream>>> (b, m, n, SFIn, SFOutput);
303303}
304304
305+ template <typename T>
306+ struct VecTypeImpl
307+ {
308+ using type = T;
309+ };
310+
311+ template <>
312+ struct VecTypeImpl <half>
313+ {
314+ using type = half2;
315+ };
316+
317+ template <>
318+ struct VecTypeImpl <__nv_bfloat16>
319+ {
320+ using type = __nv_bfloat162;
321+ };
322+
323+ template <typename T>
324+ using VecType = typename VecTypeImpl<T>::type;
325+
326+ template <typename T>
327+ __device__ float getMaxAbs (float4 & vec)
328+ {
329+ auto absMaxVec = cuda_abs (reinterpret_cast <VecType<T>*>(&vec)[0 ]);
330+ for (int i = 1 ; i < 4 ; ++i)
331+ {
332+ absMaxVec = cuda_max (absMaxVec, cuda_abs (reinterpret_cast <VecType<T>*>(&vec)[i]));
333+ }
334+ float absMaxVal;
335+ if constexpr (sizeof (T) == 4 )
336+ {
337+ absMaxVal = static_cast <float >(absMaxVec);
338+ }
339+ else
340+ {
341+ absMaxVal = static_cast <float >(cuda_max (absMaxVec.x , absMaxVec.y ));
342+ }
343+ tensorrt_llm::common::blockReduceMaxV2<float , 1 >(&absMaxVal);
344+ return absMaxVal;
345+ }
346+
347+ template <typename T>
348+ __global__ void computePerTokenGlobalScaleForFP4QuantizationKernel (
349+ int b, int m, int n, T const * input, int const * tokensPerBatch, float * globalScale)
350+ {
351+ static constexpr int ElemsPerVec = 16 / sizeof (T);
352+ int batchIdx = blockIdx .x ;
353+ int realTokensNum = (tokensPerBatch == nullptr ) ? m : tokensPerBatch[batchIdx];
354+ input += batchIdx * m * n;
355+ globalScale += batchIdx * m;
356+ for (int tokenIdx = blockIdx .y ; tokenIdx < realTokensNum; tokenIdx += gridDim .y )
357+ {
358+ float perTokenMaxAbsVal = 0 .f ;
359+ for (int vecIdx = threadIdx .x ; vecIdx < n / ElemsPerVec; vecIdx += blockDim .x )
360+ {
361+ float4 vec = reinterpret_cast <float4 const *>(input + tokenIdx * n)[vecIdx];
362+ float maxAbsVal = getMaxAbs<T>(vec);
363+ perTokenMaxAbsVal = cuda_max (perTokenMaxAbsVal, maxAbsVal);
364+ }
365+ float globalScaleVal = 448 .f * 6 .f / perTokenMaxAbsVal;
366+ if (threadIdx .x == 0 )
367+ {
368+ globalScale[tokenIdx] = globalScaleVal;
369+ }
370+ }
371+ }
372+
373+ template <typename T>
374+ void computePerTokenGlobalScaleForFP4Quantization (int b, int m, int n, T const * input, int const * tokensPerBatch,
375+ float * globalScale, int multiProcessorCount, cudaStream_t stream)
376+ {
377+
378+ static constexpr int ElemsPerVec = 16 / sizeof (T);
379+ TLLM_CHECK (n % (ElemsPerVec * 32 ) == 0 and b > 0 );
380+ dim3 block (std::min (n / ElemsPerVec, 1024 ));
381+ dim3 grid (b, std::max (1 , std::min (m, multiProcessorCount / b)));
382+
383+ cudaLaunchConfig_t config;
384+ config.gridDim = grid;
385+ config.blockDim = block;
386+ config.dynamicSmemBytes = 0 ;
387+ config.stream = stream;
388+ cudaLaunchAttribute attrs[1 ];
389+ attrs[0 ].id = cudaLaunchAttributeProgrammaticStreamSerialization;
390+ attrs[0 ].val .programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL ();
391+ config.numAttrs = 1 ;
392+ config.attrs = attrs;
393+ TLLM_CUDA_CHECK (cudaLaunchKernelEx (
394+ &config, computePerTokenGlobalScaleForFP4QuantizationKernel<T>, b, m, n, input, tokensPerBatch, globalScale));
395+ }
396+
305397// Instantiate the function.
306398template void invokeFP4Quantization<half, 16 >(int b, int m, int n, half const * input, float const * SFScale,
307399 int64_t * output, int32_t * SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
@@ -311,6 +403,8 @@ template void invokeFP4Quantization<half, 32>(int b, int m, int n, half const* i
311403 cudaStream_t stream);
312404template void invokeMxFP8Quantization<half>(int b, int m, int n, int padded_n, half const * input, int64_t * output,
313405 int32_t * SFOuput, QuantizationSFLayout layout, int multiProcessorCount, cudaStream_t stream);
406+ template void computePerTokenGlobalScaleForFP4Quantization<half>(int b, int m, int n, half const * input,
407+ int const * tokensPerBatch, float * globalScale, int multiProcessorCount, cudaStream_t stream);
314408#ifdef ENABLE_BF16
315409template void invokeFP4Quantization<__nv_bfloat16, 16 >(int b, int m, int n, __nv_bfloat16 const * input,
316410 float const * SFScale, int64_t * output, int32_t * SFOuput, bool useUE8M0, QuantizationSFLayout layout,
@@ -320,6 +414,9 @@ template void invokeFP4Quantization<__nv_bfloat16, 32>(int b, int m, int n, __nv
320414 int multiProcessorCount, cudaStream_t stream);
321415template void invokeMxFP8Quantization<__nv_bfloat16>(int b, int m, int n, int padded_n, __nv_bfloat16 const * input,
322416 int64_t * output, int32_t * SFOuput, QuantizationSFLayout layout, int multiProcessorCount, cudaStream_t stream);
417+ template void computePerTokenGlobalScaleForFP4Quantization<__nv_bfloat16>(int b, int m, int n,
418+ __nv_bfloat16 const * input, int const * tokensPerBatch, float * globalScale, int multiProcessorCount,
419+ cudaStream_t stream);
323420#endif
324421
325422#ifdef ENABLE_FP8
0 commit comments