Skip to content

Commit 48fda86

Browse files
authored
[None][fix] Fix dummy load format for DeepSeek. (#7874)
Signed-off-by: Yuxian Qiu <[email protected]>
1 parent 6e5e8b8 commit 48fda86

File tree

11 files changed

+87
-68
lines changed

11 files changed

+87
-68
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -393,28 +393,6 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
393393
for n, p in module.named_parameters():
394394
p.data.copy_(module_weights[n][:])
395395

396-
if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales(
397-
) and is_sm_100f() and hasattr(module, "weight_scale"):
398-
weight, weight_scale = resmooth_to_fp8_e8m0(
399-
module.weight, module.weight_scale)
400-
transfromed_scale = transform_sf_into_required_layout(
401-
weight_scale,
402-
mn=weight.shape[0],
403-
k=weight.shape[1],
404-
recipe=(1, 128, 128),
405-
is_sfa=False)
406-
module.weight = nn.Parameter(weight, requires_grad=False)
407-
module.weight_scale = nn.Parameter(transfromed_scale,
408-
requires_grad=False)
409-
if not self.is_draft_model:
410-
for idx, layer in enumerate(
411-
self.model.model.layers[:self.config.num_hidden_layers]):
412-
if idx == self.config.num_hidden_layers - 1:
413-
layer.next_layer_layernorm = self.model.model.norm
414-
else:
415-
layer.next_layer_layernorm = self.model.model.layers[
416-
idx + 1].input_layernorm
417-
418396

419397
class DeepseekV3MTPHead(nn.Module):
420398

@@ -1540,3 +1518,32 @@ def forward(
15401518
def load_weights(self, weights: Dict):
15411519
weight_loader = DeepseekV3WeightLoader(self)
15421520
weight_loader.load_weights(weights)
1521+
1522+
def post_load_weights(self):
1523+
all_named_modules = dict(self.model.named_modules())
1524+
for name, module in tqdm(all_named_modules.items(),
1525+
desc="Post loading weights"):
1526+
if len(module._parameters) <= 0 or name.startswith("draft_model"):
1527+
continue
1528+
else:
1529+
if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales(
1530+
) and is_sm_100f() and hasattr(module, "weight_scale"):
1531+
weight, weight_scale = resmooth_to_fp8_e8m0(
1532+
module.weight, module.weight_scale)
1533+
transfromed_scale = transform_sf_into_required_layout(
1534+
weight_scale,
1535+
mn=weight.shape[0],
1536+
k=weight.shape[1],
1537+
recipe=(1, 128, 128),
1538+
is_sfa=False)
1539+
module.weight = nn.Parameter(weight, requires_grad=False)
1540+
module.weight_scale = nn.Parameter(transfromed_scale,
1541+
requires_grad=False)
1542+
1543+
for idx, layer in enumerate(
1544+
self.model.layers[:self.config.num_hidden_layers]):
1545+
if idx == self.config.num_hidden_layers - 1:
1546+
layer.next_layer_layernorm = self.model.norm
1547+
else:
1548+
layer.next_layer_layernorm = self.model.layers[
1549+
idx + 1].input_layernorm

tensorrt_llm/_torch/models/modeling_llama_min_latency.py

Lines changed: 28 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from collections.abc import Callable
2-
from functools import partial
32
from typing import Dict, List, Optional, Tuple, Union
43

54
import torch
@@ -66,7 +65,6 @@ def __init__(
6665
enable_fused_gemm_swiglu: bool = False,
6766
enable_fused_gemm_attn_scaling: bool = False,
6867
enable_trtllm_gen: bool = False,
69-
post_load_weights_hook: Optional[Callable] = None,
7068
):
7169
# First, initialize the base class.
7270
super().__init__(
@@ -88,7 +86,6 @@ def __init__(
8886
self.enable_fused_gemm_swiglu = enable_fused_gemm_swiglu
8987
self.enable_fused_gemm_attn_scaling = enable_fused_gemm_attn_scaling
9088
self.enable_trtllm_gen = enable_trtllm_gen
91-
self.post_load_weights_hook = post_load_weights_hook
9289
self.position_ids = None
9390

9491
def load_weights(self, weights: List[Dict]):
@@ -123,9 +120,6 @@ def load_weights(self, weights: List[Dict]):
123120
self.weight.view(torch.uint8),
124121
128).view(torch.float8_e4m3fn)
125122

126-
if self.post_load_weights_hook is not None:
127-
self.post_load_weights_hook(self)
128-
129123
# Override apply_linear instead of forward so that we can reuse the AllReduce/AllGather logic in the parent class.
130124
def apply_linear(
131125
self,
@@ -298,17 +292,6 @@ def __init__(self,
298292
enable_trtllm_gen=True,
299293
)
300294

301-
# After loading both gate_up_proj and down_proj, we need to set the scales needed by the special kernels and by
302-
# the trtllm-gen gemm+swiglu kernel.
303-
def post_load_weights_hook(gate_up_proj, down_proj):
304-
if gate_up_proj.has_fp8_qdq:
305-
# For the special gemm+swiglu kernel, we need to set the inverse of the output scale, which is the inverse
306-
# of down_proj's combined input scale.
307-
gate_up_proj.inv_output_scale = 1.0 / down_proj.input_scale
308-
# For the trtllm-gen gemm+swiglu kernel, we need to set the global scale, which is gate_up_proj's
309-
# combined input scale times inv_output_scale.
310-
gate_up_proj.trtllm_gen_global_scale = gate_up_proj.combined_scale * gate_up_proj.inv_output_scale
311-
312295
self.down_proj = Llama4MinLatencyLinear(
313296
self.intermediate_size,
314297
self.hidden_size,
@@ -320,10 +303,19 @@ def post_load_weights_hook(gate_up_proj, down_proj):
320303
reduce_output=reduce_output,
321304
skip_create_weights_in_init=config.skip_create_weights_in_init,
322305
enable_trtllm_gen=True,
323-
post_load_weights_hook=partial(post_load_weights_hook,
324-
self.gate_up_proj),
325306
)
326307

308+
# After loading both gate_up_proj and down_proj, we need to set the scales needed by the special kernels and by
309+
# the trtllm-gen gemm+swiglu kernel.
310+
def post_load_weights(self):
311+
if self.gate_up_proj.has_fp8_qdq:
312+
# For the special gemm+swiglu kernel, we need to set the inverse of the output scale, which is the inverse
313+
# of down_proj's combined input scale.
314+
self.gate_up_proj.inv_output_scale = 1.0 / self.down_proj.input_scale
315+
# For the trtllm-gen gemm+swiglu kernel, we need to set the global scale, which is gate_up_proj's
316+
# combined input scale times inv_output_scale.
317+
self.gate_up_proj.trtllm_gen_global_scale = self.gate_up_proj.combined_scale * self.gate_up_proj.inv_output_scale
318+
327319
def forward(
328320
self,
329321
x: Union[torch.Tensor, Fp4QuantizedTensor],
@@ -450,7 +442,6 @@ def __init__(
450442
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.
451443
VANILLA,
452444
apply_router_weight_on_input: bool = False,
453-
post_load_weights_hook: Optional[Callable] = None,
454445
):
455446

456447
super().__init__(
@@ -466,8 +457,6 @@ def __init__(
466457
apply_router_weight_on_input=apply_router_weight_on_input,
467458
)
468459

469-
self.post_load_weights_hook = post_load_weights_hook
470-
471460
# Enable min-latency mode for Llama4 Maverick TP8 EP1.
472461
self.enable_min_latency_fused_moe = False
473462
if num_experts == 128 \
@@ -481,12 +470,6 @@ def __init__(
481470
and apply_router_weight_on_input:
482471
self.enable_min_latency_fused_moe = True
483472

484-
def load_weights(self, weights: List[Dict]):
485-
super().load_weights(weights)
486-
487-
if self.post_load_weights_hook:
488-
self.post_load_weights_hook(self)
489-
490473
def forward(
491474
self,
492475
x: Union[torch.Tensor, Fp4QuantizedTensor],
@@ -560,22 +543,6 @@ def __init__(
560543
overridden_tp_size=1 if self.enable_attention_dp else None,
561544
reduce_output=False)
562545

563-
def post_load_weights_hook(shared_expert, experts):
564-
# Set min-latency quant scales for routed experts if we plan to use min-latency MoE kernels.
565-
# This is because the routed experts' input scale is after the score multiplication, so we must use the
566-
# pre-score scaling input scale, which happens to be shared expert's input scale.
567-
if experts.enable_min_latency_fused_moe and hasattr(
568-
shared_expert.gate_up_proj, "input_scale"):
569-
pre_score_scaling_input_scale = shared_expert.gate_up_proj.input_scale
570-
experts.min_latency_quant_scales = FusedMoEQuantScalesFP8(
571-
fc1_dequant=experts.fc31_dequant.data /
572-
experts.fc31_input_dequant.data *
573-
pre_score_scaling_input_scale,
574-
fc2_quant=experts.fc2_quant,
575-
fc2_dequant=experts.fc2_dequant,
576-
fc1_input_dequant=pre_score_scaling_input_scale,
577-
)
578-
579546
self.experts = Llama4MinLatencyFusedMoE(
580547
routing_method=Llama4RenormalizeMoeRoutingMethod(top_k),
581548
num_experts=num_experts,
@@ -587,8 +554,7 @@ def post_load_weights_hook(shared_expert, experts):
587554
weight_loading_mode=MoEWeightLoadingMode.FUSED_GATE_UP_PROJ,
588555
model_config=model_config,
589556
apply_router_weight_on_input=True,
590-
post_load_weights_hook=partial(post_load_weights_hook,
591-
self.shared_expert))
557+
)
592558

593559
self.router = Llama4MinLatencyLinear(
594560
hidden_size,
@@ -597,6 +563,22 @@ def post_load_weights_hook(shared_expert, experts):
597563
dtype=model_config.pretrained_config.torch_dtype,
598564
quant_config=None)
599565

566+
def post_load_weights(self):
567+
# Set min-latency quant scales for routed experts if we plan to use min-latency MoE kernels.
568+
# This is because the routed experts' input scale is after the score multiplication, so we must use the
569+
# pre-score scaling input scale, which happens to be shared expert's input scale.
570+
if self.experts.enable_min_latency_fused_moe and hasattr(
571+
self.shared_expert.gate_up_proj, "input_scale"):
572+
pre_score_scaling_input_scale = self.shared_expert.gate_up_proj.input_scale
573+
self.experts.min_latency_quant_scales = FusedMoEQuantScalesFP8(
574+
fc1_dequant=self.experts.fc31_dequant.data /
575+
self.experts.fc31_input_dequant.data *
576+
pre_score_scaling_input_scale,
577+
fc2_quant=self.experts.fc2_quant,
578+
fc2_dequant=self.experts.fc2_dequant,
579+
fc1_input_dequant=pre_score_scaling_input_scale,
580+
)
581+
600582
def compute_routed_output(
601583
self,
602584
hidden_states,

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,3 +598,6 @@ def load_weights(self, weights: List[Dict]):
598598
weights = weights[0]
599599

600600
self.quant_method.load_weights(self, weights, self.weight_loading_mode)
601+
602+
def post_load_weights(self):
603+
self.quant_method.post_load_weights(self)

tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1385,3 +1385,6 @@ def load_weights(self, weights: List[Dict]):
13851385
weights = weights[0]
13861386

13871387
self.quant_method.load_weights(self, weights, self.weight_loading_mode)
1388+
1389+
def post_load_weights(self):
1390+
self.quant_method.post_load_weights(self)

tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,9 @@ def load_weights(self, weights: List[Dict]):
169169

170170
self.quant_method.load_weights(self, weights, self.weight_loading_mode)
171171

172+
def post_load_weights(self):
173+
self.quant_method.post_load_weights(self)
174+
172175
def forward_impl(
173176
self,
174177
x: Union[torch.Tensor, Fp4QuantizedTensor],

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,3 +1017,6 @@ def load_weights(self, weights: List[Dict]):
10171017
weights = weights[0]
10181018

10191019
self.quant_method.load_weights(self, weights, self.weight_loading_mode)
1020+
1021+
def post_load_weights(self):
1022+
self.quant_method.post_load_weights(self)

tensorrt_llm/_torch/modules/fused_moe/interface.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ def create_weights(self):
195195
def load_weights(self, weights: List[Dict]):
196196
raise NotImplementedError
197197

198+
def post_load_weights(self):
199+
pass
200+
198201
@abstractmethod
199202
def forward_impl(
200203
self,

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,9 @@ def load_weights(self, module: torch.nn.Module, weights: List[Dict],
322322
module.layer_load_balancer.set_initial_weight_assignments(
323323
module.initial_global_assignments)
324324

325+
def post_load_weights(self, module: torch.nn.Module):
326+
pass
327+
325328
def load_quant_scales(self, module: torch.nn.Module, weights: List[Dict]):
326329
pass
327330

@@ -726,6 +729,7 @@ def load_weights(self, module: torch.nn.Module, weights: List[Dict],
726729
weight, scale)
727730
super().load_weights(module, weights, weight_loading_mode)
728731

732+
def post_load_weights(self, module: torch.nn.Module):
729733
if is_sm_100f():
730734
transfromed_w3_w1_scale = transform_sf_into_required_layout(
731735
module.quant_scales[0],

tensorrt_llm/_torch/modules/linear.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,9 @@ def load_weights(self, module: Linear, weights: List[Dict],
241241
else:
242242
raise ValueError(f'unsupported weight mode: {weight_mode}')
243243

244+
def post_load_weights(self, module: Linear):
245+
pass
246+
244247
def load_weight_scales(self, weights: List[Dict], *args, **kwargs):
245248
"""
246249
Load quantized weight scales from the checkpoint.
@@ -2001,3 +2004,6 @@ def load_weights(self, weights: List[Dict]):
20012004

20022005
weight_mode = self.weights_loading_config.weight_mode
20032006
self.quant_method.load_weights(self, weights, weight_mode)
2007+
2008+
def post_load_weights(self):
2009+
self.quant_method.post_load_weights(self)

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,6 +1048,10 @@ def init_meta_tensor(t: torch.Tensor):
10481048
raise NotImplementedError(
10491049
f"No load support for load format: {load_format}")
10501050

1051+
for module in model.modules():
1052+
if hasattr(module, 'post_load_weights'):
1053+
module.post_load_weights()
1054+
10511055
if isinstance(moe_load_balancer, MoeLoadBalancer):
10521056
setattr(self, "moe_load_balancer", moe_load_balancer)
10531057
moe_load_balancer.register_weight_slots_after_to_cuda()

0 commit comments

Comments
 (0)