Skip to content

Commit d8a8dc8

Browse files
committed
Ensure that the input_scale remains aligned across all ranks for W4A8 custom ckpts.
Signed-off-by: Yilin Zhang <[email protected]>
1 parent faa2f46 commit d8a8dc8

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,16 +1144,20 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
11441144
weight_scale_name = "weight_scale"
11451145

11461146
assert (len(module.interleave) == 2)
1147+
1148+
# Ensure that the input_scale remains aligned across all ranks for W4A8 custom.
1149+
input_scale_expert_ids = module.initial_local_expert_ids if not w4a8_custom else range(
1150+
module.num_experts)
11471151
# fc31 scales
11481152
all_w3_input_scales = [
11491153
load_weight_shard(weights[f"{expert_id}.w3.input_scale"],
11501154
device=self.device)
1151-
for expert_id in module.initial_local_expert_ids
1155+
for expert_id in input_scale_expert_ids
11521156
]
11531157
all_w1_input_scales = [
11541158
load_weight_shard(weights[f"{expert_id}.w1.input_scale"],
11551159
device=self.device)
1156-
for expert_id in module.initial_local_expert_ids
1160+
for expert_id in input_scale_expert_ids
11571161
]
11581162
all_w3_w1_input_scales_max = torch.max(
11591163
torch.stack(all_w3_input_scales),

0 commit comments

Comments
 (0)