|
45 | 45 | from tensorrt_llm.llmapi.utils import enable_llm_debug |
46 | 46 | from tensorrt_llm.mapping import Mapping |
47 | 47 | from tensorrt_llm.models.modeling_utils import QuantConfig |
| 48 | +from tensorrt_llm.quantization.mode import QuantAlgo |
48 | 49 | from tensorrt_llm.quantization.utils.fp8_utils import ( |
49 | 50 | resmooth_to_fp8_e8m0, transform_sf_into_required_layout) |
50 | 51 |
|
@@ -468,10 +469,13 @@ def __init__(self, |
468 | 469 | layer_idx=layer_idx, |
469 | 470 | # DS-R1 W4A8 is only supported through custom quantization script from |
470 | 471 | # examples/quantization/quantize_mixed_precision_moe.py |
471 | | - weight_loading_mode=(MoEWeightLoadingMode.W4A8_CUSTOM |
472 | | - if model_config.quant_config.quant_mode. |
473 | | - is_int4_weight_only_per_group() else |
474 | | - MoEWeightLoadingMode.VANILLA)) |
| 472 | + weight_loading_mode=( |
| 473 | + MoEWeightLoadingMode.W4A8_CUSTOM |
| 474 | + if self._get_experts_quant_config( |
| 475 | + model_config, |
| 476 | + layer_idx).layer_quant_mode.is_int4_weight_only_per_group() |
| 477 | + else MoEWeightLoadingMode.VANILLA), |
| 478 | + ) |
475 | 479 |
|
476 | 480 | self.mapping = model_config.mapping |
477 | 481 |
|
@@ -536,6 +540,13 @@ def _compute_shared_expert_tp_size(self, intermediate_size: int, |
536 | 540 |
|
537 | 541 | return shared_tp_size, shared_output_scale |
538 | 542 |
|
| 543 | + @staticmethod |
| 544 | + def _get_experts_quant_config(model_config, layer_idx: int) -> QuantConfig: |
| 545 | + if getattr(model_config, "quant_config_dict", None) is None: |
| 546 | + return model_config.quant_config |
| 547 | + return model_config.quant_config_dict.get( |
| 548 | + f"model.layers.{layer_idx}.mlp.experts", model_config.quant_config) |
| 549 | + |
539 | 550 | def compute_routed_output(self, hidden_states, hidden_states_fp4, |
540 | 551 | all_rank_num_tokens, all_rank_max_num_tokens, |
541 | 552 | do_finalize): |
@@ -657,6 +668,9 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], |
657 | 668 | quant_config = self._get_decoder_layer_quant_config( |
658 | 669 | model_config, layer_idx) |
659 | 670 | self.is_nvfp4 = quant_config.layer_quant_mode.has_nvfp4() |
| 671 | + assert ( |
| 672 | + quant_config.quant_algo |
| 673 | + is not QuantAlgo.MIXED_PRECISION), "MIXED_PRECISION is ambiguous" |
660 | 674 |
|
661 | 675 | has_tp = mapping.has_tp() |
662 | 676 | self.allreduce = AllReduce(mapping=model_config.mapping, |
|
0 commit comments