@@ -468,45 +468,8 @@ def create_weights(self, module: torch.nn.Module):
468468
469469 def load_weights (self , module : torch .nn .Module , weights : List [Dict ],
470470 weight_loading_mode : MoEWeightLoadingMode ):
471-
472- if get_sm_version () == 100 :
473- expert_ids = set (module .initial_local_expert_ids )
474- if self .need_load_shared_weights (module ):
475- expert_ids .update (
476- module .layer_load_balancer .get_load_expert_ids ())
477- for name in list (weights .keys ()):
478- if name .endswith ("weight_scale_inv" ):
479- if int (name .split ("." )[0 ]) not in expert_ids :
480- continue
481- weight_name = name .replace ("weight_scale_inv" , "weight" )
482- logger .debug (f"Resmoothing { weight_name } " )
483- weight = weights [weight_name ][:]
484- scale = weights [name ][:]
485- weights [weight_name ], weights [name ] = resmooth_to_fp8_e8m0 (
486- weight , scale )
487471 super ().load_weights (module , weights , weight_loading_mode )
488472
489- if get_sm_version () == 100 :
490- transfromed_w3_w1_scale = transform_sf_into_required_layout (
491- module .quant_scales [0 ],
492- mn = module .w3_w1_weight .shape [1 ],
493- k = module .w3_w1_weight .shape [2 ],
494- recipe = (1 , 128 , 128 ),
495- num_groups = module .w3_w1_weight .shape [0 ],
496- is_sfa = False )
497- module .w3_w1_weight_scaling_factor = nn .Parameter (
498- transfromed_w3_w1_scale , requires_grad = False )
499- transfromed_w2_scale = transform_sf_into_required_layout (
500- module .quant_scales [1 ],
501- mn = module .w2_weight .shape [1 ],
502- k = module .w2_weight .shape [2 ],
503- recipe = (1 , 128 , 128 ),
504- num_groups = module .w3_w1_weight .shape [0 ],
505- is_sfa = False )
506- module .w2_weight_scaling_factor = nn .Parameter (transfromed_w2_scale ,
507- requires_grad = False )
508- self .setup_quant_scales (module )
509-
510473 def setup_quant_scales (self , module : torch .nn .Module ):
511474 module .quant_scales = FusedMoEQuantScalesDeepSeekFP8BlockScales (
512475 fc_weight_scales = module .w3_w1_weight_scaling_factor ,
@@ -603,6 +566,50 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
603566 })
604567
605568
569+ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm (
570+ DeepSeekFP8BlockScalesFusedMoEMethod ):
571+
572+ def load_weights (self , module : torch .nn .Module , weights : List [Dict ],
573+ weight_loading_mode : MoEWeightLoadingMode ):
574+ if get_sm_version () == 100 :
575+ expert_ids = set (module .initial_local_expert_ids )
576+ if self .need_load_shared_weights (module ):
577+ expert_ids .update (
578+ module .layer_load_balancer .get_load_expert_ids ())
579+ for name in list (weights .keys ()):
580+ if name .endswith ("weight_scale_inv" ):
581+ if int (name .split ("." )[0 ]) not in expert_ids :
582+ continue
583+ weight_name = name .replace ("weight_scale_inv" , "weight" )
584+ logger .debug (f"Resmoothing { weight_name } " )
585+ weight = weights [weight_name ][:]
586+ scale = weights [name ][:]
587+ weights [weight_name ], weights [name ] = resmooth_to_fp8_e8m0 (
588+ weight , scale )
589+ super ().load_weights (module , weights , weight_loading_mode )
590+
591+ if get_sm_version () == 100 :
592+ transfromed_w3_w1_scale = transform_sf_into_required_layout (
593+ module .quant_scales [0 ],
594+ mn = module .w3_w1_weight .shape [1 ],
595+ k = module .w3_w1_weight .shape [2 ],
596+ recipe = (1 , 128 , 128 ),
597+ num_groups = module .w3_w1_weight .shape [0 ],
598+ is_sfa = False )
599+ module .w3_w1_weight_scaling_factor = nn .Parameter (
600+ transfromed_w3_w1_scale , requires_grad = False )
601+ transfromed_w2_scale = transform_sf_into_required_layout (
602+ module .quant_scales [1 ],
603+ mn = module .w2_weight .shape [1 ],
604+ k = module .w2_weight .shape [2 ],
605+ recipe = (1 , 128 , 128 ),
606+ num_groups = module .w3_w1_weight .shape [0 ],
607+ is_sfa = False )
608+ module .w2_weight_scaling_factor = nn .Parameter (transfromed_w2_scale ,
609+ requires_grad = False )
610+ self .setup_quant_scales (module )
611+
612+
606613class WInt4AFP8FusedMoEMethod (FusedMoEMethodBase ):
607614
608615 def create_weights (self , module : torch .nn .Module ):
0 commit comments