Skip to content

Commit 548da52

Browse files
achartierdominicshanshan
authored andcommitted
[https://nvbugs/5449155][fix] Fix DeepSeek R1 weight loading for TP16 (#6913)
Signed-off-by: Aurelien Chartier <[email protected]>
1 parent 93f15aa commit 548da52

File tree

1 file changed

+30
-10
lines changed

1 file changed

+30
-10
lines changed

tensorrt_llm/_torch/modules/linear.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -622,9 +622,12 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
622622
load_weights_vanilla_helper(module, weights)
623623

624624
scale_name = self._get_scale_name(weights)
625-
weight_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
626-
module.tp_rank,
627-
module.tp_mode).squeeze()
625+
full_weight_scale = weights[0][scale_name]
626+
# modelopt fp8_pb_wo can have 2 extra singleton dimensions
627+
if full_weight_scale.dim() == 4:
628+
full_weight_scale = full_weight_scale.squeeze(1).squeeze(-1)
629+
weight_scale = load_weight_shard(full_weight_scale, module.tp_size,
630+
module.tp_rank, module.tp_mode)
628631
copy_weight(module.weight_scale, weight_scale)
629632
if "input_scale" in weights[0]:
630633
copy_weight(module.input_scale, weights[0]["input_scale"])
@@ -637,13 +640,23 @@ def load_weights_fused_qkv_linear(self, module: Linear,
637640
fused_weight = torch.cat((q_weight, k_weight, v_weight))
638641

639642
scale_name = self._get_scale_name(weights)
640-
q_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
643+
full_q_scale = weights[0][scale_name]
644+
full_k_scale = weights[1][scale_name]
645+
full_v_scale = weights[2][scale_name]
646+
# modelopt fp8_pb_wo can have 2 extra singleton dimensions
647+
if full_q_scale.dim() == 4:
648+
full_q_scale = full_q_scale.squeeze(1).squeeze(-1)
649+
if full_k_scale.dim() == 4:
650+
full_k_scale = full_k_scale.squeeze(1).squeeze(-1)
651+
if full_v_scale.dim() == 4:
652+
full_v_scale = full_v_scale.squeeze(1).squeeze(-1)
653+
q_scale = load_weight_shard(full_q_scale, module.tp_size,
641654
module.tp_rank, module.tp_mode)
642-
k_scale = load_weight_shard(weights[1][scale_name], module.tp_size,
655+
k_scale = load_weight_shard(full_k_scale, module.tp_size,
643656
module.tp_rank, module.tp_mode)
644-
v_scale = load_weight_shard(weights[2][scale_name], module.tp_size,
657+
v_scale = load_weight_shard(full_v_scale, module.tp_size,
645658
module.tp_rank, module.tp_mode)
646-
fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale)).squeeze()
659+
fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale))
647660

648661
copy_weight(module.weight, fused_weight)
649662
copy_weight(module.weight_scale, fused_fp8_block_scale)
@@ -655,11 +668,18 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
655668
fused_weight = torch.cat((gate_weight, up_weight))
656669

657670
scale_name = self._get_scale_name(weights)
658-
left_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
671+
full_left_scale = weights[0][scale_name]
672+
full_right_scale = weights[1][scale_name]
673+
# modelopt fp8_pb_wo can have 2 extra singleton dimensions
674+
if full_left_scale.dim() == 4:
675+
full_left_scale = full_left_scale.squeeze(1).squeeze(-1)
676+
if full_right_scale.dim() == 4:
677+
full_right_scale = full_right_scale.squeeze(1).squeeze(-1)
678+
left_scale = load_weight_shard(full_left_scale, module.tp_size,
659679
module.tp_rank, module.tp_mode)
660-
right_scale = load_weight_shard(weights[1][scale_name], module.tp_size,
680+
right_scale = load_weight_shard(full_right_scale, module.tp_size,
661681
module.tp_rank, module.tp_mode)
662-
fused_scale = torch.cat([left_scale, right_scale], dim=0).squeeze()
682+
fused_scale = torch.cat([left_scale, right_scale], dim=0)
663683
copy_weight(module.weight, fused_weight)
664684
copy_weight(module.weight_scale, fused_scale)
665685

0 commit comments

Comments
 (0)