diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index cf404042497..4a500fa441d 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -45,6 +45,7 @@ from tensorrt_llm.llmapi.utils import enable_llm_debug from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig +from tensorrt_llm.quantization.mode import QuantAlgo from tensorrt_llm.quantization.utils.fp8_utils import ( resmooth_to_fp8_e8m0, transform_sf_into_required_layout) @@ -456,10 +457,13 @@ def __init__(self, layer_idx=layer_idx, # DS-R1 W4A8 is only supported through custom quantization script from # examples/quantization/quantize_mixed_precision_moe.py - weight_loading_mode=(MoEWeightLoadingMode.W4A8_CUSTOM - if model_config.quant_config.quant_mode. - is_int4_weight_only_per_group() else - MoEWeightLoadingMode.VANILLA)) + weight_loading_mode=( + MoEWeightLoadingMode.W4A8_CUSTOM + if self._get_experts_quant_config( + model_config, + layer_idx).layer_quant_mode.is_int4_weight_only_per_group() + else MoEWeightLoadingMode.VANILLA), + ) self.mapping = model_config.mapping @@ -524,6 +528,13 @@ def _compute_shared_expert_tp_size(self, intermediate_size: int, return shared_tp_size, shared_output_scale + @staticmethod + def _get_experts_quant_config(model_config, layer_idx: int) -> QuantConfig: + if getattr(model_config, "quant_config_dict", None) is None: + return model_config.quant_config + return model_config.quant_config_dict.get( + f"model.layers.{layer_idx}.mlp.experts", model_config.quant_config) + def compute_routed_output(self, hidden_states, hidden_states_fp4, all_rank_num_tokens, all_rank_max_num_tokens, do_finalize): @@ -634,6 +645,9 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], quant_config = self._get_decoder_layer_quant_config( model_config, layer_idx) self.is_nvfp4 = quant_config.layer_quant_mode.has_nvfp4() + assert ( + quant_config.quant_algo + is not QuantAlgo.MIXED_PRECISION), "MIXED_PRECISION is ambiguous" has_tp = mapping.has_tp()