Skip to content

Commit 2bf15fd

Browse files
committed
Apply per expert act scale to FC2 for w4a8 moe on PyT flow
Signed-off-by: Min Yu <[email protected]>
1 parent c0d8b42 commit 2bf15fd

File tree

6 files changed

+160
-45
lines changed

6 files changed

+160
-45
lines changed

cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
858858

859859
T const* applyPrequantScale(void* smoothed_act, void const* permuted_data, void const* prequant_scales,
860860
int64_t const* num_valid_tokens_ptr, int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq,
861-
cudaStream_t stream);
861+
cudaStream_t stream, int64_t* expert_first_token_offset = nullptr, int const num_experts_per_node = 0);
862862

863863
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType> moe_gemm_runner_;
864864
std::unique_ptr<DeepSeekBlockScaleGemmRunner> blockscale_gemm_runner_;

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
#include "tensorrt_llm/common/dataType.h"
5353
#include "tensorrt_llm/common/envUtils.h"
5454
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
55+
#include "tensorrt_llm/kernels/moe_utils.h"
5556
#include "tensorrt_llm/kernels/preQuantScaleKernel.h"
5657
#include "tensorrt_llm/kernels/quantization.cuh"
5758

@@ -897,27 +898,6 @@ void threeStepBuildExpertMapsSortFirstToken(int const* token_selected_experts, i
897898
}
898899

899900
// ============================== Infer GEMM sizes =================================
900-
// TODO Could linear search be better for small # experts
901-
template <class T>
902-
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target)
903-
{
904-
int64_t low = 0, high = arr_length - 1, target_location = -1;
905-
while (low <= high)
906-
{
907-
int64_t mid = (low + high) / 2;
908-
909-
if (sorted_indices[mid] >= target)
910-
{
911-
high = mid - 1;
912-
}
913-
else
914-
{
915-
low = mid + 1;
916-
target_location = mid;
917-
}
918-
}
919-
return target_location + 1;
920-
}
921901

922902
template <class T>
923903
using sizeof_bits = cutlass::sizeof_bits<typename cutlass_kernels::TllmToCutlassTypeAdapter<std::remove_cv_t<T>>::type>;
@@ -2922,7 +2902,8 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType, Ena
29222902
template <class T, class WeightType, class OutputType, class InputType, class ScaleBiasType, class Enable>
29232903
T const* CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType, Enable>::applyPrequantScale(
29242904
void* smoothed_act, void const* permuted_data, void const* prequant_scales, int64_t const* num_valid_tokens_ptr,
2925-
int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq, cudaStream_t stream)
2905+
int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq, cudaStream_t stream,
2906+
int64_t* expert_first_token_offset, int const num_experts_per_node)
29262907
{
29272908
T const* gemm_input;
29282909
bool use_prequant_scale_kernel = use_awq && !std::is_same_v<T, WeightType>;
@@ -2932,10 +2913,20 @@ T const* CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType,
29322913
(!std::is_same_v<T, WeightType>), "Prequant scales are only used for different weight/activation type!");
29332914
if constexpr (!std::is_same_v<T, WeightType>)
29342915
{
2935-
tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher<UnfusedGemmOutputType, T>(
2936-
reinterpret_cast<T*>(smoothed_act), reinterpret_cast<UnfusedGemmOutputType const*>(permuted_data),
2937-
reinterpret_cast<UnfusedGemmOutputType const*>(prequant_scales), expanded_num_rows, seq_len,
2938-
num_valid_tokens_ptr, stream);
2916+
if (expert_first_token_offset != nullptr)
2917+
{
2918+
tensorrt_llm::kernels::apply_per_channel_scale_per_expert_kernel_launcher<UnfusedGemmOutputType, T>(
2919+
reinterpret_cast<T*>(smoothed_act), reinterpret_cast<UnfusedGemmOutputType const*>(permuted_data),
2920+
reinterpret_cast<UnfusedGemmOutputType const*>(prequant_scales), expanded_num_rows, seq_len,
2921+
expert_first_token_offset, num_experts_per_node, num_valid_tokens_ptr, stream);
2922+
}
2923+
else
2924+
{
2925+
tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher<UnfusedGemmOutputType, T>(
2926+
reinterpret_cast<T*>(smoothed_act), reinterpret_cast<UnfusedGemmOutputType const*>(permuted_data),
2927+
reinterpret_cast<UnfusedGemmOutputType const*>(prequant_scales), expanded_num_rows, seq_len,
2928+
num_valid_tokens_ptr, stream);
2929+
}
29392930
}
29402931
gemm_input = reinterpret_cast<T const*>(smoothed_act);
29412932
}
@@ -3744,7 +3735,8 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
37443735
}
37453736

37463737
auto gemm2_input = applyPrequantScale(smoothed_act_, fc1_result_, quant_params.groupwise.fc2.act_scales,
3747-
num_valid_tokens_ptr, expanded_num_rows, inter_size, use_awq, stream);
3738+
num_valid_tokens_ptr, expanded_num_rows, inter_size, use_awq, stream, expert_first_token_offset_,
3739+
num_experts_per_node);
37483740
sync_check_cuda_error(stream);
37493741
Self::gemm2(moe_gemm_runner_, blockscale_gemm_runner, gemm2_input, fc2_result_, final_output,
37503742
expert_first_token_offset_, gemm2_tma_ws_input, fc2_expert_weights, fc2_expert_biases, fc2_int_scales,
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#pragma once
19+
20+
namespace tensorrt_llm
21+
{
22+
namespace kernels
23+
{
24+
25+
// TODO Could linear search be better for small # experts
26+
template <class T>
27+
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target)
28+
{
29+
int64_t low = 0, high = arr_length - 1, target_location = -1;
30+
while (low <= high)
31+
{
32+
int64_t mid = (low + high) / 2;
33+
34+
if (sorted_indices[mid] >= target)
35+
{
36+
high = mid - 1;
37+
}
38+
else
39+
{
40+
low = mid + 1;
41+
target_location = mid;
42+
}
43+
}
44+
return target_location + 1;
45+
}
46+
47+
} // namespace kernels
48+
} // namespace tensorrt_llm

cpp/tensorrt_llm/kernels/preQuantScaleKernel.cu

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
* See the License for the specific language governing permissions and
1515
* limitations under the License.
1616
*/
17+
#include "tensorrt_llm/kernels/moe_utils.h"
1718
#include "tensorrt_llm/kernels/preQuantScaleKernel.h"
1819

1920
namespace tensorrt_llm
@@ -41,7 +42,7 @@ struct Vec2Type<__nv_bfloat16>
4142

4243
template <typename T_in, typename T_out, int kProcessRows, typename AccessType>
4344
__global__ void apply_per_channel_scale(T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale, int rows,
44-
int cols, int64_t const* num_valid_tokens_ptr)
45+
int cols, int64_t const* num_valid_tokens_ptr, int64_t* expert_first_token_offset, int const num_experts_per_node)
4546
{
4647
static constexpr int kElems = sizeof(AccessType) / sizeof(T_in);
4748
T_in scale[kElems], act_vec[kElems];
@@ -53,11 +54,19 @@ __global__ void apply_per_channel_scale(T_out* smoothed_act, T_in const* act, T_
5354
return;
5455
act += row_offset * kProcessRows * cols;
5556
smoothed_act += row_offset * kProcessRows * cols;
56-
*reinterpret_cast<AccessType*>(scale) = reinterpret_cast<AccessType const*>(per_channel_scale)[col_offset];
5757
#pragma unroll
5858
for (int i = 0; i < kProcessRows; ++i)
5959
{
6060
*reinterpret_cast<AccessType*>(act_vec) = reinterpret_cast<AccessType const*>(act + i * cols)[col_offset];
61+
int expert = 0;
62+
if (expert_first_token_offset != nullptr)
63+
{
64+
expert = findTotalEltsLessThanTarget(
65+
expert_first_token_offset, num_experts_per_node, (int64_t) row_offset * kProcessRows + i + 1)
66+
- 1;
67+
}
68+
*reinterpret_cast<AccessType*>(scale)
69+
= reinterpret_cast<AccessType const*>(per_channel_scale)[expert * cols / kElems + col_offset];
6170
if constexpr ((std::is_same_v<T_in, half>
6271
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
6372
|| std::is_same_v<T_in, __nv_bfloat16>
@@ -98,13 +107,14 @@ __global__ void apply_per_channel_scale(T_out* smoothed_act, T_in const* act, T_
98107

99108
template <typename T_in, typename T_out, int kProcessRows, typename AccessType = float4>
100109
void apply_per_channel_scale_kernel_launcher_(T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale,
101-
int rows, int cols, int64_t const* num_valid_tokens_ptr = nullptr, cudaStream_t stream = 0)
110+
int rows, int cols, int64_t const* num_valid_tokens_ptr = nullptr, cudaStream_t stream = 0,
111+
int64_t* expert_first_token_offset = nullptr, int const num_experts_per_node = 0)
102112
{
103113
static constexpr int kElems = sizeof(AccessType) / sizeof(T_in);
104114
dim3 block(128);
105115
dim3 grid((rows + kProcessRows - 1) / kProcessRows, (cols / kElems + block.x - 1) / block.x);
106-
apply_per_channel_scale<T_in, T_out, kProcessRows, AccessType>
107-
<<<grid, block, 0, stream>>>(smoothed_act, act, per_channel_scale, rows, cols, num_valid_tokens_ptr);
116+
apply_per_channel_scale<T_in, T_out, kProcessRows, AccessType><<<grid, block, 0, stream>>>(smoothed_act, act,
117+
per_channel_scale, rows, cols, num_valid_tokens_ptr, expert_first_token_offset, num_experts_per_node);
108118
}
109119

110120
template <typename T_in, typename T_out>
@@ -134,6 +144,34 @@ void apply_per_channel_scale_kernel_launcher(T_out* smoothed_act, T_in const* ac
134144
}
135145
}
136146

147+
template <typename T_in, typename T_out>
148+
void apply_per_channel_scale_per_expert_kernel_launcher(T_out* smoothed_act, T_in const* act,
149+
T_in const* per_channel_scale, int rows, int cols, int64_t* expert_first_token_offset,
150+
int const num_experts_per_node, int64_t const* num_valid_tokens_ptr, cudaStream_t stream)
151+
{
152+
uint64_t elems = static_cast<uint64_t>(rows) * static_cast<uint64_t>(cols);
153+
if (elems < 2048 * 2048)
154+
{
155+
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 1, float4>(smoothed_act, act, per_channel_scale, rows,
156+
cols, num_valid_tokens_ptr, stream, expert_first_token_offset, num_experts_per_node);
157+
}
158+
else if (elems < 4096 * 4096)
159+
{
160+
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 4, float4>(smoothed_act, act, per_channel_scale, rows,
161+
cols, num_valid_tokens_ptr, stream, expert_first_token_offset, num_experts_per_node);
162+
}
163+
else if (elems < 8192 * 8192)
164+
{
165+
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 8, float4>(smoothed_act, act, per_channel_scale, rows,
166+
cols, num_valid_tokens_ptr, stream, expert_first_token_offset, num_experts_per_node);
167+
}
168+
else
169+
{
170+
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 16, float4>(smoothed_act, act, per_channel_scale, rows,
171+
cols, num_valid_tokens_ptr, stream, expert_first_token_offset, num_experts_per_node);
172+
}
173+
}
174+
137175
#define INSTANTIATE_PREQUANT_SCALE(T_in, T_out) \
138176
template void apply_per_channel_scale_kernel_launcher<T_in, T_out>(T_out * smoothed_act, const T_in* act, \
139177
const T_in* per_channel_scale, int rows, int cols, int64_t const* num_valid_tokens_ptr, cudaStream_t stream)
@@ -150,5 +188,22 @@ INSTANTIATE_PREQUANT_SCALE(__nv_bfloat16, __nv_fp8_e4m3);
150188
#endif
151189
#endif
152190

191+
#define INSTANTIATE_PREQUANT_SCALE_PER_EXPERT(T_in, T_out) \
192+
template void apply_per_channel_scale_per_expert_kernel_launcher<T_in, T_out>(T_out * smoothed_act, \
193+
const T_in* act, const T_in* per_channel_scale, int rows, int cols, int64_t* expert_first_token_offset, \
194+
int const num_experts_per_node, int64_t const* num_valid_tokens_ptr, cudaStream_t stream)
195+
196+
INSTANTIATE_PREQUANT_SCALE_PER_EXPERT(half, half);
197+
#if defined(ENABLE_FP8)
198+
INSTANTIATE_PREQUANT_SCALE_PER_EXPERT(half, __nv_fp8_e4m3);
199+
#endif
200+
201+
#if defined(ENABLE_BF16)
202+
INSTANTIATE_PREQUANT_SCALE_PER_EXPERT(__nv_bfloat16, __nv_bfloat16);
203+
#if defined(ENABLE_FP8)
204+
INSTANTIATE_PREQUANT_SCALE_PER_EXPERT(__nv_bfloat16, __nv_fp8_e4m3);
205+
#endif
206+
#endif
207+
153208
} // namespace kernels
154209
} // namespace tensorrt_llm

cpp/tensorrt_llm/kernels/preQuantScaleKernel.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,10 @@ template <typename T_in, typename T_out = T_in>
3939
void apply_per_channel_scale_kernel_launcher(T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale,
4040
int rows, int cols, int64_t const* num_valid_tokens_ptr = nullptr, cudaStream_t stream = 0);
4141

42+
template <typename T_in, typename T_out = T_in>
43+
void apply_per_channel_scale_per_expert_kernel_launcher(T_out* smoothed_act, T_in const* act,
44+
T_in const* per_channel_scale, int rows, int cols, int64_t* expert_first_token_offset,
45+
int const num_experts_per_node, int64_t const* num_valid_tokens_ptr, cudaStream_t stream);
46+
4247
} // namespace kernels
4348
} // namespace tensorrt_llm

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,10 @@ def create_weights(self, module: torch.nn.Module):
918918
module.register_parameter("fc31_act_scale", fc31_act_scale)
919919

920920
fc2_act_scale = nn.Parameter(torch.empty(
921-
1, module.intermediate_size_per_partition, 1, dtype=module.dtype),
921+
module.expert_size_per_partition,
922+
module.intermediate_size_per_partition,
923+
1,
924+
dtype=module.dtype),
922925
requires_grad=False)
923926
module.register_parameter("fc2_act_scale", fc2_act_scale)
924927

@@ -1246,23 +1249,31 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
12461249
device=self.device)
12471250
for expert_id in module.initial_local_expert_ids
12481251
]
1249-
all_w2_pre_quant_scales_max = torch.max(
1250-
torch.stack(all_w2_pre_quant_scales).to(module.dtype),
1251-
dim=0).values
1252+
all_w2_pre_quant_scales = torch.stack(all_w2_pre_quant_scales).to(
1253+
module.dtype)
1254+
all_w2_input_scales = torch.stack(all_w2_input_scales).to(
1255+
module.dtype)
1256+
all_w2_pre_quant_scales_div_input_scales = (
1257+
all_w2_pre_quant_scales.permute(1, 0) *
1258+
(1 / (all_w2_input_scales.reshape(
1259+
module.expert_size_per_partition).float()))).permute(1, 0)
12521260
module.fc2_act_scale.data.copy_(
1253-
torch.ones_like(module.fc2_act_scale, device=self.device) *
1254-
(all_w2_pre_quant_scales_max.unsqueeze(-1)) *
1255-
(1 / all_w2_input_scales_max))
1261+
all_w2_pre_quant_scales_div_input_scales.reshape(
1262+
module.fc2_act_scale.shape))
12561263
# In vanilla ckpt (at least from ModelOpt), per-tensor weight_scale_2 is separately stored
12571264
all_w2_weight_scale_2 = [
12581265
load_weight_shard(weights[f"{expert_id}.w2.weight_scale_2"],
12591266
device=self.device)
12601267
for expert_id in module.initial_local_expert_ids
12611268
]
1262-
all_w2_weight_scale_2_max = torch.stack(all_w2_weight_scale_2).to(
1263-
module.dtype).max()
1264-
module.fc2_alpha.data.copy_(all_w2_weight_scale_2_max.float() *
1265-
all_w2_input_scales_max.float())
1269+
all_w2_weight_scale_2 = torch.stack(all_w2_weight_scale_2).to(
1270+
module.dtype)
1271+
all_w2_weight_scale_2_mul_input_scales = (
1272+
all_w2_weight_scale_2.reshape(module.expert_size_per_partition,
1273+
1) *
1274+
all_w2_input_scales.reshape(module.expert_size_per_partition,
1275+
1))
1276+
module.fc2_alpha.data.copy_(all_w2_weight_scale_2_mul_input_scales)
12661277

12671278
# Per-group weight_scale
12681279
all_w2_scales = [
@@ -1281,7 +1292,11 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
12811292
module.dtype)
12821293

12831294
if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
1284-
w2_scales /= all_w2_weight_scale_2_max.float()
1295+
w2_scales = w2_scales.permute(1, 2, 0)
1296+
all_w2_weight_scale_2 = all_w2_weight_scale_2.reshape(
1297+
module.expert_size_per_partition)
1298+
w2_scales /= (all_w2_weight_scale_2.float())
1299+
w2_scales = w2_scales.permute(2, 0, 1)
12851300
w2_s_shape = w2_scales.shape
12861301
w2_scales_interleaved = w2_scales.reshape(
12871302
w2_s_shape[0], w2_s_shape[1],

0 commit comments

Comments
 (0)