From 305ef8c5868fb1f59bccb12a273afdf9f13fb276 Mon Sep 17 00:00:00 2001 From: beipingpan Date: Wed, 13 Aug 2025 16:17:36 +0800 Subject: [PATCH 1/5] Support W4A8 method of AngleSlim tool Signed-off-by: beipingpan --- tensorrt_llm/_torch/model_config.py | 81 ++++++++++++++++++- .../_torch/models/modeling_deepseekv3.py | 2 +- tensorrt_llm/_torch/models/modeling_utils.py | 9 ++- .../_torch/modules/fused_moe/quantization.py | 58 +++++++++++-- tensorrt_llm/llmapi/llm_utils.py | 35 +++++++- tensorrt_llm/models/modeling_utils.py | 6 +- tensorrt_llm/quantization/mode.py | 5 ++ 7 files changed, 183 insertions(+), 13 deletions(-) diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 907cc2a1d9e..5740141a95b 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -20,7 +20,7 @@ from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig -from tensorrt_llm.quantization.mode import QuantAlgo +from tensorrt_llm.quantization.mode import QuantAlgo, ActivationScheme TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig) @@ -322,6 +322,47 @@ def load_modelopt_quant_config(quant_config_file, checkpoint_dir, ] return quant_config, layer_quant_config + @staticmethod + def load_angelslim_quant_config(quant_config_file, checkpoint_dir, moe_backend): + quant_config = QuantConfig() + layer_quant_config = None + + with open(quant_config_file) as f: + quant_config_dict = json.load(f) + + json_quant_configs = quant_config_dict['quantization'] + + quant_config.quant_algo = QuantAlgo( + json_quant_configs.get('quant_algo', None).upper()) if json_quant_configs.get("quant_algo") else None + # fp8_pb_wo from modelopt is the same as FP8_BLOCK_SCALES + if quant_config.quant_algo == "fp8_pb_wo": + quant_config.quant_algo = QuantAlgo('FP8_BLOCK_SCALES') + + quant_config.kv_cache_quant_algo = QuantAlgo( + json_quant_configs.get("kv_cache_quant_algo").upper() + ) if json_quant_configs.get("kv_cache_quant_algo") else None + quant_config.group_size = json_quant_configs.get('group_size', None) + quant_config.exclude_modules = json_quant_configs.get( + 'exclude_modules', None) + quant_config.activation_scheme = ActivationScheme( + json_quant_configs.get('activation_scheme', None).upper() + ) if json_quant_configs.get("activation_scheme") else None + + json_exclude_quant_configs = json_quant_configs.get('exclude_quantization', None) + if json_exclude_quant_configs: + quant_config.exclude_quant_config = { + "quant_algo": QuantAlgo( + json_exclude_quant_configs.get('quant_algo', None).upper() + ) if json_exclude_quant_configs.get("quant_algo") else None, + "kv_cache_quant_algo": QuantAlgo( + json_exclude_quant_configs.get("kv_cache_quant_algo").upper() + ) if json_exclude_quant_configs.get("kv_cache_quant_algo") else None, + "activation_scheme": ActivationScheme( + json_exclude_quant_configs.get('activation_scheme', None).upper() + ) if json_exclude_quant_configs.get("activation_scheme") else None, + } + return quant_config, layer_quant_config + @staticmethod def get_mxfp4_quant_algo(moe_backend, is_dynamic_quant=False): quant_algo = ModelConfig.override_quant_algo() @@ -366,6 +407,40 @@ def load_hf_quant_config(hf_quant_config, moe_backend): 'block.*.attn.out', 'block.*.mlp.gate', 'block.*.attn.qkv', 'embedding', 'unembedding' ] + elif hf_quant_config.get("quant_method") == "w4a8_awq": + quant_config.quant_algo = QuantAlgo.W4A8_AWQ + else: + raise NotImplementedError(f"Unsupported quantization_config: {hf_quant_config}.") + + # set kv_cache_quant_algo + quant_config.kv_cache_quant_algo = QuantAlgo(hf_quant_config.get("kv_cache_quant_method").upper()) \ + if hf_quant_config.get("kv_cache_quant_method") else None + # set activation_scheme + quant_config.activation_scheme = ActivationScheme(hf_quant_config.get("activation_scheme").upper()) \ + if hf_quant_config.get("activation_scheme") else None + # set exclude_modules + if quant_config.exclude_modules: + if hf_quant_config.get("ignored_modules"): + quant_config.exclude_modules += hf_quant_config.get("ignored_modules") + else: + quant_config.exclude_modules = hf_quant_config.get("ignored_modules") + + # set exclude_quant_config + hf_ignored_quantization_config = hf_quant_config.get("ignored_quantization_config") + if hf_ignored_quantization_config: + quant_config.exclude_quant_config = { + "quant_algo": QuantAlgo( + hf_ignored_quantization_config.get("quant_method").upper() + ) if hf_ignored_quantization_config.get("quant_method") else None, + "kv_cache_quant_algo": QuantAlgo( + hf_ignored_quantization_config.get("kv_cache_quant_method").upper() + ) if hf_ignored_quantization_config.get("kv_cache_quant_method") else None, + "activation_scheme": ActivationScheme( + hf_ignored_quantization_config.get("activation_scheme").upper() + ) if hf_ignored_quantization_config.get("activation_scheme") else None, + } + + logger.info(f"Load quantization config from pretrained config, quant_config: {quant_config}") return quant_config, layer_quant_config @@ -499,6 +574,10 @@ def cached_file(path_or_repo_id, file_name): 'hf_quant_config.json'): quant_config, layer_quant_config = cls.load_modelopt_quant_config( quant_config_file, checkpoint_dir, moe_backend) + elif quant_config_file := cached_file(checkpoint_dir, + 'angelslim_hf_quant_config.json'): + quant_config, layer_quant_config = cls.load_angelslim_quant_config( + quant_config_file, checkpoint_dir, moe_backend) # quantized ckpt in other formats elif hasattr(pretrained_config, "quantization_config"): hf_quant_config = pretrained_config.quantization_config diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 4d3d50abb33..c6ce314e4af 100755 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -290,7 +290,7 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, if names[-1] == "kv_b_proj": # TODO: remove weight_dequant after enabling fp8_bmm dequant_kv_b_proj = self.model_config.quant_config.is_module_excluded_from_quantization( - names[-1]) + names[-1]) and self.model_config.quant_config.exclude_quantization is None if dequant_kv_b_proj: kv_b_proj, k_b_proj_trans = load_kv_b_proj_and_k_b_proj_trans_dequant( name) diff --git a/tensorrt_llm/_torch/models/modeling_utils.py b/tensorrt_llm/_torch/models/modeling_utils.py index 95cc9cac6be..ad56a91fe0b 100755 --- a/tensorrt_llm/_torch/models/modeling_utils.py +++ b/tensorrt_llm/_torch/models/modeling_utils.py @@ -480,7 +480,14 @@ def apply_quant_config_exclude_modules(self): kv_cache_quant_algo = None if quant_config: kv_cache_quant_algo = quant_config.kv_cache_quant_algo - new_config = QuantConfig(kv_cache_quant_algo=kv_cache_quant_algo) + quant_algo = None + activation_scheme = None + exclude_quant_config = quant_config.exclude_quant_config + if exclude_quant_config: + quant_algo = exclude_quant_config.get("quant_algo", None) + activation_scheme = exclude_quant_config.get("activation_scheme", None) + new_config = QuantConfig( + quant_algo=quant_algo, kv_cache_quant_algo=kv_cache_quant_algo, activation_scheme=activation_scheme) if quant_config is not None: if quant_config.exclude_modules is not None: diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index 1e56f90d5e9..53b654e54c8 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -224,9 +224,10 @@ def load_expert_weights_to_dst( MoEWeightLoadingMode.VANILLA, MoEWeightLoadingMode.W4A8_CUSTOM ]: - w1_weight = weights[f"{expert_id}.w1.weight"] - w3_weight = weights[f"{expert_id}.w3.weight"] - w2_weight = weights[f"{expert_id}.w2.weight"] + weight_name = "qweight" if f"{expert_id}.w1.qweight" in weights else "weight" + w1_weight = weights[f"{expert_id}.w1.{weight_name}"] + w3_weight = weights[f"{expert_id}.w3.{weight_name}"] + w2_weight = weights[f"{expert_id}.w2.{weight_name}"] if module.bias: w1_bias = weights[f"{expert_id}.w1.bias"] w3_bias = weights[f"{expert_id}.w3.bias"] @@ -1085,6 +1086,10 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): w4a8_custom = module.weight_loading_mode == MoEWeightLoadingMode.W4A8_CUSTOM if w4a8_custom: weight_scale_name = "weight_scale_inv" + for expert_id in module.initial_local_expert_ids: + if f"{expert_id}.w3.weight_scale.int4" in weights: + weight_scale_name = "weight_scale.int4" + break else: weight_scale_name = "weight_scale" @@ -1107,13 +1112,31 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): all_w3_w1_input_scales_max = torch.max( torch.stack(all_w3_input_scales), torch.stack(all_w1_input_scales)).max() + all_w3_w1_scales_fp8_max = None + has_fp8_weight_scale = False if w4a8_custom: # In custom W4A8 ckpt, per-tensor input_scale and per-channel pre_quant_scale are fused into input_scale module.fc31_act_scale.data.copy_( torch.ones_like(module.fc31_act_scale, device=self.device) * (1 / all_w3_w1_input_scales_max)) + + for expert_id in module.initial_local_expert_ids: + if f"{expert_id}.w1.weight_scale" in weights: + has_fp8_weight_scale = True + break + if has_fp8_weight_scale: + all_w3_w1_scales_fp8_max = [] + for expert_id in module.initial_local_expert_ids: + w1_weight_scale_fp8 = load_weight_shard(weights[f"{expert_id}.w1.weight_scale"], + device=self.device) + w3_weight_scale_fp8 = load_weight_shard(weights[f"{expert_id}.w3.weight_scale"], + device=self.device) + all_w3_w1_scales_fp8_max.append(torch.max(w3_weight_scale_fp8, w1_weight_scale_fp8)) + all_w3_w1_scales_fp8_max = torch.stack(all_w3_w1_scales_fp8_max).reshape(module.fc31_alpha.shape) + else: + all_w3_w1_scales_fp8_max = torch.ones_like(module.fc31_alpha, device=self.device) module.fc31_alpha.data.copy_( - (torch.ones_like(module.fc31_alpha, device=self.device) * + (all_w3_w1_scales_fp8_max * all_w3_w1_input_scales_max).float()) else: # In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and per-channel pre_quant_scale are separately stored @@ -1192,6 +1215,8 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): device=self.device) for expert_id in module.initial_local_expert_ids ] + if w4a8_custom and has_fp8_weight_scale: + all_w3_scales = torch.stack(all_w3_scales) / all_w3_w1_scales_fp8_max.unsqueeze(2) all_w1_scales = [ load_weight_shard(weights[f"{expert_id}.w1.{weight_scale_name}"], module.tp_size, @@ -1200,9 +1225,15 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): device=self.device) for expert_id in module.initial_local_expert_ids ] - all_w3_w1_scales = torch.cat( - [torch.stack(all_w3_scales), - torch.stack(all_w1_scales)], dim=-2) + if w4a8_custom and has_fp8_weight_scale: + all_w1_scales = torch.stack(all_w1_scales) / all_w3_w1_scales_fp8_max.unsqueeze(2) + all_w3_w1_scales = torch.cat( + [all_w3_scales, + all_w1_scales], dim=-2) + else: + all_w3_w1_scales = torch.cat( + [torch.stack(all_w3_scales), + torch.stack(all_w1_scales)], dim=-2) if module.sm_version == 89: w3_w1_scales = all_w3_w1_scales.to(torch.float16).view(module.dtype) else: @@ -1234,14 +1265,23 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): all_w2_input_scales_max = torch.stack(all_w2_input_scales).to( module.dtype).max() + all_w2_scales_fp8 = None if w4a8_custom: # In custom W4A8 ckpt, per-tensor input_scale and per-channel pre_quant_scale are fused into input_scale module.fc2_act_scale.data.copy_( torch.ones_like(module.fc2_act_scale, device=self.device) * (1 / all_w2_input_scales_max)) # In custom W4A8 ckpt, per-tensor weight_scale_2 is fused into alpha + if has_fp8_weight_scale: + all_w2_scales_fp8 = [ + load_weight_shard(weights[f"{expert_id}.w2.weight_scale"], device=self.device) + for expert_id in module.initial_local_expert_ids + ] + all_w2_scales_fp8 = torch.stack(all_w2_scales_fp8).reshape(module.fc2_alpha.shape) + else: + all_w2_scales_fp8 = torch.ones_like(module.fc2_alpha, device=self.device) module.fc2_alpha.data.copy_( - (torch.ones_like(module.fc2_alpha, device=self.device) * + (all_w2_scales_fp8 * all_w2_input_scales_max).float()) else: # In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and per-channel pre_quant_scale are separately stored @@ -1288,6 +1328,8 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): device=self.device) for expert_id in module.initial_local_expert_ids ] + if w4a8_custom and has_fp8_weight_scale: + all_w2_scales = torch.stack(all_w2_scales) / all_w2_scales_fp8.unsqueeze(2) if module.sm_version == 89: w2_scales = torch.stack(all_w2_scales).to(torch.float16).view( module.dtype) diff --git a/tensorrt_llm/llmapi/llm_utils.py b/tensorrt_llm/llmapi/llm_utils.py index ecd0e5bfc16..ab31e342156 100644 --- a/tensorrt_llm/llmapi/llm_utils.py +++ b/tensorrt_llm/llmapi/llm_utils.py @@ -25,7 +25,7 @@ from ..logger import logger from ..mapping import Mapping from ..models.automodel import MODEL_MAP, AutoConfig, AutoModelForCausalLM -from ..models.modeling_utils import PretrainedConfig, QuantAlgo, QuantConfig +from ..models.modeling_utils import PretrainedConfig, QuantAlgo, QuantConfig, ActivationScheme from ..module import Module from .build_cache import (BuildCache, BuildCacheConfig, CachedStage, get_build_cache_config_from_env) @@ -423,10 +423,43 @@ def _update_from_hf_quant_config(self) -> bool: 'block.*.attn.out', 'block.*.mlp.gate', 'block.*.attn.qkv', 'embedding', 'unembedding' ] + elif hf_quant_config.get("quant_method") == "w4a8_awq": + quant_config.quant_algo = QuantAlgo.W4A8_AWQ else: raise NotImplementedError( f"Unsupported quantization_config: {hf_quant_config}.") + # set kv_cache_quant_algo + quant_config.kv_cache_quant_algo = QuantAlgo( + hf_quant_config.get("kv_cache_quant_method").upper() + ) if hf_quant_config.get("kv_cache_quant_method") else None + # set activation_scheme + quant_config.activation_scheme = ActivationScheme( + hf_quant_config.get("activation_scheme").upper() + ) if hf_quant_config.get("activation_scheme") else None + # set exclude_modules + if quant_config.exclude_modules: + if hf_quant_config.get("ignored_modules"): + quant_config.exclude_modules += hf_quant_config.get("ignored_modules") + else: + quant_config.exclude_modules = hf_quant_config.get("ignored_modules") + # set exclude_quant_config + hf_ignored_quantization_config = hf_quant_config.get("ignored_quantization_config") + if hf_ignored_quantization_config: + quant_config.exclude_quant_config = { + "quant_algo": QuantAlgo( + hf_ignored_quantization_config.get("quant_method").upper() + ) if hf_ignored_quantization_config.get("quant_method") else None, + "kv_cache_quant_algo": QuantAlgo( + hf_ignored_quantization_config.get("kv_cache_quant_method").upper() + ) if hf_ignored_quantization_config.get("kv_cache_quant_method") else None, + "activation_scheme": ActivationScheme( + hf_ignored_quantization_config.get("activation_scheme").upper() + ) if hf_ignored_quantization_config.get("activation_scheme") else None, + } + logger.info( + f"Detected quantization_config: {quant_config}." + ) return True return False diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index 03c1ee60ae5..e87d0c64fbb 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -45,7 +45,7 @@ WeightOnlyQuantLinear, WeightOnlyQuantRowLinear) from ..quantization.mode import (KV_CACHE_QUANT_ALGO_LIST, QUANT_ALGO_LIST, - W8A8_SQ_PLUGIN_LIST, QuantAlgo) + W8A8_SQ_PLUGIN_LIST, QuantAlgo, ActivationScheme) from ..quantization.utils import fp4_utils from ..top_model_mixin import TopModelMixin from .convert_utils import weight_only_quantize_dict @@ -143,6 +143,8 @@ class QuantConfig: pre_quant_scale (bool): Whether to use pre-quant scale for quantization. Defaults to False. exclude_modules (List[str], optional): The module name patterns that are skipped in quantization. Defaults to None. mamba_ssm_cache_dtype (str, optional): The data type for mamba SSM cache. Defaults to None. + exclude_quant_config (Dict, optional): The model of exclude_modules will use exclude_quant_config. + activation_scheme (tensorrt_llm.quantization.mode.ActivationScheme, optional): The input of activation quantize scheme. """ quant_algo: Optional[QuantAlgo] = None kv_cache_quant_algo: Optional[QuantAlgo] = None @@ -154,6 +156,8 @@ class QuantConfig: pre_quant_scale: bool = False exclude_modules: Optional[List[str]] = None mamba_ssm_cache_dtype: Optional[str] = None + exclude_quant_config: Optional[Dict] = None + activation_scheme: Optional[ActivationScheme] = None @cached_property def quant_mode(self) -> QuantModeWrapper: diff --git a/tensorrt_llm/quantization/mode.py b/tensorrt_llm/quantization/mode.py index 4615bc1376f..f6f31c4bd0d 100644 --- a/tensorrt_llm/quantization/mode.py +++ b/tensorrt_llm/quantization/mode.py @@ -473,3 +473,8 @@ class GroupwiseQuantAlgo: PRE_QUANT_SCALE = 4 W4A8_ALPHA = 8 INT8_WEIGHT = 16 + + +class ActivationScheme(StrEnum, metaclass=BaseEnumMeta): + STATIC = auto() + DYNAMIC = auto() From f975f38d04cd14a0d89a34dbaa0ca35f45a77deb Mon Sep 17 00:00:00 2001 From: beipingpan Date: Wed, 10 Sep 2025 17:07:29 +0800 Subject: [PATCH 2/5] Update quantization config Signed-off-by: beipingpan --- tensorrt_llm/_torch/model_config.py | 55 +++++++++++++++----- tensorrt_llm/_torch/models/modeling_utils.py | 16 ++++-- tensorrt_llm/llmapi/llm_utils.py | 8 +++ tensorrt_llm/models/modeling_utils.py | 4 +- 4 files changed, 62 insertions(+), 21 deletions(-) diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 5740141a95b..59d31c1485a 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -348,19 +348,28 @@ def load_angelslim_quant_config(quant_config_file, checkpoint_dir, moe_backend): json_quant_configs.get('activation_scheme', None).upper() ) if json_quant_configs.get("activation_scheme") else None - json_exclude_quant_configs = json_quant_configs.get('exclude_quantization', None) - if json_exclude_quant_configs: + json_exclude_quantization= json_quant_configs.get('exclude_quantization', None) + if json_exclude_quantization: quant_config.exclude_quant_config = { "quant_algo": QuantAlgo( - json_exclude_quant_configs.get('quant_algo', None).upper() - ) if json_exclude_quant_configs.get("quant_algo") else None, + json_exclude_quantization.get('quant_algo', None).upper() + ) if json_exclude_quantization.get("quant_algo") else None, "kv_cache_quant_algo": QuantAlgo( - json_exclude_quant_configs.get("kv_cache_quant_algo").upper() - ) if json_exclude_quant_configs.get("kv_cache_quant_algo") else None, + json_exclude_quantization.get("kv_cache_quant_algo").upper() + ) if json_exclude_quantization.get("kv_cache_quant_algo") else None, "activation_scheme": ActivationScheme( - json_exclude_quant_configs.get('activation_scheme', None).upper() - ) if json_exclude_quant_configs.get("activation_scheme") else None, + json_exclude_quantization.get('activation_scheme', None).upper() + ) if json_exclude_quantization.get("activation_scheme") else None, + "group_size": json_exclude_quantization.get('group_size', None), } + if quant_config.exclude_quantization["quant_algo"] in [QuantAlgo.FP8_BLOCK_SCALES, QuantAlgo.W4A8_AWQ]: + if quant_config.exclude_quantization["group_size"] is None: + quant_config.exclude_quantization["group_size"] = 128 + + if quant_config.quant_algo in [QuantAlgo.FP8_BLOCK_SCALES, QuantAlgo.W4A8_AWQ]: + if quant_config.group_size is None: + quant_config.group_size = 128 + return quant_config, layer_quant_config @staticmethod @@ -407,8 +416,11 @@ def load_hf_quant_config(hf_quant_config, moe_backend): 'block.*.attn.out', 'block.*.mlp.gate', 'block.*.attn.qkv', 'embedding', 'unembedding' ] + elif hf_quant_config.get("quant_method") == "fp8": + quant_config.quant_algo = QuantAlgo.FP8 elif hf_quant_config.get("quant_method") == "w4a8_awq": quant_config.quant_algo = QuantAlgo.W4A8_AWQ + quant_config.group_size = hf_quant_config.get("weight_group_size", 128) else: raise NotImplementedError(f"Unsupported quantization_config: {hf_quant_config}.") @@ -420,25 +432,40 @@ def load_hf_quant_config(hf_quant_config, moe_backend): if hf_quant_config.get("activation_scheme") else None # set exclude_modules if quant_config.exclude_modules: - if hf_quant_config.get("ignored_modules"): - quant_config.exclude_modules += hf_quant_config.get("ignored_modules") + if hf_quant_config.get("ignored_layers"): + quant_config.exclude_modules += hf_quant_config.get("ignored_layers") else: - quant_config.exclude_modules = hf_quant_config.get("ignored_modules") + quant_config.exclude_modules = hf_quant_config.get("ignored_layers") # set exclude_quant_config hf_ignored_quantization_config = hf_quant_config.get("ignored_quantization_config") if hf_ignored_quantization_config: quant_config.exclude_quant_config = { - "quant_algo": QuantAlgo( - hf_ignored_quantization_config.get("quant_method").upper() - ) if hf_ignored_quantization_config.get("quant_method") else None, "kv_cache_quant_algo": QuantAlgo( hf_ignored_quantization_config.get("kv_cache_quant_method").upper() ) if hf_ignored_quantization_config.get("kv_cache_quant_method") else None, "activation_scheme": ActivationScheme( hf_ignored_quantization_config.get("activation_scheme").upper() ) if hf_ignored_quantization_config.get("activation_scheme") else None, + "group_size": 128, } + if hf_ignored_quantization_config.get( + "quant_method") == "fp8" and hf_ignored_quantization_config.get("weight_block_size", []): + quant_config.exclude_quantization["quant_algo"] = QuantAlgo.FP8_BLOCK_SCALES + block_size = hf_ignored_quantization_config.get("weight_block_size", []) + assert tuple(block_size) == ( + 128, + 128), "FP8_BLOCK_SCALES only supports block_size=(128,128)" + quant_config.exclude_quantization["group_size"] = block_size[0] + elif hf_ignored_quantization_config.get("quant_method") == "fp8": + quant_config.exclude_quantization["quant_algo"] = QuantAlgo.FP8 + elif hf_ignored_quantization_config.get("quant_method") == "w4a8_awq": + quant_config.exclude_quantization["quant_algo"] = QuantAlgo.W4A8_AWQ + quant_config.exclude_quantization["group_size"] = hf_ignored_quantization_config.get( + "weight_group_size", 128) + else: + raise NotImplementedError(f"Unsupported quantization_config.ignored_quantization_config: " + f"{hf_ignored_quantization_config}.") logger.info(f"Load quantization config from pretrained config, quant_config: {quant_config}") diff --git a/tensorrt_llm/_torch/models/modeling_utils.py b/tensorrt_llm/_torch/models/modeling_utils.py index ad56a91fe0b..cd0a54df4e7 100755 --- a/tensorrt_llm/_torch/models/modeling_utils.py +++ b/tensorrt_llm/_torch/models/modeling_utils.py @@ -482,12 +482,18 @@ def apply_quant_config_exclude_modules(self): kv_cache_quant_algo = quant_config.kv_cache_quant_algo quant_algo = None activation_scheme = None - exclude_quant_config = quant_config.exclude_quant_config - if exclude_quant_config: - quant_algo = exclude_quant_config.get("quant_algo", None) - activation_scheme = exclude_quant_config.get("activation_scheme", None) + group_size = 128 + exclude_quantization = quant_config.exclude_quantization + if exclude_quantization: + quant_algo = exclude_quantization.get("quant_algo", None) + activation_scheme = exclude_quantization.get("activation_scheme", None) + group_size = exclude_quantization.get("group_size", 128) new_config = QuantConfig( - quant_algo=quant_algo, kv_cache_quant_algo=kv_cache_quant_algo, activation_scheme=activation_scheme) + quant_algo=quant_algo, + kv_cache_quant_algo=kv_cache_quant_algo, + activation_scheme=activation_scheme, + group_size=group_size, + ) if quant_config is not None: if quant_config.exclude_modules is not None: diff --git a/tensorrt_llm/llmapi/llm_utils.py b/tensorrt_llm/llmapi/llm_utils.py index ab31e342156..aa7d0a3c1d9 100644 --- a/tensorrt_llm/llmapi/llm_utils.py +++ b/tensorrt_llm/llmapi/llm_utils.py @@ -414,6 +414,11 @@ def _update_from_hf_quant_config(self) -> bool: "weight_block_size"): quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES quant_config.exclude_modules = ["*eh_proj"] + block_size = hf_quant_config.get("weight_block_size", []) + assert tuple(block_size) == ( + 128, + 128), "FP8_BLOCK_SCALES only supports block_size=(128,128)" + quant_config.group_size = block_size[0] elif hf_quant_config.get("quant_method") == "mxfp4": from .._torch.model_config import ModelConfig quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo( @@ -423,8 +428,11 @@ def _update_from_hf_quant_config(self) -> bool: 'block.*.attn.out', 'block.*.mlp.gate', 'block.*.attn.qkv', 'embedding', 'unembedding' ] + elif hf_quant_config.get("quant_method") == "fp8": + quant_config.quant_algo = QuantAlgo.FP8 elif hf_quant_config.get("quant_method") == "w4a8_awq": quant_config.quant_algo = QuantAlgo.W4A8_AWQ + quant_config.group_size = hf_quant_config.get("weight_group_size", 128) else: raise NotImplementedError( f"Unsupported quantization_config: {hf_quant_config}.") diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index e87d0c64fbb..fed58d3b617 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -143,7 +143,7 @@ class QuantConfig: pre_quant_scale (bool): Whether to use pre-quant scale for quantization. Defaults to False. exclude_modules (List[str], optional): The module name patterns that are skipped in quantization. Defaults to None. mamba_ssm_cache_dtype (str, optional): The data type for mamba SSM cache. Defaults to None. - exclude_quant_config (Dict, optional): The model of exclude_modules will use exclude_quant_config. + exclude_quantization (Dict, optional): The model of exclude_modules will use exclude_quantization. activation_scheme (tensorrt_llm.quantization.mode.ActivationScheme, optional): The input of activation quantize scheme. """ quant_algo: Optional[QuantAlgo] = None @@ -156,7 +156,7 @@ class QuantConfig: pre_quant_scale: bool = False exclude_modules: Optional[List[str]] = None mamba_ssm_cache_dtype: Optional[str] = None - exclude_quant_config: Optional[Dict] = None + exclude_quantization: Optional[Dict] = None activation_scheme: Optional[ActivationScheme] = None @cached_property From 2f18e383dca112906ff915a3cacc7ef6001854eb Mon Sep 17 00:00:00 2001 From: beipingpan Date: Mon, 27 Oct 2025 20:16:39 +0800 Subject: [PATCH 3/5] Reformate code Signed-off-by: beipingpan --- tensorrt_llm/_torch/model_config.py | 106 ++++++++++++------ tensorrt_llm/_torch/models/modeling_utils.py | 3 +- .../_torch/modules/fused_moe/quantization.py | 50 +++++---- tensorrt_llm/llmapi/llm_utils.py | 50 +++++---- tensorrt_llm/models/modeling_utils.py | 3 +- 5 files changed, 134 insertions(+), 78 deletions(-) diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 59d31c1485a..e2cd0b8fc95 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -20,7 +20,7 @@ from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig -from tensorrt_llm.quantization.mode import QuantAlgo, ActivationScheme +from tensorrt_llm.quantization.mode import ActivationScheme, QuantAlgo TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig) @@ -333,7 +333,9 @@ def load_angelslim_quant_config(quant_config_file, checkpoint_dir, moe_backend): json_quant_configs = quant_config_dict['quantization'] quant_config.quant_algo = QuantAlgo( - json_quant_configs.get('quant_algo', None).upper()) if json_quant_configs.get("quant_algo") else None + json_quant_configs.get( + 'quant_algo', + None).upper()) if json_quant_configs.get("quant_algo") else None # fp8_pb_wo from modelopt is the same as FP8_BLOCK_SCALES if quant_config.quant_algo == "fp8_pb_wo": quant_config.quant_algo = QuantAlgo('FP8_BLOCK_SCALES') @@ -348,25 +350,36 @@ def load_angelslim_quant_config(quant_config_file, checkpoint_dir, moe_backend): json_quant_configs.get('activation_scheme', None).upper() ) if json_quant_configs.get("activation_scheme") else None - json_exclude_quantization= json_quant_configs.get('exclude_quantization', None) + json_exclude_quantization = json_quant_configs.get( + 'exclude_quantization', None) if json_exclude_quantization: quant_config.exclude_quant_config = { - "quant_algo": QuantAlgo( - json_exclude_quantization.get('quant_algo', None).upper() - ) if json_exclude_quantization.get("quant_algo") else None, - "kv_cache_quant_algo": QuantAlgo( - json_exclude_quantization.get("kv_cache_quant_algo").upper() - ) if json_exclude_quantization.get("kv_cache_quant_algo") else None, - "activation_scheme": ActivationScheme( - json_exclude_quantization.get('activation_scheme', None).upper() - ) if json_exclude_quantization.get("activation_scheme") else None, - "group_size": json_exclude_quantization.get('group_size', None), + "quant_algo": + QuantAlgo( + json_exclude_quantization.get('quant_algo', None).upper()) + if json_exclude_quantization.get("quant_algo") else None, + "kv_cache_quant_algo": + QuantAlgo( + json_exclude_quantization.get( + "kv_cache_quant_algo").upper()) if + json_exclude_quantization.get("kv_cache_quant_algo") else None, + "activation_scheme": + ActivationScheme( + json_exclude_quantization.get('activation_scheme', + None).upper()) + if json_exclude_quantization.get("activation_scheme") else None, + "group_size": + json_exclude_quantization.get('group_size', None), } - if quant_config.exclude_quantization["quant_algo"] in [QuantAlgo.FP8_BLOCK_SCALES, QuantAlgo.W4A8_AWQ]: + if quant_config.exclude_quantization["quant_algo"] in [ + QuantAlgo.FP8_BLOCK_SCALES, QuantAlgo.W4A8_AWQ + ]: if quant_config.exclude_quantization["group_size"] is None: quant_config.exclude_quantization["group_size"] = 128 - if quant_config.quant_algo in [QuantAlgo.FP8_BLOCK_SCALES, QuantAlgo.W4A8_AWQ]: + if quant_config.quant_algo in [ + QuantAlgo.FP8_BLOCK_SCALES, QuantAlgo.W4A8_AWQ + ]: if quant_config.group_size is None: quant_config.group_size = 128 @@ -420,9 +433,11 @@ def load_hf_quant_config(hf_quant_config, moe_backend): quant_config.quant_algo = QuantAlgo.FP8 elif hf_quant_config.get("quant_method") == "w4a8_awq": quant_config.quant_algo = QuantAlgo.W4A8_AWQ - quant_config.group_size = hf_quant_config.get("weight_group_size", 128) + quant_config.group_size = hf_quant_config.get( + "weight_group_size", 128) else: - raise NotImplementedError(f"Unsupported quantization_config: {hf_quant_config}.") + raise NotImplementedError( + f"Unsupported quantization_config: {hf_quant_config}.") # set kv_cache_quant_algo quant_config.kv_cache_quant_algo = QuantAlgo(hf_quant_config.get("kv_cache_quant_method").upper()) \ @@ -433,41 +448,60 @@ def load_hf_quant_config(hf_quant_config, moe_backend): # set exclude_modules if quant_config.exclude_modules: if hf_quant_config.get("ignored_layers"): - quant_config.exclude_modules += hf_quant_config.get("ignored_layers") + quant_config.exclude_modules += hf_quant_config.get( + "ignored_layers") else: quant_config.exclude_modules = hf_quant_config.get("ignored_layers") # set exclude_quant_config - hf_ignored_quantization_config = hf_quant_config.get("ignored_quantization_config") + hf_ignored_quantization_config = hf_quant_config.get( + "ignored_quantization_config") if hf_ignored_quantization_config: quant_config.exclude_quant_config = { - "kv_cache_quant_algo": QuantAlgo( - hf_ignored_quantization_config.get("kv_cache_quant_method").upper() - ) if hf_ignored_quantization_config.get("kv_cache_quant_method") else None, - "activation_scheme": ActivationScheme( - hf_ignored_quantization_config.get("activation_scheme").upper() - ) if hf_ignored_quantization_config.get("activation_scheme") else None, - "group_size": 128, + "kv_cache_quant_algo": + QuantAlgo( + hf_ignored_quantization_config.get( + "kv_cache_quant_method").upper()) + if hf_ignored_quantization_config.get("kv_cache_quant_method") + else None, + "activation_scheme": + ActivationScheme( + hf_ignored_quantization_config.get( + "activation_scheme").upper()) + if hf_ignored_quantization_config.get("activation_scheme") else + None, + "group_size": + 128, } if hf_ignored_quantization_config.get( - "quant_method") == "fp8" and hf_ignored_quantization_config.get("weight_block_size", []): - quant_config.exclude_quantization["quant_algo"] = QuantAlgo.FP8_BLOCK_SCALES - block_size = hf_ignored_quantization_config.get("weight_block_size", []) + "quant_method" + ) == "fp8" and hf_ignored_quantization_config.get( + "weight_block_size", []): + quant_config.exclude_quantization[ + "quant_algo"] = QuantAlgo.FP8_BLOCK_SCALES + block_size = hf_ignored_quantization_config.get( + "weight_block_size", []) assert tuple(block_size) == ( 128, 128), "FP8_BLOCK_SCALES only supports block_size=(128,128)" quant_config.exclude_quantization["group_size"] = block_size[0] elif hf_ignored_quantization_config.get("quant_method") == "fp8": quant_config.exclude_quantization["quant_algo"] = QuantAlgo.FP8 - elif hf_ignored_quantization_config.get("quant_method") == "w4a8_awq": - quant_config.exclude_quantization["quant_algo"] = QuantAlgo.W4A8_AWQ - quant_config.exclude_quantization["group_size"] = hf_ignored_quantization_config.get( - "weight_group_size", 128) + elif hf_ignored_quantization_config.get( + "quant_method") == "w4a8_awq": + quant_config.exclude_quantization[ + "quant_algo"] = QuantAlgo.W4A8_AWQ + quant_config.exclude_quantization[ + "group_size"] = hf_ignored_quantization_config.get( + "weight_group_size", 128) else: - raise NotImplementedError(f"Unsupported quantization_config.ignored_quantization_config: " - f"{hf_ignored_quantization_config}.") + raise NotImplementedError( + f"Unsupported quantization_config.ignored_quantization_config: " + f"{hf_ignored_quantization_config}.") - logger.info(f"Load quantization config from pretrained config, quant_config: {quant_config}") + logger.info( + f"Load quantization config from pretrained config, quant_config: {quant_config}" + ) return quant_config, layer_quant_config diff --git a/tensorrt_llm/_torch/models/modeling_utils.py b/tensorrt_llm/_torch/models/modeling_utils.py index cd0a54df4e7..305fe32276e 100755 --- a/tensorrt_llm/_torch/models/modeling_utils.py +++ b/tensorrt_llm/_torch/models/modeling_utils.py @@ -486,7 +486,8 @@ def apply_quant_config_exclude_modules(self): exclude_quantization = quant_config.exclude_quantization if exclude_quantization: quant_algo = exclude_quantization.get("quant_algo", None) - activation_scheme = exclude_quantization.get("activation_scheme", None) + activation_scheme = exclude_quantization.get( + "activation_scheme", None) group_size = exclude_quantization.get("group_size", 128) new_config = QuantConfig( quant_algo=quant_algo, diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index 53b654e54c8..4b7a7578385 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -1127,17 +1127,21 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): if has_fp8_weight_scale: all_w3_w1_scales_fp8_max = [] for expert_id in module.initial_local_expert_ids: - w1_weight_scale_fp8 = load_weight_shard(weights[f"{expert_id}.w1.weight_scale"], - device=self.device) - w3_weight_scale_fp8 = load_weight_shard(weights[f"{expert_id}.w3.weight_scale"], - device=self.device) - all_w3_w1_scales_fp8_max.append(torch.max(w3_weight_scale_fp8, w1_weight_scale_fp8)) - all_w3_w1_scales_fp8_max = torch.stack(all_w3_w1_scales_fp8_max).reshape(module.fc31_alpha.shape) + w1_weight_scale_fp8 = load_weight_shard( + weights[f"{expert_id}.w1.weight_scale"], + device=self.device) + w3_weight_scale_fp8 = load_weight_shard( + weights[f"{expert_id}.w3.weight_scale"], + device=self.device) + all_w3_w1_scales_fp8_max.append( + torch.max(w3_weight_scale_fp8, w1_weight_scale_fp8)) + all_w3_w1_scales_fp8_max = torch.stack( + all_w3_w1_scales_fp8_max).reshape(module.fc31_alpha.shape) else: - all_w3_w1_scales_fp8_max = torch.ones_like(module.fc31_alpha, device=self.device) + all_w3_w1_scales_fp8_max = torch.ones_like(module.fc31_alpha, + device=self.device) module.fc31_alpha.data.copy_( - (all_w3_w1_scales_fp8_max * - all_w3_w1_input_scales_max).float()) + (all_w3_w1_scales_fp8_max * all_w3_w1_input_scales_max).float()) else: # In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and per-channel pre_quant_scale are separately stored all_w3_pre_quant_scales = [ @@ -1216,7 +1220,8 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): for expert_id in module.initial_local_expert_ids ] if w4a8_custom and has_fp8_weight_scale: - all_w3_scales = torch.stack(all_w3_scales) / all_w3_w1_scales_fp8_max.unsqueeze(2) + all_w3_scales = torch.stack( + all_w3_scales) / all_w3_w1_scales_fp8_max.unsqueeze(2) all_w1_scales = [ load_weight_shard(weights[f"{expert_id}.w1.{weight_scale_name}"], module.tp_size, @@ -1226,14 +1231,14 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): for expert_id in module.initial_local_expert_ids ] if w4a8_custom and has_fp8_weight_scale: - all_w1_scales = torch.stack(all_w1_scales) / all_w3_w1_scales_fp8_max.unsqueeze(2) - all_w3_w1_scales = torch.cat( - [all_w3_scales, - all_w1_scales], dim=-2) + all_w1_scales = torch.stack( + all_w1_scales) / all_w3_w1_scales_fp8_max.unsqueeze(2) + all_w3_w1_scales = torch.cat([all_w3_scales, all_w1_scales], dim=-2) else: all_w3_w1_scales = torch.cat( [torch.stack(all_w3_scales), - torch.stack(all_w1_scales)], dim=-2) + torch.stack(all_w1_scales)], + dim=-2) if module.sm_version == 89: w3_w1_scales = all_w3_w1_scales.to(torch.float16).view(module.dtype) else: @@ -1274,15 +1279,17 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): # In custom W4A8 ckpt, per-tensor weight_scale_2 is fused into alpha if has_fp8_weight_scale: all_w2_scales_fp8 = [ - load_weight_shard(weights[f"{expert_id}.w2.weight_scale"], device=self.device) + load_weight_shard(weights[f"{expert_id}.w2.weight_scale"], + device=self.device) for expert_id in module.initial_local_expert_ids ] - all_w2_scales_fp8 = torch.stack(all_w2_scales_fp8).reshape(module.fc2_alpha.shape) + all_w2_scales_fp8 = torch.stack(all_w2_scales_fp8).reshape( + module.fc2_alpha.shape) else: - all_w2_scales_fp8 = torch.ones_like(module.fc2_alpha, device=self.device) + all_w2_scales_fp8 = torch.ones_like(module.fc2_alpha, + device=self.device) module.fc2_alpha.data.copy_( - (all_w2_scales_fp8 * - all_w2_input_scales_max).float()) + (all_w2_scales_fp8 * all_w2_input_scales_max).float()) else: # In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and per-channel pre_quant_scale are separately stored all_w2_pre_quant_scales = [ @@ -1329,7 +1336,8 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): for expert_id in module.initial_local_expert_ids ] if w4a8_custom and has_fp8_weight_scale: - all_w2_scales = torch.stack(all_w2_scales) / all_w2_scales_fp8.unsqueeze(2) + all_w2_scales = torch.stack( + all_w2_scales) / all_w2_scales_fp8.unsqueeze(2) if module.sm_version == 89: w2_scales = torch.stack(all_w2_scales).to(torch.float16).view( module.dtype) diff --git a/tensorrt_llm/llmapi/llm_utils.py b/tensorrt_llm/llmapi/llm_utils.py index aa7d0a3c1d9..2664fd9df55 100644 --- a/tensorrt_llm/llmapi/llm_utils.py +++ b/tensorrt_llm/llmapi/llm_utils.py @@ -25,7 +25,8 @@ from ..logger import logger from ..mapping import Mapping from ..models.automodel import MODEL_MAP, AutoConfig, AutoModelForCausalLM -from ..models.modeling_utils import PretrainedConfig, QuantAlgo, QuantConfig, ActivationScheme +from ..models.modeling_utils import (ActivationScheme, PretrainedConfig, + QuantAlgo, QuantConfig) from ..module import Module from .build_cache import (BuildCache, BuildCacheConfig, CachedStage, get_build_cache_config_from_env) @@ -416,8 +417,8 @@ def _update_from_hf_quant_config(self) -> bool: quant_config.exclude_modules = ["*eh_proj"] block_size = hf_quant_config.get("weight_block_size", []) assert tuple(block_size) == ( - 128, - 128), "FP8_BLOCK_SCALES only supports block_size=(128,128)" + 128, 128 + ), "FP8_BLOCK_SCALES only supports block_size=(128,128)" quant_config.group_size = block_size[0] elif hf_quant_config.get("quant_method") == "mxfp4": from .._torch.model_config import ModelConfig @@ -432,7 +433,8 @@ def _update_from_hf_quant_config(self) -> bool: quant_config.quant_algo = QuantAlgo.FP8 elif hf_quant_config.get("quant_method") == "w4a8_awq": quant_config.quant_algo = QuantAlgo.W4A8_AWQ - quant_config.group_size = hf_quant_config.get("weight_group_size", 128) + quant_config.group_size = hf_quant_config.get( + "weight_group_size", 128) else: raise NotImplementedError( f"Unsupported quantization_config: {hf_quant_config}.") @@ -448,26 +450,36 @@ def _update_from_hf_quant_config(self) -> bool: # set exclude_modules if quant_config.exclude_modules: if hf_quant_config.get("ignored_modules"): - quant_config.exclude_modules += hf_quant_config.get("ignored_modules") + quant_config.exclude_modules += hf_quant_config.get( + "ignored_modules") else: - quant_config.exclude_modules = hf_quant_config.get("ignored_modules") + quant_config.exclude_modules = hf_quant_config.get( + "ignored_modules") # set exclude_quant_config - hf_ignored_quantization_config = hf_quant_config.get("ignored_quantization_config") + hf_ignored_quantization_config = hf_quant_config.get( + "ignored_quantization_config") if hf_ignored_quantization_config: quant_config.exclude_quant_config = { - "quant_algo": QuantAlgo( - hf_ignored_quantization_config.get("quant_method").upper() - ) if hf_ignored_quantization_config.get("quant_method") else None, - "kv_cache_quant_algo": QuantAlgo( - hf_ignored_quantization_config.get("kv_cache_quant_method").upper() - ) if hf_ignored_quantization_config.get("kv_cache_quant_method") else None, - "activation_scheme": ActivationScheme( - hf_ignored_quantization_config.get("activation_scheme").upper() - ) if hf_ignored_quantization_config.get("activation_scheme") else None, + "quant_algo": + QuantAlgo( + hf_ignored_quantization_config.get( + "quant_method").upper()) + if hf_ignored_quantization_config.get("quant_method") + else None, + "kv_cache_quant_algo": + QuantAlgo( + hf_ignored_quantization_config.get( + "kv_cache_quant_method").upper()) + if hf_ignored_quantization_config.get( + "kv_cache_quant_method") else None, + "activation_scheme": + ActivationScheme( + hf_ignored_quantization_config.get( + "activation_scheme").upper()) if + hf_ignored_quantization_config.get("activation_scheme") + else None, } - logger.info( - f"Detected quantization_config: {quant_config}." - ) + logger.info(f"Detected quantization_config: {quant_config}.") return True return False diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index fed58d3b617..666eb0eccba 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -45,7 +45,8 @@ WeightOnlyQuantLinear, WeightOnlyQuantRowLinear) from ..quantization.mode import (KV_CACHE_QUANT_ALGO_LIST, QUANT_ALGO_LIST, - W8A8_SQ_PLUGIN_LIST, QuantAlgo, ActivationScheme) + W8A8_SQ_PLUGIN_LIST, ActivationScheme, + QuantAlgo) from ..quantization.utils import fp4_utils from ..top_model_mixin import TopModelMixin from .convert_utils import weight_only_quantize_dict From e14abe6dcec22f464305cc8698de330d3ab3c18a Mon Sep 17 00:00:00 2001 From: beipingpan Date: Tue, 28 Oct 2025 17:37:55 +0800 Subject: [PATCH 4/5] Reformat code Signed-off-by: beipingpan --- tensorrt_llm/_torch/model_config.py | 5 +++-- tensorrt_llm/_torch/models/modeling_deepseekv3.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index e2cd0b8fc95..dda2f34e56e 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -323,7 +323,8 @@ def load_modelopt_quant_config(quant_config_file, checkpoint_dir, return quant_config, layer_quant_config @staticmethod - def load_angelslim_quant_config(quant_config_file, checkpoint_dir, moe_backend): + def load_angelslim_quant_config(quant_config_file, checkpoint_dir, + moe_backend): quant_config = QuantConfig() layer_quant_config = None @@ -636,7 +637,7 @@ def cached_file(path_or_repo_id, file_name): quant_config, layer_quant_config = cls.load_modelopt_quant_config( quant_config_file, checkpoint_dir, moe_backend) elif quant_config_file := cached_file(checkpoint_dir, - 'angelslim_hf_quant_config.json'): + 'angelslim_hf_quant_config.json'): quant_config, layer_quant_config = cls.load_angelslim_quant_config( quant_config_file, checkpoint_dir, moe_backend) # quantized ckpt in other formats diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index c6ce314e4af..ab15f71df13 100755 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -290,7 +290,8 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, if names[-1] == "kv_b_proj": # TODO: remove weight_dequant after enabling fp8_bmm dequant_kv_b_proj = self.model_config.quant_config.is_module_excluded_from_quantization( - names[-1]) and self.model_config.quant_config.exclude_quantization is None + names[-1] + ) and self.model_config.quant_config.exclude_quantization is None if dequant_kv_b_proj: kv_b_proj, k_b_proj_trans = load_kv_b_proj_and_k_b_proj_trans_dequant( name) From 431925b951cba33fea80e277aa99b8bb8a5479d3 Mon Sep 17 00:00:00 2001 From: beipingpan Date: Thu, 30 Oct 2025 15:38:59 +0800 Subject: [PATCH 5/5] Fix when quant_config is None Signed-off-by: beipingpan --- tensorrt_llm/_torch/models/modeling_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_utils.py b/tensorrt_llm/_torch/models/modeling_utils.py index 305fe32276e..d66c695365f 100755 --- a/tensorrt_llm/_torch/models/modeling_utils.py +++ b/tensorrt_llm/_torch/models/modeling_utils.py @@ -478,17 +478,17 @@ def apply_quant_config_exclude_modules(self): """ quant_config = self.model_config.quant_config kv_cache_quant_algo = None - if quant_config: - kv_cache_quant_algo = quant_config.kv_cache_quant_algo quant_algo = None activation_scheme = None group_size = 128 - exclude_quantization = quant_config.exclude_quantization - if exclude_quantization: - quant_algo = exclude_quantization.get("quant_algo", None) - activation_scheme = exclude_quantization.get( - "activation_scheme", None) - group_size = exclude_quantization.get("group_size", 128) + if quant_config: + kv_cache_quant_algo = quant_config.kv_cache_quant_algo + exclude_quantization = quant_config.exclude_quantization + if exclude_quantization: + quant_algo = exclude_quantization.get("quant_algo", None) + activation_scheme = exclude_quantization.get( + "activation_scheme", None) + group_size = exclude_quantization.get("group_size", 128) new_config = QuantConfig( quant_algo=quant_algo, kv_cache_quant_algo=kv_cache_quant_algo,