@@ -910,9 +910,10 @@ def create_weights(self, module: torch.nn.Module):
910910 module .intermediate_size_per_partition // 2 )
911911
912912 # Multiply act with reciprocal of per-channel pre_quant_scale * per-tensor input_scale
913- fc31_act_scale = nn .Parameter (torch .empty (1 ,
914- module .hidden_size ,
915- dtype = module .dtype ),
913+ fc31_act_scale = nn .Parameter (torch .empty (
914+ module .expert_size_per_partition ,
915+ module .hidden_size ,
916+ dtype = module .dtype ),
916917 requires_grad = False )
917918 module .register_parameter ("fc31_act_scale" , fc31_act_scale )
918919
@@ -1125,15 +1126,29 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
11251126 device = self .device )
11261127 for expert_id in module .initial_local_expert_ids
11271128 ]
1128- all_w3_w1_pre_quant_scales_max = torch .max (
1129- torch .stack (all_w3_pre_quant_scales +
1130- all_w1_pre_quant_scales ).to (module .dtype ),
1129+ all_w3_w1_pre_quant_scales_greater = torch .max (
1130+ torch .stack ([
1131+ torch .stack (all_w3_pre_quant_scales ),
1132+ torch .stack (all_w1_pre_quant_scales )
1133+ ]).to (module .dtype ),
1134+ dim = 0 ,
1135+ ).values .permute (1 , 0 )
1136+
1137+ all_w3_w1_input_scales_greater = torch .max (
1138+ torch .stack ([
1139+ torch .stack (all_w3_input_scales ),
1140+ torch .stack (all_w1_input_scales )
1141+ ]).to (module .dtype ),
11311142 dim = 0 ,
11321143 ).values
1144+
1145+ all_w3_w1_pre_quant_scales_div_input_scales = (
1146+ all_w3_w1_pre_quant_scales_greater *
1147+ (1 / all_w3_w1_input_scales_greater .reshape (
1148+ 1 , module .expert_size_per_partition ).float ()))
1149+
11331150 module .fc31_act_scale .data .copy_ (
1134- torch .ones_like (module .fc31_act_scale , device = self .device ) *
1135- (all_w3_w1_pre_quant_scales_max ) *
1136- (1 / all_w3_w1_input_scales_max ))
1151+ all_w3_w1_pre_quant_scales_div_input_scales .permute (1 , 0 ))
11371152 # In vanilla ckpt (at least from ModelOpt), per-tensor weight_scale_2 is separately stored
11381153 all_w3_weight_scale_2 = [
11391154 load_weight_shard (weights [f"{ expert_id } .w3.weight_scale_2" ],
@@ -1145,13 +1160,21 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
11451160 device = self .device )
11461161 for expert_id in module .initial_local_expert_ids
11471162 ]
1148- all_w3_w1_weight_scale_2_max = torch .max (
1149- torch .stack (all_w3_weight_scale_2 + all_w1_weight_scale_2 ).to (
1150- module .dtype ),
1151- dim = 0 ,
1152- ).values
1153- module .fc31_alpha .data .copy_ (all_w3_w1_weight_scale_2_max .float () *
1154- all_w3_w1_input_scales_max .float ())
1163+ all_w3_w1_weight_scale_2 = torch .stack ([
1164+ torch .stack (all_w3_weight_scale_2 ),
1165+ torch .stack (all_w1_weight_scale_2 )
1166+ ]).to (module .dtype )
1167+ all_w3_w1_weight_scale_2_greater = torch .max (
1168+ all_w3_w1_weight_scale_2 , dim = 0 ).values
1169+
1170+ all_w3_w1_weight_scale_2_mul_input_scales = (
1171+ all_w3_w1_weight_scale_2_greater .reshape (
1172+ module .expert_size_per_partition , 1 ).float () *
1173+ all_w3_w1_input_scales_greater .reshape (
1174+ module .expert_size_per_partition , 1 ).float ())
1175+ module .fc31_alpha .data .copy_ (
1176+ all_w3_w1_weight_scale_2_mul_input_scales .reshape (
1177+ module .expert_size_per_partition , 1 ).float ())
11551178
11561179 # Per-group weight_scale
11571180 all_w3_scales = [
@@ -1179,7 +1202,11 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
11791202 w3_w1_scales = all_w3_w1_scales .to (torch .bfloat16 ).view (
11801203 module .dtype )
11811204 if module .weight_loading_mode == MoEWeightLoadingMode .VANILLA :
1182- w3_w1_scales /= all_w3_w1_weight_scale_2_max .float ()
1205+ w3_w1_scales = w3_w1_scales .permute (1 , 2 , 0 )
1206+ w3_w1_scales /= all_w3_w1_weight_scale_2_greater .reshape (
1207+ module .expert_size_per_partition ).float ()
1208+ w3_w1_scales = w3_w1_scales .permute (2 , 0 , 1 )
1209+
11831210 w3_w1_s_shape = w3_w1_scales .shape
11841211 w3_w1_scales_interleaved = w3_w1_scales .reshape (
11851212 w3_w1_s_shape [0 ], w3_w1_s_shape [1 ],
0 commit comments