Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions tensorrt_llm/_torch/modules/fused_moe/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,16 +1144,20 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
weight_scale_name = "weight_scale"

assert (len(module.interleave) == 2)

# Ensure that the input_scale remains aligned across all ranks for W4A8 custom.
input_scale_expert_ids = module.initial_local_expert_ids if not w4a8_custom else range(
module.num_experts)
# fc31 scales
all_w3_input_scales = [
load_weight_shard(weights[f"{expert_id}.w3.input_scale"],
device=self.device)
for expert_id in module.initial_local_expert_ids
for expert_id in input_scale_expert_ids
]
all_w1_input_scales = [
load_weight_shard(weights[f"{expert_id}.w1.input_scale"],
device=self.device)
for expert_id in module.initial_local_expert_ids
for expert_id in input_scale_expert_ids
]
all_w3_w1_input_scales_max = torch.max(
torch.stack(all_w3_input_scales),
Expand Down