Skip to content

Commit c0d8b42

Browse files
committed
Apply per expert act scale to FC1 for w4a8 moe on PyT flow
Signed-off-by: Min Yu <[email protected]>
1 parent ab26d21 commit c0d8b42

File tree

2 files changed

+49
-18
lines changed

2 files changed

+49
-18
lines changed

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1508,14 +1508,18 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
15081508
static_assert(!is_nvfp4 && !is_mxfp8, "NVFP4 and MXFP8 are not supported for AWQ");
15091509
static_assert(!std::is_same_v<InputActivationsType, ExpandedActivationsType>,
15101510
"Input and output types must be different for AWQ");
1511+
int64_t expert = findTotalEltsLessThanTarget(
1512+
expert_first_token_offset, num_experts_per_node, (int64_t) permuted_row + 1)
1513+
- 1;
15111514
for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride)
15121515
{
15131516
auto frag_elems = source_row_ptr[elem_index];
15141517

15151518
CUTLASS_PRAGMA_UNROLL
15161519
for (int e = 0; e < ELEM_PER_THREAD; e++)
15171520
{
1518-
frag_elems[e] = frag_elems[e] * prequant_scales[elem_index * ELEM_PER_THREAD + e];
1521+
frag_elems[e]
1522+
= frag_elems[e] * prequant_scales[expert * hidden_size + elem_index * ELEM_PER_THREAD + e];
15191523
}
15201524

15211525
dest_row_ptr[elem_index] = arrayConvert<DataElem, OutputElem>(frag_elems);

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -910,9 +910,10 @@ def create_weights(self, module: torch.nn.Module):
910910
module.intermediate_size_per_partition // 2)
911911

912912
# Multiply act with reciprocal of per-channel pre_quant_scale * per-tensor input_scale
913-
fc31_act_scale = nn.Parameter(torch.empty(1,
914-
module.hidden_size,
915-
dtype=module.dtype),
913+
fc31_act_scale = nn.Parameter(torch.empty(
914+
module.expert_size_per_partition,
915+
module.hidden_size,
916+
dtype=module.dtype),
916917
requires_grad=False)
917918
module.register_parameter("fc31_act_scale", fc31_act_scale)
918919

@@ -1125,15 +1126,29 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
11251126
device=self.device)
11261127
for expert_id in module.initial_local_expert_ids
11271128
]
1128-
all_w3_w1_pre_quant_scales_max = torch.max(
1129-
torch.stack(all_w3_pre_quant_scales +
1130-
all_w1_pre_quant_scales).to(module.dtype),
1129+
all_w3_w1_pre_quant_scales_greater = torch.max(
1130+
torch.stack([
1131+
torch.stack(all_w3_pre_quant_scales),
1132+
torch.stack(all_w1_pre_quant_scales)
1133+
]).to(module.dtype),
1134+
dim=0,
1135+
).values.permute(1, 0)
1136+
1137+
all_w3_w1_input_scales_greater = torch.max(
1138+
torch.stack([
1139+
torch.stack(all_w3_input_scales),
1140+
torch.stack(all_w1_input_scales)
1141+
]).to(module.dtype),
11311142
dim=0,
11321143
).values
1144+
1145+
all_w3_w1_pre_quant_scales_div_input_scales = (
1146+
all_w3_w1_pre_quant_scales_greater *
1147+
(1 / all_w3_w1_input_scales_greater.reshape(
1148+
1, module.expert_size_per_partition).float()))
1149+
11331150
module.fc31_act_scale.data.copy_(
1134-
torch.ones_like(module.fc31_act_scale, device=self.device) *
1135-
(all_w3_w1_pre_quant_scales_max) *
1136-
(1 / all_w3_w1_input_scales_max))
1151+
all_w3_w1_pre_quant_scales_div_input_scales.permute(1, 0))
11371152
# In vanilla ckpt (at least from ModelOpt), per-tensor weight_scale_2 is separately stored
11381153
all_w3_weight_scale_2 = [
11391154
load_weight_shard(weights[f"{expert_id}.w3.weight_scale_2"],
@@ -1145,13 +1160,21 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
11451160
device=self.device)
11461161
for expert_id in module.initial_local_expert_ids
11471162
]
1148-
all_w3_w1_weight_scale_2_max = torch.max(
1149-
torch.stack(all_w3_weight_scale_2 + all_w1_weight_scale_2).to(
1150-
module.dtype),
1151-
dim=0,
1152-
).values
1153-
module.fc31_alpha.data.copy_(all_w3_w1_weight_scale_2_max.float() *
1154-
all_w3_w1_input_scales_max.float())
1163+
all_w3_w1_weight_scale_2 = torch.stack([
1164+
torch.stack(all_w3_weight_scale_2),
1165+
torch.stack(all_w1_weight_scale_2)
1166+
]).to(module.dtype)
1167+
all_w3_w1_weight_scale_2_greater = torch.max(
1168+
all_w3_w1_weight_scale_2, dim=0).values
1169+
1170+
all_w3_w1_weight_scale_2_mul_input_scales = (
1171+
all_w3_w1_weight_scale_2_greater.reshape(
1172+
module.expert_size_per_partition, 1).float() *
1173+
all_w3_w1_input_scales_greater.reshape(
1174+
module.expert_size_per_partition, 1).float())
1175+
module.fc31_alpha.data.copy_(
1176+
all_w3_w1_weight_scale_2_mul_input_scales.reshape(
1177+
module.expert_size_per_partition, 1).float())
11551178

11561179
# Per-group weight_scale
11571180
all_w3_scales = [
@@ -1179,7 +1202,11 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
11791202
w3_w1_scales = all_w3_w1_scales.to(torch.bfloat16).view(
11801203
module.dtype)
11811204
if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
1182-
w3_w1_scales /= all_w3_w1_weight_scale_2_max.float()
1205+
w3_w1_scales = w3_w1_scales.permute(1, 2, 0)
1206+
w3_w1_scales /= all_w3_w1_weight_scale_2_greater.reshape(
1207+
module.expert_size_per_partition).float()
1208+
w3_w1_scales = w3_w1_scales.permute(2, 0, 1)
1209+
11831210
w3_w1_s_shape = w3_w1_scales.shape
11841211
w3_w1_scales_interleaved = w3_w1_scales.reshape(
11851212
w3_w1_s_shape[0], w3_w1_s_shape[1],

0 commit comments

Comments
 (0)