Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface

T const* applyPrequantScale(void* smoothed_act, void const* permuted_data, void const* prequant_scales,
int64_t const* num_valid_tokens_ptr, int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq,
cudaStream_t stream);
cudaStream_t stream, int64_t* expert_first_token_offset = nullptr, int const num_experts_per_node = 0);

MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType> moe_gemm_runner_;
std::unique_ptr<DeepSeekBlockScaleGemmRunner> blockscale_gemm_runner_;
Expand Down
52 changes: 24 additions & 28 deletions cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#include "tensorrt_llm/common/dataType.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
#include "tensorrt_llm/kernels/moe_utils.cuh"
#include "tensorrt_llm/kernels/preQuantScaleKernel.h"
#include "tensorrt_llm/kernels/quantization.cuh"

Expand Down Expand Up @@ -897,27 +898,6 @@ void threeStepBuildExpertMapsSortFirstToken(int const* token_selected_experts, i
}

// ============================== Infer GEMM sizes =================================
// TODO Could linear search be better for small # experts
template <class T>
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target)
{
int64_t low = 0, high = arr_length - 1, target_location = -1;
while (low <= high)
{
int64_t mid = (low + high) / 2;

if (sorted_indices[mid] >= target)
{
high = mid - 1;
}
else
{
low = mid + 1;
target_location = mid;
}
}
return target_location + 1;
}

template <class T>
using sizeof_bits = cutlass::sizeof_bits<typename cutlass_kernels::TllmToCutlassTypeAdapter<std::remove_cv_t<T>>::type>;
Expand Down Expand Up @@ -1508,14 +1488,18 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
static_assert(!is_nvfp4 && !is_mxfp8, "NVFP4 and MXFP8 are not supported for AWQ");
static_assert(!std::is_same_v<InputActivationsType, ExpandedActivationsType>,
"Input and output types must be different for AWQ");
int64_t expert = findTotalEltsLessThanTarget(
expert_first_token_offset, num_experts_per_node, (int64_t) permuted_row + 1)
- 1;
for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride)
{
auto frag_elems = source_row_ptr[elem_index];

CUTLASS_PRAGMA_UNROLL
for (int e = 0; e < ELEM_PER_THREAD; e++)
{
frag_elems[e] = frag_elems[e] * prequant_scales[elem_index * ELEM_PER_THREAD + e];
frag_elems[e]
= frag_elems[e] * prequant_scales[expert * hidden_size + elem_index * ELEM_PER_THREAD + e];
}

dest_row_ptr[elem_index] = arrayConvert<DataElem, OutputElem>(frag_elems);
Expand Down Expand Up @@ -2918,7 +2902,8 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType, Ena
template <class T, class WeightType, class OutputType, class InputType, class ScaleBiasType, class Enable>
T const* CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType, Enable>::applyPrequantScale(
void* smoothed_act, void const* permuted_data, void const* prequant_scales, int64_t const* num_valid_tokens_ptr,
int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq, cudaStream_t stream)
int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq, cudaStream_t stream,
int64_t* expert_first_token_offset, int const num_experts_per_node)
{
T const* gemm_input;
bool use_prequant_scale_kernel = use_awq && !std::is_same_v<T, WeightType>;
Expand All @@ -2928,10 +2913,20 @@ T const* CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType,
(!std::is_same_v<T, WeightType>), "Prequant scales are only used for different weight/activation type!");
if constexpr (!std::is_same_v<T, WeightType>)
{
tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher<UnfusedGemmOutputType, T>(
reinterpret_cast<T*>(smoothed_act), reinterpret_cast<UnfusedGemmOutputType const*>(permuted_data),
reinterpret_cast<UnfusedGemmOutputType const*>(prequant_scales), expanded_num_rows, seq_len,
num_valid_tokens_ptr, stream);
if (expert_first_token_offset != nullptr)
{
tensorrt_llm::kernels::apply_per_channel_scale_per_expert_kernel_launcher<UnfusedGemmOutputType, T>(
reinterpret_cast<T*>(smoothed_act), reinterpret_cast<UnfusedGemmOutputType const*>(permuted_data),
reinterpret_cast<UnfusedGemmOutputType const*>(prequant_scales), expanded_num_rows, seq_len,
expert_first_token_offset, num_experts_per_node, num_valid_tokens_ptr, stream);
}
else
{
tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher<UnfusedGemmOutputType, T>(
reinterpret_cast<T*>(smoothed_act), reinterpret_cast<UnfusedGemmOutputType const*>(permuted_data),
reinterpret_cast<UnfusedGemmOutputType const*>(prequant_scales), expanded_num_rows, seq_len,
num_valid_tokens_ptr, stream);
}
}
gemm_input = reinterpret_cast<T const*>(smoothed_act);
}
Expand Down Expand Up @@ -3740,7 +3735,8 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
}

auto gemm2_input = applyPrequantScale(smoothed_act_, fc1_result_, quant_params.groupwise.fc2.act_scales,
num_valid_tokens_ptr, expanded_num_rows, inter_size, use_awq, stream);
num_valid_tokens_ptr, expanded_num_rows, inter_size, use_awq, stream, expert_first_token_offset_,
num_experts_per_node);
sync_check_cuda_error(stream);
Self::gemm2(moe_gemm_runner_, blockscale_gemm_runner, gemm2_input, fc2_result_, final_output,
expert_first_token_offset_, gemm2_tma_ws_input, fc2_expert_weights, fc2_expert_biases, fc2_int_scales,
Expand Down
48 changes: 48 additions & 0 deletions cpp/tensorrt_llm/kernels/moe_utils.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

namespace tensorrt_llm
{
namespace kernels
{

// TODO Could linear search be better for small # experts
template <class T>
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target)
{
int64_t low = 0, high = arr_length - 1, target_location = -1;
while (low <= high)
{
int64_t mid = (low + high) / 2;

if (sorted_indices[mid] >= target)
{
high = mid - 1;
}
else
{
low = mid + 1;
target_location = mid;
}
}
return target_location + 1;
}

} // namespace kernels
} // namespace tensorrt_llm
65 changes: 60 additions & 5 deletions cpp/tensorrt_llm/kernels/preQuantScaleKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/moe_utils.cuh"
#include "tensorrt_llm/kernels/preQuantScaleKernel.h"

namespace tensorrt_llm
Expand Down Expand Up @@ -41,7 +42,7 @@ struct Vec2Type<__nv_bfloat16>

template <typename T_in, typename T_out, int kProcessRows, typename AccessType>
__global__ void apply_per_channel_scale(T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale, int rows,
int cols, int64_t const* num_valid_tokens_ptr)
int cols, int64_t const* num_valid_tokens_ptr, int64_t* expert_first_token_offset, int const num_experts_per_node)
{
static constexpr int kElems = sizeof(AccessType) / sizeof(T_in);
T_in scale[kElems], act_vec[kElems];
Expand All @@ -53,11 +54,19 @@ __global__ void apply_per_channel_scale(T_out* smoothed_act, T_in const* act, T_
return;
act += row_offset * kProcessRows * cols;
smoothed_act += row_offset * kProcessRows * cols;
*reinterpret_cast<AccessType*>(scale) = reinterpret_cast<AccessType const*>(per_channel_scale)[col_offset];
#pragma unroll
for (int i = 0; i < kProcessRows; ++i)
{
*reinterpret_cast<AccessType*>(act_vec) = reinterpret_cast<AccessType const*>(act + i * cols)[col_offset];
int expert = 0;
if (expert_first_token_offset != nullptr)
{
expert = findTotalEltsLessThanTarget(
expert_first_token_offset, num_experts_per_node, (int64_t) row_offset * kProcessRows + i + 1)
- 1;
}
*reinterpret_cast<AccessType*>(scale)
= reinterpret_cast<AccessType const*>(per_channel_scale)[expert * cols / kElems + col_offset];
if constexpr ((std::is_same_v<T_in, half>
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|| std::is_same_v<T_in, __nv_bfloat16>
Expand Down Expand Up @@ -98,13 +107,14 @@ __global__ void apply_per_channel_scale(T_out* smoothed_act, T_in const* act, T_

template <typename T_in, typename T_out, int kProcessRows, typename AccessType = float4>
void apply_per_channel_scale_kernel_launcher_(T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale,
int rows, int cols, int64_t const* num_valid_tokens_ptr = nullptr, cudaStream_t stream = 0)
int rows, int cols, int64_t const* num_valid_tokens_ptr = nullptr, cudaStream_t stream = 0,
int64_t* expert_first_token_offset = nullptr, int const num_experts_per_node = 0)
{
static constexpr int kElems = sizeof(AccessType) / sizeof(T_in);
dim3 block(128);
dim3 grid((rows + kProcessRows - 1) / kProcessRows, (cols / kElems + block.x - 1) / block.x);
apply_per_channel_scale<T_in, T_out, kProcessRows, AccessType>
<<<grid, block, 0, stream>>>(smoothed_act, act, per_channel_scale, rows, cols, num_valid_tokens_ptr);
apply_per_channel_scale<T_in, T_out, kProcessRows, AccessType><<<grid, block, 0, stream>>>(smoothed_act, act,
per_channel_scale, rows, cols, num_valid_tokens_ptr, expert_first_token_offset, num_experts_per_node);
}

template <typename T_in, typename T_out>
Expand Down Expand Up @@ -134,6 +144,34 @@ void apply_per_channel_scale_kernel_launcher(T_out* smoothed_act, T_in const* ac
}
}

template <typename T_in, typename T_out>
void apply_per_channel_scale_per_expert_kernel_launcher(T_out* smoothed_act, T_in const* act,
T_in const* per_channel_scale, int rows, int cols, int64_t* expert_first_token_offset,
int const num_experts_per_node, int64_t const* num_valid_tokens_ptr, cudaStream_t stream)
{
uint64_t elems = static_cast<uint64_t>(rows) * static_cast<uint64_t>(cols);
if (elems < 2048 * 2048)
{
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 1, float4>(smoothed_act, act, per_channel_scale, rows,
cols, num_valid_tokens_ptr, stream, expert_first_token_offset, num_experts_per_node);
}
else if (elems < 4096 * 4096)
{
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 4, float4>(smoothed_act, act, per_channel_scale, rows,
cols, num_valid_tokens_ptr, stream, expert_first_token_offset, num_experts_per_node);
}
else if (elems < 8192 * 8192)
{
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 8, float4>(smoothed_act, act, per_channel_scale, rows,
cols, num_valid_tokens_ptr, stream, expert_first_token_offset, num_experts_per_node);
}
else
{
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 16, float4>(smoothed_act, act, per_channel_scale, rows,
cols, num_valid_tokens_ptr, stream, expert_first_token_offset, num_experts_per_node);
}
}

#define INSTANTIATE_PREQUANT_SCALE(T_in, T_out) \
template void apply_per_channel_scale_kernel_launcher<T_in, T_out>(T_out * smoothed_act, const T_in* act, \
const T_in* per_channel_scale, int rows, int cols, int64_t const* num_valid_tokens_ptr, cudaStream_t stream)
Expand All @@ -150,5 +188,22 @@ INSTANTIATE_PREQUANT_SCALE(__nv_bfloat16, __nv_fp8_e4m3);
#endif
#endif

#define INSTANTIATE_PREQUANT_SCALE_PER_EXPERT(T_in, T_out) \
template void apply_per_channel_scale_per_expert_kernel_launcher<T_in, T_out>(T_out * smoothed_act, \
const T_in* act, const T_in* per_channel_scale, int rows, int cols, int64_t* expert_first_token_offset, \
int const num_experts_per_node, int64_t const* num_valid_tokens_ptr, cudaStream_t stream)

INSTANTIATE_PREQUANT_SCALE_PER_EXPERT(half, half);
#if defined(ENABLE_FP8)
INSTANTIATE_PREQUANT_SCALE_PER_EXPERT(half, __nv_fp8_e4m3);
#endif

#if defined(ENABLE_BF16)
INSTANTIATE_PREQUANT_SCALE_PER_EXPERT(__nv_bfloat16, __nv_bfloat16);
#if defined(ENABLE_FP8)
INSTANTIATE_PREQUANT_SCALE_PER_EXPERT(__nv_bfloat16, __nv_fp8_e4m3);
#endif
#endif

} // namespace kernels
} // namespace tensorrt_llm
5 changes: 5 additions & 0 deletions cpp/tensorrt_llm/kernels/preQuantScaleKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,10 @@ template <typename T_in, typename T_out = T_in>
void apply_per_channel_scale_kernel_launcher(T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale,
int rows, int cols, int64_t const* num_valid_tokens_ptr = nullptr, cudaStream_t stream = 0);

template <typename T_in, typename T_out = T_in>
void apply_per_channel_scale_per_expert_kernel_launcher(T_out* smoothed_act, T_in const* act,
T_in const* per_channel_scale, int rows, int cols, int64_t* expert_first_token_offset,
int const num_experts_per_node, int64_t const* num_valid_tokens_ptr, cudaStream_t stream);

} // namespace kernels
} // namespace tensorrt_llm
Loading