Skip to content

Commit 0a0159f

Browse files
authored
[https://nvbugs/5378031] [feat] W4A8 AWQ MoE supports Per Expert Pre-quant Scale Factor for PyT backend (#7286)
Signed-off-by: Min Yu <[email protected]>
1 parent e75b4f9 commit 0a0159f

File tree

7 files changed

+217
-72
lines changed

7 files changed

+217
-72
lines changed

cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -859,7 +859,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
859859

860860
T const* applyPrequantScale(void* smoothed_act, void const* permuted_data, void const* prequant_scales,
861861
int64_t const* num_valid_tokens_ptr, int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq,
862-
cudaStream_t stream);
862+
cudaStream_t stream, int64_t* expert_first_token_offset = nullptr, int const num_experts_per_node = 0);
863863

864864
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType> moe_gemm_runner_;
865865
std::unique_ptr<DeepSeekBlockScaleGemmRunner> blockscale_gemm_runner_;

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
#include "tensorrt_llm/common/dataType.h"
5353
#include "tensorrt_llm/common/envUtils.h"
5454
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
55+
#include "tensorrt_llm/kernels/moe_utils.cuh"
5556
#include "tensorrt_llm/kernels/preQuantScaleKernel.h"
5657
#include "tensorrt_llm/kernels/quantization.cuh"
5758

@@ -897,27 +898,6 @@ void threeStepBuildExpertMapsSortFirstToken(int const* token_selected_experts, i
897898
}
898899

899900
// ============================== Infer GEMM sizes =================================
900-
// TODO Could linear search be better for small # experts
901-
template <class T>
902-
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target)
903-
{
904-
int64_t low = 0, high = arr_length - 1, target_location = -1;
905-
while (low <= high)
906-
{
907-
int64_t mid = (low + high) / 2;
908-
909-
if (sorted_indices[mid] >= target)
910-
{
911-
high = mid - 1;
912-
}
913-
else
914-
{
915-
low = mid + 1;
916-
target_location = mid;
917-
}
918-
}
919-
return target_location + 1;
920-
}
921901

922902
template <class T>
923903
using sizeof_bits = cutlass::sizeof_bits<typename cutlass_kernels::TllmToCutlassTypeAdapter<std::remove_cv_t<T>>::type>;
@@ -1508,14 +1488,18 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
15081488
static_assert(!is_nvfp4 && !is_mxfp8, "NVFP4 and MXFP8 are not supported for AWQ");
15091489
static_assert(!std::is_same_v<InputActivationsType, ExpandedActivationsType>,
15101490
"Input and output types must be different for AWQ");
1491+
int64_t expert = findTotalEltsLessThanTarget(
1492+
expert_first_token_offset, num_experts_per_node, (int64_t) permuted_row + 1)
1493+
- 1;
15111494
for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride)
15121495
{
15131496
auto frag_elems = source_row_ptr[elem_index];
15141497

15151498
CUTLASS_PRAGMA_UNROLL
15161499
for (int e = 0; e < ELEM_PER_THREAD; e++)
15171500
{
1518-
frag_elems[e] = frag_elems[e] * prequant_scales[elem_index * ELEM_PER_THREAD + e];
1501+
frag_elems[e]
1502+
= frag_elems[e] * prequant_scales[expert * hidden_size + elem_index * ELEM_PER_THREAD + e];
15191503
}
15201504

15211505
dest_row_ptr[elem_index] = arrayConvert<DataElem, OutputElem>(frag_elems);
@@ -2918,7 +2902,8 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType, Ena
29182902
template <class T, class WeightType, class OutputType, class InputType, class ScaleBiasType, class Enable>
29192903
T const* CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType, Enable>::applyPrequantScale(
29202904
void* smoothed_act, void const* permuted_data, void const* prequant_scales, int64_t const* num_valid_tokens_ptr,
2921-
int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq, cudaStream_t stream)
2905+
int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq, cudaStream_t stream,
2906+
int64_t* expert_first_token_offset, int const num_experts_per_node)
29222907
{
29232908
T const* gemm_input;
29242909
bool use_prequant_scale_kernel = use_awq && !std::is_same_v<T, WeightType>;
@@ -2928,10 +2913,20 @@ T const* CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType,
29282913
(!std::is_same_v<T, WeightType>), "Prequant scales are only used for different weight/activation type!");
29292914
if constexpr (!std::is_same_v<T, WeightType>)
29302915
{
2931-
tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher<UnfusedGemmOutputType, T>(
2932-
reinterpret_cast<T*>(smoothed_act), reinterpret_cast<UnfusedGemmOutputType const*>(permuted_data),
2933-
reinterpret_cast<UnfusedGemmOutputType const*>(prequant_scales), expanded_num_rows, seq_len,
2934-
num_valid_tokens_ptr, stream);
2916+
if (expert_first_token_offset != nullptr)
2917+
{
2918+
tensorrt_llm::kernels::apply_per_channel_scale_per_expert_kernel_launcher<UnfusedGemmOutputType, T>(
2919+
reinterpret_cast<T*>(smoothed_act), reinterpret_cast<UnfusedGemmOutputType const*>(permuted_data),
2920+
reinterpret_cast<UnfusedGemmOutputType const*>(prequant_scales), expanded_num_rows, seq_len,
2921+
expert_first_token_offset, num_experts_per_node, num_valid_tokens_ptr, stream);
2922+
}
2923+
else
2924+
{
2925+
tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher<UnfusedGemmOutputType, T>(
2926+
reinterpret_cast<T*>(smoothed_act), reinterpret_cast<UnfusedGemmOutputType const*>(permuted_data),
2927+
reinterpret_cast<UnfusedGemmOutputType const*>(prequant_scales), expanded_num_rows, seq_len,
2928+
num_valid_tokens_ptr, stream);
2929+
}
29352930
}
29362931
gemm_input = reinterpret_cast<T const*>(smoothed_act);
29372932
}
@@ -3740,7 +3735,8 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
37403735
}
37413736

37423737
auto gemm2_input = applyPrequantScale(smoothed_act_, fc1_result_, quant_params.groupwise.fc2.act_scales,
3743-
num_valid_tokens_ptr, expanded_num_rows, inter_size, use_awq, stream);
3738+
num_valid_tokens_ptr, expanded_num_rows, inter_size, use_awq, stream, expert_first_token_offset_,
3739+
num_experts_per_node);
37443740
sync_check_cuda_error(stream);
37453741
Self::gemm2(moe_gemm_runner_, blockscale_gemm_runner, gemm2_input, fc2_result_, final_output,
37463742
expert_first_token_offset_, gemm2_tma_ws_input, fc2_expert_weights, fc2_expert_biases, fc2_int_scales,
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#pragma once
19+
20+
namespace tensorrt_llm
21+
{
22+
namespace kernels
23+
{
24+
25+
// TODO Could linear search be better for small # experts
26+
template <class T>
27+
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target)
28+
{
29+
int64_t low = 0, high = arr_length - 1, target_location = -1;
30+
while (low <= high)
31+
{
32+
int64_t mid = (low + high) / 2;
33+
34+
if (sorted_indices[mid] >= target)
35+
{
36+
high = mid - 1;
37+
}
38+
else
39+
{
40+
low = mid + 1;
41+
target_location = mid;
42+
}
43+
}
44+
return target_location + 1;
45+
}
46+
47+
} // namespace kernels
48+
} // namespace tensorrt_llm

cpp/tensorrt_llm/kernels/preQuantScaleKernel.cu

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
* See the License for the specific language governing permissions and
1515
* limitations under the License.
1616
*/
17+
#include "tensorrt_llm/kernels/moe_utils.cuh"
1718
#include "tensorrt_llm/kernels/preQuantScaleKernel.h"
1819

1920
namespace tensorrt_llm
@@ -41,7 +42,7 @@ struct Vec2Type<__nv_bfloat16>
4142

4243
template <typename T_in, typename T_out, int kProcessRows, typename AccessType>
4344
__global__ void apply_per_channel_scale(T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale, int rows,
44-
int cols, int64_t const* num_valid_tokens_ptr)
45+
int cols, int64_t const* num_valid_tokens_ptr, int64_t* expert_first_token_offset, int const num_experts_per_node)
4546
{
4647
static constexpr int kElems = sizeof(AccessType) / sizeof(T_in);
4748
T_in scale[kElems], act_vec[kElems];
@@ -53,11 +54,19 @@ __global__ void apply_per_channel_scale(T_out* smoothed_act, T_in const* act, T_
5354
return;
5455
act += row_offset * kProcessRows * cols;
5556
smoothed_act += row_offset * kProcessRows * cols;
56-
*reinterpret_cast<AccessType*>(scale) = reinterpret_cast<AccessType const*>(per_channel_scale)[col_offset];
5757
#pragma unroll
5858
for (int i = 0; i < kProcessRows; ++i)
5959
{
6060
*reinterpret_cast<AccessType*>(act_vec) = reinterpret_cast<AccessType const*>(act + i * cols)[col_offset];
61+
int expert = 0;
62+
if (expert_first_token_offset != nullptr)
63+
{
64+
expert = findTotalEltsLessThanTarget(
65+
expert_first_token_offset, num_experts_per_node, (int64_t) row_offset * kProcessRows + i + 1)
66+
- 1;
67+
}
68+
*reinterpret_cast<AccessType*>(scale)
69+
= reinterpret_cast<AccessType const*>(per_channel_scale)[expert * cols / kElems + col_offset];
6170
if constexpr ((std::is_same_v<T_in, half>
6271
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
6372
|| std::is_same_v<T_in, __nv_bfloat16>
@@ -98,13 +107,14 @@ __global__ void apply_per_channel_scale(T_out* smoothed_act, T_in const* act, T_
98107

99108
template <typename T_in, typename T_out, int kProcessRows, typename AccessType = float4>
100109
void apply_per_channel_scale_kernel_launcher_(T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale,
101-
int rows, int cols, int64_t const* num_valid_tokens_ptr = nullptr, cudaStream_t stream = 0)
110+
int rows, int cols, int64_t const* num_valid_tokens_ptr = nullptr, cudaStream_t stream = 0,
111+
int64_t* expert_first_token_offset = nullptr, int const num_experts_per_node = 0)
102112
{
103113
static constexpr int kElems = sizeof(AccessType) / sizeof(T_in);
104114
dim3 block(128);
105115
dim3 grid((rows + kProcessRows - 1) / kProcessRows, (cols / kElems + block.x - 1) / block.x);
106-
apply_per_channel_scale<T_in, T_out, kProcessRows, AccessType>
107-
<<<grid, block, 0, stream>>>(smoothed_act, act, per_channel_scale, rows, cols, num_valid_tokens_ptr);
116+
apply_per_channel_scale<T_in, T_out, kProcessRows, AccessType><<<grid, block, 0, stream>>>(smoothed_act, act,
117+
per_channel_scale, rows, cols, num_valid_tokens_ptr, expert_first_token_offset, num_experts_per_node);
108118
}
109119

110120
template <typename T_in, typename T_out>
@@ -134,6 +144,34 @@ void apply_per_channel_scale_kernel_launcher(T_out* smoothed_act, T_in const* ac
134144
}
135145
}
136146

147+
template <typename T_in, typename T_out>
148+
void apply_per_channel_scale_per_expert_kernel_launcher(T_out* smoothed_act, T_in const* act,
149+
T_in const* per_channel_scale, int rows, int cols, int64_t* expert_first_token_offset,
150+
int const num_experts_per_node, int64_t const* num_valid_tokens_ptr, cudaStream_t stream)
151+
{
152+
uint64_t elems = static_cast<uint64_t>(rows) * static_cast<uint64_t>(cols);
153+
if (elems < 2048 * 2048)
154+
{
155+
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 1, float4>(smoothed_act, act, per_channel_scale, rows,
156+
cols, num_valid_tokens_ptr, stream, expert_first_token_offset, num_experts_per_node);
157+
}
158+
else if (elems < 4096 * 4096)
159+
{
160+
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 4, float4>(smoothed_act, act, per_channel_scale, rows,
161+
cols, num_valid_tokens_ptr, stream, expert_first_token_offset, num_experts_per_node);
162+
}
163+
else if (elems < 8192 * 8192)
164+
{
165+
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 8, float4>(smoothed_act, act, per_channel_scale, rows,
166+
cols, num_valid_tokens_ptr, stream, expert_first_token_offset, num_experts_per_node);
167+
}
168+
else
169+
{
170+
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 16, float4>(smoothed_act, act, per_channel_scale, rows,
171+
cols, num_valid_tokens_ptr, stream, expert_first_token_offset, num_experts_per_node);
172+
}
173+
}
174+
137175
#define INSTANTIATE_PREQUANT_SCALE(T_in, T_out) \
138176
template void apply_per_channel_scale_kernel_launcher<T_in, T_out>(T_out * smoothed_act, const T_in* act, \
139177
const T_in* per_channel_scale, int rows, int cols, int64_t const* num_valid_tokens_ptr, cudaStream_t stream)
@@ -150,5 +188,22 @@ INSTANTIATE_PREQUANT_SCALE(__nv_bfloat16, __nv_fp8_e4m3);
150188
#endif
151189
#endif
152190

191+
#define INSTANTIATE_PREQUANT_SCALE_PER_EXPERT(T_in, T_out) \
192+
template void apply_per_channel_scale_per_expert_kernel_launcher<T_in, T_out>(T_out * smoothed_act, \
193+
const T_in* act, const T_in* per_channel_scale, int rows, int cols, int64_t* expert_first_token_offset, \
194+
int const num_experts_per_node, int64_t const* num_valid_tokens_ptr, cudaStream_t stream)
195+
196+
INSTANTIATE_PREQUANT_SCALE_PER_EXPERT(half, half);
197+
#if defined(ENABLE_FP8)
198+
INSTANTIATE_PREQUANT_SCALE_PER_EXPERT(half, __nv_fp8_e4m3);
199+
#endif
200+
201+
#if defined(ENABLE_BF16)
202+
INSTANTIATE_PREQUANT_SCALE_PER_EXPERT(__nv_bfloat16, __nv_bfloat16);
203+
#if defined(ENABLE_FP8)
204+
INSTANTIATE_PREQUANT_SCALE_PER_EXPERT(__nv_bfloat16, __nv_fp8_e4m3);
205+
#endif
206+
#endif
207+
153208
} // namespace kernels
154209
} // namespace tensorrt_llm

cpp/tensorrt_llm/kernels/preQuantScaleKernel.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,10 @@ template <typename T_in, typename T_out = T_in>
3939
void apply_per_channel_scale_kernel_launcher(T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale,
4040
int rows, int cols, int64_t const* num_valid_tokens_ptr = nullptr, cudaStream_t stream = 0);
4141

42+
template <typename T_in, typename T_out = T_in>
43+
void apply_per_channel_scale_per_expert_kernel_launcher(T_out* smoothed_act, T_in const* act,
44+
T_in const* per_channel_scale, int rows, int cols, int64_t* expert_first_token_offset,
45+
int const num_experts_per_node, int64_t const* num_valid_tokens_ptr, cudaStream_t stream);
46+
4247
} // namespace kernels
4348
} // namespace tensorrt_llm

0 commit comments

Comments
 (0)