Skip to content

Commit d0a4c05

Browse files
rosenrodtyumin066
authored andcommitted
fix weight_loading_mode for deepseek
Signed-off-by: Anthony Chang <[email protected]>
1 parent f12e6b1 commit d0a4c05

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from tensorrt_llm.llmapi.utils import enable_llm_debug
4646
from tensorrt_llm.mapping import Mapping
4747
from tensorrt_llm.models.modeling_utils import QuantConfig
48+
from tensorrt_llm.quantization.mode import QuantAlgo
4849
from tensorrt_llm.quantization.utils.fp8_utils import (
4950
resmooth_to_fp8_e8m0, transform_sf_into_required_layout)
5051

@@ -468,10 +469,13 @@ def __init__(self,
468469
layer_idx=layer_idx,
469470
# DS-R1 W4A8 is only supported through custom quantization script from
470471
# examples/quantization/quantize_mixed_precision_moe.py
471-
weight_loading_mode=(MoEWeightLoadingMode.W4A8_CUSTOM
472-
if model_config.quant_config.quant_mode.
473-
is_int4_weight_only_per_group() else
474-
MoEWeightLoadingMode.VANILLA))
472+
weight_loading_mode=(
473+
MoEWeightLoadingMode.W4A8_CUSTOM
474+
if self._get_experts_quant_config(
475+
model_config,
476+
layer_idx).layer_quant_mode.is_int4_weight_only_per_group()
477+
else MoEWeightLoadingMode.VANILLA),
478+
)
475479

476480
self.mapping = model_config.mapping
477481

@@ -536,6 +540,13 @@ def _compute_shared_expert_tp_size(self, intermediate_size: int,
536540

537541
return shared_tp_size, shared_output_scale
538542

543+
@staticmethod
544+
def _get_experts_quant_config(model_config, layer_idx: int) -> QuantConfig:
545+
if getattr(model_config, "quant_config_dict", None) is None:
546+
return model_config.quant_config
547+
return model_config.quant_config_dict.get(
548+
f"model.layers.{layer_idx}.mlp.experts", model_config.quant_config)
549+
539550
def compute_routed_output(self, hidden_states, hidden_states_fp4,
540551
all_rank_num_tokens, all_rank_max_num_tokens,
541552
do_finalize):
@@ -657,6 +668,9 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
657668
quant_config = self._get_decoder_layer_quant_config(
658669
model_config, layer_idx)
659670
self.is_nvfp4 = quant_config.layer_quant_mode.has_nvfp4()
671+
assert (
672+
quant_config.quant_algo
673+
is not QuantAlgo.MIXED_PRECISION), "MIXED_PRECISION is ambiguous"
660674

661675
has_tp = mapping.has_tp()
662676
self.allreduce = AllReduce(mapping=model_config.mapping,

0 commit comments

Comments
 (0)