Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from tensorrt_llm.llmapi.utils import enable_llm_debug
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization.mode import QuantAlgo
from tensorrt_llm.quantization.utils.fp8_utils import (
resmooth_to_fp8_e8m0, transform_sf_into_required_layout)

Expand Down Expand Up @@ -456,10 +457,13 @@ def __init__(self,
layer_idx=layer_idx,
# DS-R1 W4A8 is only supported through custom quantization script from
# examples/quantization/quantize_mixed_precision_moe.py
weight_loading_mode=(MoEWeightLoadingMode.W4A8_CUSTOM
if model_config.quant_config.quant_mode.
is_int4_weight_only_per_group() else
MoEWeightLoadingMode.VANILLA))
weight_loading_mode=(
MoEWeightLoadingMode.W4A8_CUSTOM
if self._get_experts_quant_config(
model_config,
layer_idx).layer_quant_mode.is_int4_weight_only_per_group()
else MoEWeightLoadingMode.VANILLA),
)

self.mapping = model_config.mapping

Expand Down Expand Up @@ -524,6 +528,13 @@ def _compute_shared_expert_tp_size(self, intermediate_size: int,

return shared_tp_size, shared_output_scale

@staticmethod
def _get_experts_quant_config(model_config, layer_idx: int) -> QuantConfig:
if getattr(model_config, "quant_config_dict", None) is None:
return model_config.quant_config
return model_config.quant_config_dict.get(
f"model.layers.{layer_idx}.mlp.experts", model_config.quant_config)

def compute_routed_output(self, hidden_states, hidden_states_fp4,
all_rank_num_tokens, all_rank_max_num_tokens,
do_finalize):
Expand Down Expand Up @@ -634,6 +645,9 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
quant_config = self._get_decoder_layer_quant_config(
model_config, layer_idx)
self.is_nvfp4 = quant_config.layer_quant_mode.has_nvfp4()
assert (
quant_config.quant_algo
is not QuantAlgo.MIXED_PRECISION), "MIXED_PRECISION is ambiguous"

has_tp = mapping.has_tp()

Expand Down