@@ -622,9 +622,12 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
622622 load_weights_vanilla_helper (module , weights )
623623
624624 scale_name = self ._get_scale_name (weights )
625- weight_scale = load_weight_shard (weights [0 ][scale_name ], module .tp_size ,
626- module .tp_rank ,
627- module .tp_mode ).squeeze ()
625+ full_weight_scale = weights [0 ][scale_name ]
626+ # modelopt fp8_pb_wo can have 2 extra singleton dimensions
627+ if full_weight_scale .dim () == 4 :
628+ full_weight_scale = full_weight_scale .squeeze (1 ).squeeze (- 1 )
629+ weight_scale = load_weight_shard (full_weight_scale , module .tp_size ,
630+ module .tp_rank , module .tp_mode )
628631 copy_weight (module .weight_scale , weight_scale )
629632 if "input_scale" in weights [0 ]:
630633 copy_weight (module .input_scale , weights [0 ]["input_scale" ])
@@ -637,13 +640,23 @@ def load_weights_fused_qkv_linear(self, module: Linear,
637640 fused_weight = torch .cat ((q_weight , k_weight , v_weight ))
638641
639642 scale_name = self ._get_scale_name (weights )
640- q_scale = load_weight_shard (weights [0 ][scale_name ], module .tp_size ,
643+ full_q_scale = weights [0 ][scale_name ]
644+ full_k_scale = weights [1 ][scale_name ]
645+ full_v_scale = weights [2 ][scale_name ]
646+ # modelopt fp8_pb_wo can have 2 extra singleton dimensions
647+ if full_q_scale .dim () == 4 :
648+ full_q_scale = full_q_scale .squeeze (1 ).squeeze (- 1 )
649+ if full_k_scale .dim () == 4 :
650+ full_k_scale = full_k_scale .squeeze (1 ).squeeze (- 1 )
651+ if full_v_scale .dim () == 4 :
652+ full_v_scale = full_v_scale .squeeze (1 ).squeeze (- 1 )
653+ q_scale = load_weight_shard (full_q_scale , module .tp_size ,
641654 module .tp_rank , module .tp_mode )
642- k_scale = load_weight_shard (weights [ 1 ][ scale_name ] , module .tp_size ,
655+ k_scale = load_weight_shard (full_k_scale , module .tp_size ,
643656 module .tp_rank , module .tp_mode )
644- v_scale = load_weight_shard (weights [ 2 ][ scale_name ] , module .tp_size ,
657+ v_scale = load_weight_shard (full_v_scale , module .tp_size ,
645658 module .tp_rank , module .tp_mode )
646- fused_fp8_block_scale = torch .cat ((q_scale , k_scale , v_scale )). squeeze ()
659+ fused_fp8_block_scale = torch .cat ((q_scale , k_scale , v_scale ))
647660
648661 copy_weight (module .weight , fused_weight )
649662 copy_weight (module .weight_scale , fused_fp8_block_scale )
@@ -655,11 +668,18 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
655668 fused_weight = torch .cat ((gate_weight , up_weight ))
656669
657670 scale_name = self ._get_scale_name (weights )
658- left_scale = load_weight_shard (weights [0 ][scale_name ], module .tp_size ,
671+ full_left_scale = weights [0 ][scale_name ]
672+ full_right_scale = weights [1 ][scale_name ]
673+ # modelopt fp8_pb_wo can have 2 extra singleton dimensions
674+ if full_left_scale .dim () == 4 :
675+ full_left_scale = full_left_scale .squeeze (1 ).squeeze (- 1 )
676+ if full_right_scale .dim () == 4 :
677+ full_right_scale = full_right_scale .squeeze (1 ).squeeze (- 1 )
678+ left_scale = load_weight_shard (full_left_scale , module .tp_size ,
659679 module .tp_rank , module .tp_mode )
660- right_scale = load_weight_shard (weights [ 1 ][ scale_name ] , module .tp_size ,
680+ right_scale = load_weight_shard (full_right_scale , module .tp_size ,
661681 module .tp_rank , module .tp_mode )
662- fused_scale = torch .cat ([left_scale , right_scale ], dim = 0 ). squeeze ()
682+ fused_scale = torch .cat ([left_scale , right_scale ], dim = 0 )
663683 copy_weight (module .weight , fused_weight )
664684 copy_weight (module .weight_scale , fused_scale )
665685
0 commit comments