Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions cpp/tensorrt_llm/kernels/quantization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,98 @@ void invokeBlockScaleInterleaveReverse(
block_scale_interleave_reverse_kernel<<<grid, block, 0, stream>>>(b, m, n, SFIn, SFOutput);
}

template <typename T>
struct VecTypeImpl
{
using type = T;
};

template <>
struct VecTypeImpl<half>
{
using type = half2;
};

template <>
struct VecTypeImpl<__nv_bfloat16>
{
using type = __nv_bfloat162;
};

template <typename T>
using VecType = typename VecTypeImpl<T>::type;

template <typename T>
__device__ float getMaxAbs(float4& vec)
{
auto absMaxVec = cuda_abs(reinterpret_cast<VecType<T>*>(&vec)[0]);
for (int i = 1; i < 4; ++i)
{
absMaxVec = cuda_max(absMaxVec, cuda_abs(reinterpret_cast<VecType<T>*>(&vec)[i]));
}
float absMaxVal;
if constexpr (sizeof(T) == 4)
{
absMaxVal = static_cast<float>(absMaxVec);
}
else
{
absMaxVal = static_cast<float>(cuda_max(absMaxVec.x, absMaxVec.y));
}
tensorrt_llm::common::blockReduceMaxV2<float, 1>(&absMaxVal);
return absMaxVal;
}

template <typename T>
__global__ void computePerTokenGlobalScaleForFP4QuantizationKernel(
int b, int m, int n, T const* input, int const* tokensPerBatch, float* globalScale)
{
static constexpr int ElemsPerVec = 16 / sizeof(T);
int batchIdx = blockIdx.x;
int realTokensNum = (tokensPerBatch == nullptr) ? m : tokensPerBatch[batchIdx];
input += batchIdx * m * n;
globalScale += batchIdx * m;
for (int tokenIdx = blockIdx.y; tokenIdx < realTokensNum; tokenIdx += gridDim.y)
{
float perTokenMaxAbsVal = 0.f;
for (int vecIdx = threadIdx.x; vecIdx < n / ElemsPerVec; vecIdx += blockDim.x)
{
float4 vec = reinterpret_cast<float4 const*>(input + tokenIdx * n)[vecIdx];
float maxAbsVal = getMaxAbs<T>(vec);
perTokenMaxAbsVal = cuda_max(perTokenMaxAbsVal, maxAbsVal);
}
float globalScaleVal = 448.f * 6.f / perTokenMaxAbsVal;
if (threadIdx.x == 0)
{
globalScale[tokenIdx] = globalScaleVal;
}
}
}

template <typename T>
void computePerTokenGlobalScaleForFP4Quantization(int b, int m, int n, T const* input, int const* tokensPerBatch,
float* globalScale, int multiProcessorCount, cudaStream_t stream)
{

static constexpr int ElemsPerVec = 16 / sizeof(T);
TLLM_CHECK(n % (ElemsPerVec * 32) == 0 and b > 0);
dim3 block(std::min(n / ElemsPerVec, 1024));
dim3 grid(b, std::max(1, std::min(m, multiProcessorCount / b)));

cudaLaunchConfig_t config;
config.gridDim = grid;
config.blockDim = block;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
TLLM_CUDA_CHECK(cudaLaunchKernelEx(
&config, computePerTokenGlobalScaleForFP4QuantizationKernel<T>, b, m, n, input, tokensPerBatch, globalScale));
}

// Instantiate the function.
template void invokeFP4Quantization<half, 16>(int b, int m, int n, half const* input, float const* SFScale,
int64_t* output, int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
Expand All @@ -311,6 +403,8 @@ template void invokeFP4Quantization<half, 32>(int b, int m, int n, half const* i
cudaStream_t stream);
template void invokeMxFP8Quantization<half>(int b, int m, int n, int padded_n, half const* input, int64_t* output,
int32_t* SFOuput, QuantizationSFLayout layout, int multiProcessorCount, cudaStream_t stream);
template void computePerTokenGlobalScaleForFP4Quantization<half>(int b, int m, int n, half const* input,
int const* tokensPerBatch, float* globalScale, int multiProcessorCount, cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeFP4Quantization<__nv_bfloat16, 16>(int b, int m, int n, __nv_bfloat16 const* input,
float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout,
Expand All @@ -320,6 +414,9 @@ template void invokeFP4Quantization<__nv_bfloat16, 32>(int b, int m, int n, __nv
int multiProcessorCount, cudaStream_t stream);
template void invokeMxFP8Quantization<__nv_bfloat16>(int b, int m, int n, int padded_n, __nv_bfloat16 const* input,
int64_t* output, int32_t* SFOuput, QuantizationSFLayout layout, int multiProcessorCount, cudaStream_t stream);
template void computePerTokenGlobalScaleForFP4Quantization<__nv_bfloat16>(int b, int m, int n,
__nv_bfloat16 const* input, int const* tokensPerBatch, float* globalScale, int multiProcessorCount,
cudaStream_t stream);
#endif

#ifdef ENABLE_FP8
Expand Down
4 changes: 4 additions & 0 deletions cpp/tensorrt_llm/kernels/quantization.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,9 @@ void invokeBlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded,
void invokeBlockScaleInterleaveReverse(
int b, int m, int n, uint8_t const* SFIn, uint8_t* SFOutput, int multiProcessorCount, cudaStream_t stream = 0);

template <typename T>
void computePerTokenGlobalScaleForFP4Quantization(int b, int m, int n, T const* input, int const* tokensPerBatch,
float* globalScale, int multiProcessorCount, cudaStream_t stream = 0);

} // namespace kernels
} // namespace tensorrt_llm
79 changes: 79 additions & 0 deletions cpp/tensorrt_llm/thop/fp4Quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,83 @@ std::tuple<at::Tensor, at::Tensor> fp4_quantize(at::Tensor const& self, std::opt

return {valueE2M1, scaleFP8SF};
}

at::Tensor calculate_nvfp4_global_scale(at::Tensor const& input, std::optional<at::Tensor> const& tokensPerBatch)
{
CHECK_TH_CUDA(input);
CHECK_CONTIGUOUS(input);

auto const& inputShape = input.sizes();
auto const& rank = inputShape.size();

TORCH_CHECK(rank >= 2 && rank <= 3);

// Calculate batch and token numbers
int64_t batch_size = 1;
int64_t token_num = 1;
int64_t hidden_size = inputShape[rank - 1];

if (rank == 2)
{
// [token_num, hidden_size]
token_num = inputShape[0];
batch_size = 1;
}
else if (rank == 3)
{
// [batch, token_num, hidden_size]
batch_size = inputShape[0];
token_num = inputShape[1];
}

// Create output tensor with same dimensions as input, but last dimension size is 1
std::vector<int64_t> outputShape(inputShape.begin(), inputShape.end());
outputShape[rank - 1] = 1;

at::Tensor globalScale = at::detail::empty_cuda(outputShape, torch::kFloat32, input.device(), std::nullopt);

// Get multi-processor count
static int multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();

// Prepare tokensPerBatch pointer - should have shape (batch_size)
int const* tokensPerBatchPtr = nullptr;
if (tokensPerBatch.has_value())
{
CHECK_TH_CUDA(tokensPerBatch.value());
CHECK_CONTIGUOUS(tokensPerBatch.value());

auto const& tokensShape = tokensPerBatch.value().sizes();
TORCH_CHECK(tokensShape.size() == 1, "tokensPerBatch should have exactly 1 dimension");
TORCH_CHECK(tokensShape[0] == batch_size, "tokensPerBatch first dimension must match input batch size");

tokensPerBatchPtr = tokensPerBatch.value().data_ptr<int>();
}

// Call corresponding kernel based on input data type
if (input.scalar_type() == at::ScalarType::Half)
{
tensorrt_llm::kernels::computePerTokenGlobalScaleForFP4Quantization<half>(batch_size, token_num, hidden_size,
reinterpret_cast<half const*>(input.data_ptr()), tokensPerBatchPtr, globalScale.data_ptr<float>(),
multiProcessorCount, at::cuda::getCurrentCUDAStream(input.get_device()));
}
else if (input.scalar_type() == at::ScalarType::BFloat16)
{
#ifdef ENABLE_BF16
tensorrt_llm::kernels::computePerTokenGlobalScaleForFP4Quantization<__nv_bfloat16>(batch_size, token_num,
hidden_size, reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr()), tokensPerBatchPtr,
globalScale.data_ptr<float>(), multiProcessorCount, at::cuda::getCurrentCUDAStream(input.get_device()));
#else
C10_THROW_ERROR(NotImplementedError, "BFloat16 must be enabled to compute global scale for bf16 tensor.");
#endif
}
else
{
C10_THROW_ERROR(
NotImplementedError, "calculate_nvfp4_global_scale only supports input tensor with dtypes fp16/bf16.");
}

return globalScale;
}
} // namespace torch_ext

TORCH_LIBRARY_FRAGMENT(trtllm, m)
Expand All @@ -161,9 +238,11 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
"fp4_quantize(Tensor input, Tensor? globalScale, int sfVecSize, bool sfUseUE8M0=False, bool "
"isSfSwizzledLayout=True) "
"-> (Tensor, Tensor)");
m.def("calculate_nvfp4_global_scale(Tensor input, Tensor? tokensPerBatch) -> Tensor");
}

TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("fp4_quantize", TORCH_FN(torch_ext::fp4_quantize));
m.impl("calculate_nvfp4_global_scale", TORCH_FN(torch_ext::calculate_nvfp4_global_scale));
}
4 changes: 3 additions & 1 deletion cpp/tensorrt_llm/thop/fp4Quantize.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,6 @@ namespace torch_ext
{
std::tuple<at::Tensor, at::Tensor> fp4_quantize(at::Tensor const& self, std::optional<at::Tensor> const& globalScale,
int64_t sfVecSize, bool sfUseUE8M0, bool isSfSwizzledLayout);
}

at::Tensor calculate_nvfp4_global_scale(at::Tensor const& input, std::optional<at::Tensor> const& tokensPerBatch);
} // namespace torch_ext
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ def _(
return (input.new_empty(output_shape, dtype=torch.uint8),
global_scale.new_empty(scale_shape, dtype=torch.uint8))

@torch.library.register_fake("trtllm::calculate_nvfp4_global_scale")
def _(input: torch.Tensor, tokens_per_batch: Optional[torch.Tensor]):
return input.new_empty((input.shape[:-1], 1), dtype=torch.float32)

@torch.library.register_fake("trtllm::moe_comm")
def _(
inputs: List[torch.Tensor],
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,8 +691,8 @@ def forward_chunk(
self.expert_size_per_partition,
num_tokens_per_expert_for_fused_moe, self.hidden_size)
if self.use_low_precision_combine:
global_scales = (448 * 6) / final_hidden_states.abs().max(
dim=-1, keepdim=True).values.to(torch.float32)
global_scales = torch.ops.trtllm.calculate_nvfp4_global_scale(
final_hidden_states, recv_expert_count)
final_hidden_states = self.deep_ep_buffer.low_latency_combine_fp4(
final_hidden_states, global_scales, deep_ep_topk_idx,
deep_ep_topk_weights, deep_ep_handle)
Expand Down
Loading