diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index f77d309805e..4148da442b3 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -18,7 +18,7 @@ from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig -from tensorrt_llm.quantization.mode import QuantAlgo +from tensorrt_llm.quantization.mode import QuantAlgo, ActivationScheme TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig) @@ -286,6 +286,56 @@ def load_modelopt_quant_config(quant_config_file, model_dir, moe_backend): ] return quant_config, layer_quant_config + @staticmethod + def load_angelslim_quant_config(quant_config_file, model_dir, moe_backend): + quant_config = QuantConfig() + layer_quant_config = None + + with open(quant_config_file) as f: + quant_config_dict = json.load(f) + + json_quant_configs = quant_config_dict['quantization'] + + quant_config.quant_algo = QuantAlgo( + json_quant_configs.get('quant_algo', None).upper()) if json_quant_configs.get("quant_algo") else None + # fp8_pb_wo from modelopt is the same as FP8_BLOCK_SCALES + if quant_config.quant_algo == "fp8_pb_wo": + quant_config.quant_algo = QuantAlgo('FP8_BLOCK_SCALES') + + quant_config.kv_cache_quant_algo = QuantAlgo( + json_quant_configs.get("kv_cache_quant_algo").upper() + ) if json_quant_configs.get("kv_cache_quant_algo") else None + quant_config.group_size = json_quant_configs.get('group_size', None) + quant_config.exclude_modules = json_quant_configs.get( + 'exclude_modules', None) + quant_config.activation_scheme = ActivationScheme( + json_quant_configs.get('activation_scheme', None).upper() + ) if json_quant_configs.get("activation_scheme") else None + + json_exclude_quantization= json_quant_configs.get('exclude_quantization', None) + if json_exclude_quantization: + quant_config.exclude_quant_config = { + "quant_algo": QuantAlgo( + json_exclude_quantization.get('quant_algo', None).upper() + ) if json_exclude_quantization.get("quant_algo") else None, + "kv_cache_quant_algo": QuantAlgo( + json_exclude_quantization.get("kv_cache_quant_algo").upper() + ) if json_exclude_quantization.get("kv_cache_quant_algo") else None, + "activation_scheme": ActivationScheme( + json_exclude_quantization.get('activation_scheme', None).upper() + ) if json_exclude_quantization.get("activation_scheme") else None, + "group_size": json_exclude_quantization.get('group_size', None), + } + if quant_config.exclude_quantization["quant_algo"] in [QuantAlgo.FP8_BLOCK_SCALES, QuantAlgo.W4A8_AWQ]: + if quant_config.exclude_quantization["group_size"] is None: + quant_config.exclude_quantization["group_size"] = 128 + + if quant_config.quant_algo in [QuantAlgo.FP8_BLOCK_SCALES, QuantAlgo.W4A8_AWQ]: + if quant_config.group_size is None: + quant_config.group_size = 128 + + return quant_config, layer_quant_config + @staticmethod def get_mxfp4_quant_algo(moe_backend, is_dynamic_quant=False): quant_algo = ModelConfig.override_quant_algo() @@ -330,6 +380,58 @@ def load_hf_quant_config(hf_quant_config, moe_backend): 'block.*.attn.out', 'block.*.mlp.gate', 'block.*.attn.qkv', 'embedding', 'unembedding' ] + elif hf_quant_config.get("quant_method") == "fp8": + quant_config.quant_algo = QuantAlgo.FP8 + elif hf_quant_config.get("quant_method") == "w4a8_awq": + quant_config.quant_algo = QuantAlgo.W4A8_AWQ + quant_config.group_size = hf_quant_config.get("weight_group_size", 128) + else: + raise NotImplementedError(f"Unsupported quantization_config: {hf_quant_config}.") + + # set kv_cache_quant_algo + quant_config.kv_cache_quant_algo = QuantAlgo(hf_quant_config.get("kv_cache_quant_method").upper()) \ + if hf_quant_config.get("kv_cache_quant_method") else None + # set activation_scheme + quant_config.activation_scheme = ActivationScheme(hf_quant_config.get("activation_scheme").upper()) \ + if hf_quant_config.get("activation_scheme") else None + # set exclude_modules + if quant_config.exclude_modules: + if hf_quant_config.get("ignored_layers"): + quant_config.exclude_modules += hf_quant_config.get("ignored_layers") + else: + quant_config.exclude_modules = hf_quant_config.get("ignored_layers") + + # set exclude_quant_config + hf_ignored_quantization_config = hf_quant_config.get("ignored_quantization_config") + if hf_ignored_quantization_config: + quant_config.exclude_quant_config = { + "kv_cache_quant_algo": QuantAlgo( + hf_ignored_quantization_config.get("kv_cache_quant_method").upper() + ) if hf_ignored_quantization_config.get("kv_cache_quant_method") else None, + "activation_scheme": ActivationScheme( + hf_ignored_quantization_config.get("activation_scheme").upper() + ) if hf_ignored_quantization_config.get("activation_scheme") else None, + "group_size": 128, + } + if hf_ignored_quantization_config.get( + "quant_method") == "fp8" and hf_ignored_quantization_config.get("weight_block_size", []): + quant_config.exclude_quantization["quant_algo"] = QuantAlgo.FP8_BLOCK_SCALES + block_size = hf_ignored_quantization_config.get("weight_block_size", []) + assert tuple(block_size) == ( + 128, + 128), "FP8_BLOCK_SCALES only supports block_size=(128,128)" + quant_config.exclude_quantization["group_size"] = block_size[0] + elif hf_ignored_quantization_config.get("quant_method") == "fp8": + quant_config.exclude_quantization["quant_algo"] = QuantAlgo.FP8 + elif hf_ignored_quantization_config.get("quant_method") == "w4a8_awq": + quant_config.exclude_quantization["quant_algo"] = QuantAlgo.W4A8_AWQ + quant_config.exclude_quantization["group_size"] = hf_ignored_quantization_config.get( + "weight_group_size", 128) + else: + raise NotImplementedError(f"Unsupported quantization_config.ignored_quantization_config: " + f"{hf_ignored_quantization_config}.") + + logger.info(f"Load quantization config from pretrained config, quant_config: {quant_config}") return quant_config, layer_quant_config @@ -412,6 +514,9 @@ def from_pretrained(cls, if (quant_config_file := model_dir / 'hf_quant_config.json').exists(): quant_config, layer_quant_config = cls.load_modelopt_quant_config( quant_config_file, model_dir, moe_backend) + elif (quant_config_file := model_dir / 'angelslim_hf_quant_config.json').exists(): + quant_config, layer_quant_config = cls.load_angelslim_quant_config( + quant_config_file, model_dir, moe_backend) # quantized ckpt in other formats elif hasattr(pretrained_config, "quantization_config"): hf_quant_config = pretrained_config.quantization_config diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 0ca4d28085b..260d943c8cb 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -1383,7 +1383,7 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, if names[-1] == "kv_b_proj": # TODO: remove weight_dequant after enabling fp8_bmm dequant_kv_b_proj = self.model_config.quant_config.is_module_excluded_from_quantization( - names[-1]) + names[-1]) and self.model_config.quant_config.exclude_quantization is None if dequant_kv_b_proj: kv_b_proj, k_b_proj_trans = load_kv_b_proj_and_k_b_proj_trans_dequant( name) diff --git a/tensorrt_llm/_torch/models/modeling_utils.py b/tensorrt_llm/_torch/models/modeling_utils.py index 284a31c26a6..45cc48a6c02 100755 --- a/tensorrt_llm/_torch/models/modeling_utils.py +++ b/tensorrt_llm/_torch/models/modeling_utils.py @@ -477,7 +477,20 @@ def apply_quant_config_exclude_modules(self): kv_cache_quant_algo = None if quant_config: kv_cache_quant_algo = quant_config.kv_cache_quant_algo - new_config = QuantConfig(kv_cache_quant_algo=kv_cache_quant_algo) + quant_algo = None + activation_scheme = None + group_size = 128 + exclude_quantization = quant_config.exclude_quantization + if exclude_quantization: + quant_algo = exclude_quantization.get("quant_algo", None) + activation_scheme = exclude_quantization.get("activation_scheme", None) + group_size = exclude_quantization.get("group_size", 128) + new_config = QuantConfig( + quant_algo=quant_algo, + kv_cache_quant_algo=kv_cache_quant_algo, + activation_scheme=activation_scheme, + group_size=group_size, + ) if quant_config is not None: if quant_config.exclude_modules is not None: diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index 0c46afcfb83..2b0917ece4d 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -224,9 +224,10 @@ def load_expert_weights_to_dst( MoEWeightLoadingMode.VANILLA, MoEWeightLoadingMode.W4A8_CUSTOM ]: - w1_weight = weights[f"{expert_id}.w1.weight"] - w3_weight = weights[f"{expert_id}.w3.weight"] - w2_weight = weights[f"{expert_id}.w2.weight"] + weight_name = "qweight" if f"{expert_id}.w1.qweight" in weights else "weight" + w1_weight = weights[f"{expert_id}.w1.{weight_name}"] + w3_weight = weights[f"{expert_id}.w3.{weight_name}"] + w2_weight = weights[f"{expert_id}.w2.{weight_name}"] if module.bias: w1_bias = weights[f"{expert_id}.w1.bias"] w3_bias = weights[f"{expert_id}.w3.bias"] @@ -1140,6 +1141,10 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): w4a8_custom = module.weight_loading_mode == MoEWeightLoadingMode.W4A8_CUSTOM if w4a8_custom: weight_scale_name = "weight_scale_inv" + for expert_id in module.initial_local_expert_ids: + if f"{expert_id}.w3.weight_scale.int4" in weights: + weight_scale_name = "weight_scale.int4" + break else: weight_scale_name = "weight_scale" @@ -1158,13 +1163,31 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): all_w3_w1_input_scales_max = torch.max( torch.stack(all_w3_input_scales), torch.stack(all_w1_input_scales)).max() + all_w3_w1_scales_fp8_max = None + has_fp8_weight_scale = False if w4a8_custom: # In custom W4A8 ckpt, per-tensor input_scale and per-channel pre_quant_scale are fused into input_scale module.fc31_act_scale.data.copy_( torch.ones_like(module.fc31_act_scale, device=self.device) * (1 / all_w3_w1_input_scales_max)) + + for expert_id in module.initial_local_expert_ids: + if f"{expert_id}.w1.weight_scale" in weights: + has_fp8_weight_scale = True + break + if has_fp8_weight_scale: + all_w3_w1_scales_fp8_max = [] + for expert_id in module.initial_local_expert_ids: + w1_weight_scale_fp8 = load_weight_shard(weights[f"{expert_id}.w1.weight_scale"], + device=self.device) + w3_weight_scale_fp8 = load_weight_shard(weights[f"{expert_id}.w3.weight_scale"], + device=self.device) + all_w3_w1_scales_fp8_max.append(torch.max(w3_weight_scale_fp8, w1_weight_scale_fp8)) + all_w3_w1_scales_fp8_max = torch.stack(all_w3_w1_scales_fp8_max).reshape(module.fc31_alpha.shape) + else: + all_w3_w1_scales_fp8_max = torch.ones_like(module.fc31_alpha, device=self.device) module.fc31_alpha.data.copy_( - (torch.ones_like(module.fc31_alpha, device=self.device) * + (all_w3_w1_scales_fp8_max * all_w3_w1_input_scales_max).float()) else: # In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and per-channel pre_quant_scale are separately stored @@ -1221,6 +1244,8 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): device=self.device) for expert_id in module.initial_local_expert_ids ] + if w4a8_custom and has_fp8_weight_scale: + all_w3_scales = torch.stack(all_w3_scales) / all_w3_w1_scales_fp8_max.unsqueeze(2) all_w1_scales = [ load_weight_shard(weights[f"{expert_id}.w1.{weight_scale_name}"], module.tp_size, @@ -1229,9 +1254,15 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): device=self.device) for expert_id in module.initial_local_expert_ids ] - all_w3_w1_scales = torch.cat( - [torch.stack(all_w3_scales), - torch.stack(all_w1_scales)], dim=-2) + if w4a8_custom and has_fp8_weight_scale: + all_w1_scales = torch.stack(all_w1_scales) / all_w3_w1_scales_fp8_max.unsqueeze(2) + all_w3_w1_scales = torch.cat( + [all_w3_scales, + all_w1_scales], dim=-2) + else: + all_w3_w1_scales = torch.cat( + [torch.stack(all_w3_scales), + torch.stack(all_w1_scales)], dim=-2) if module.sm_version == 89: w3_w1_scales = all_w3_w1_scales.to(torch.float16).view(module.dtype) else: @@ -1259,14 +1290,23 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): all_w2_input_scales_max = torch.stack(all_w2_input_scales).to( module.dtype).max() + all_w2_scales_fp8 = None if w4a8_custom: # In custom W4A8 ckpt, per-tensor input_scale and per-channel pre_quant_scale are fused into input_scale module.fc2_act_scale.data.copy_( torch.ones_like(module.fc2_act_scale, device=self.device) * (1 / all_w2_input_scales_max)) # In custom W4A8 ckpt, per-tensor weight_scale_2 is fused into alpha + if has_fp8_weight_scale: + all_w2_scales_fp8 = [ + load_weight_shard(weights[f"{expert_id}.w2.weight_scale"], device=self.device) + for expert_id in module.initial_local_expert_ids + ] + all_w2_scales_fp8 = torch.stack(all_w2_scales_fp8).reshape(module.fc2_alpha.shape) + else: + all_w2_scales_fp8 = torch.ones_like(module.fc2_alpha, device=self.device) module.fc2_alpha.data.copy_( - (torch.ones_like(module.fc2_alpha, device=self.device) * + (all_w2_scales_fp8 * all_w2_input_scales_max).float()) else: # In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and per-channel pre_quant_scale are separately stored @@ -1305,6 +1345,8 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): device=self.device) for expert_id in module.initial_local_expert_ids ] + if w4a8_custom and has_fp8_weight_scale: + all_w2_scales = torch.stack(all_w2_scales) / all_w2_scales_fp8.unsqueeze(2) if module.sm_version == 89: w2_scales = torch.stack(all_w2_scales).to(torch.float16).view( module.dtype) diff --git a/tensorrt_llm/llmapi/llm_utils.py b/tensorrt_llm/llmapi/llm_utils.py index e55735043e9..abb870f8252 100644 --- a/tensorrt_llm/llmapi/llm_utils.py +++ b/tensorrt_llm/llmapi/llm_utils.py @@ -25,7 +25,7 @@ from ..logger import logger from ..mapping import Mapping from ..models.automodel import MODEL_MAP, AutoConfig, AutoModelForCausalLM -from ..models.modeling_utils import PretrainedConfig, QuantAlgo, QuantConfig +from ..models.modeling_utils import PretrainedConfig, QuantAlgo, QuantConfig, ActivationScheme from ..module import Module from .build_cache import (BuildCache, BuildCacheConfig, CachedStage, get_build_cache_config_from_env) @@ -426,6 +426,11 @@ def _update_from_hf_quant_config(self) -> bool: "weight_block_size"): quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES quant_config.exclude_modules = ["*eh_proj"] + block_size = hf_quant_config.get("weight_block_size", []) + assert tuple(block_size) == ( + 128, + 128), "FP8_BLOCK_SCALES only supports block_size=(128,128)" + quant_config.group_size = block_size[0] elif hf_quant_config.get("quant_method") == "mxfp4": from .._torch.model_config import ModelConfig quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo( @@ -435,10 +440,46 @@ def _update_from_hf_quant_config(self) -> bool: 'block.*.attn.out', 'block.*.mlp.gate', 'block.*.attn.qkv', 'embedding', 'unembedding' ] + elif hf_quant_config.get("quant_method") == "fp8": + quant_config.quant_algo = QuantAlgo.FP8 + elif hf_quant_config.get("quant_method") == "w4a8_awq": + quant_config.quant_algo = QuantAlgo.W4A8_AWQ + quant_config.group_size = hf_quant_config.get("weight_group_size", 128) else: raise NotImplementedError( f"Unsupported quantization_config: {hf_quant_config}.") + # set kv_cache_quant_algo + quant_config.kv_cache_quant_algo = QuantAlgo( + hf_quant_config.get("kv_cache_quant_method").upper() + ) if hf_quant_config.get("kv_cache_quant_method") else None + # set activation_scheme + quant_config.activation_scheme = ActivationScheme( + hf_quant_config.get("activation_scheme").upper() + ) if hf_quant_config.get("activation_scheme") else None + # set exclude_modules + if quant_config.exclude_modules: + if hf_quant_config.get("ignored_modules"): + quant_config.exclude_modules += hf_quant_config.get("ignored_modules") + else: + quant_config.exclude_modules = hf_quant_config.get("ignored_modules") + # set exclude_quant_config + hf_ignored_quantization_config = hf_quant_config.get("ignored_quantization_config") + if hf_ignored_quantization_config: + quant_config.exclude_quant_config = { + "quant_algo": QuantAlgo( + hf_ignored_quantization_config.get("quant_method").upper() + ) if hf_ignored_quantization_config.get("quant_method") else None, + "kv_cache_quant_algo": QuantAlgo( + hf_ignored_quantization_config.get("kv_cache_quant_method").upper() + ) if hf_ignored_quantization_config.get("kv_cache_quant_method") else None, + "activation_scheme": ActivationScheme( + hf_ignored_quantization_config.get("activation_scheme").upper() + ) if hf_ignored_quantization_config.get("activation_scheme") else None, + } + logger.info( + f"Detected quantization_config: {quant_config}." + ) return True return False diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index 7c49855d0b3..c53a01ac2ce 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -45,7 +45,7 @@ WeightOnlyQuantLinear, WeightOnlyQuantRowLinear) from ..quantization.mode import (KV_CACHE_QUANT_ALGO_LIST, QUANT_ALGO_LIST, - W8A8_SQ_PLUGIN_LIST, QuantAlgo) + W8A8_SQ_PLUGIN_LIST, QuantAlgo, ActivationScheme) from ..quantization.utils import fp4_utils from ..top_model_mixin import TopModelMixin from .convert_utils import weight_only_quantize_dict @@ -140,6 +140,8 @@ class QuantConfig: pre_quant_scale (bool): Whether to use pre-quant scale for quantization. Defaults to False. exclude_modules (List[str], optional): The module name patterns that are skipped in quantization. Defaults to None. mamba_ssm_cache_dtype (str, optional): The data type for mamba SSM cache. Defaults to None. + exclude_quantization (Dict, optional): The model of exclude_modules will use exclude_quantization. + activation_scheme (tensorrt_llm.quantization.mode.ActivationScheme, optional): The input of activation quantize scheme. """ quant_algo: Optional[QuantAlgo] = None kv_cache_quant_algo: Optional[QuantAlgo] = None @@ -151,6 +153,8 @@ class QuantConfig: pre_quant_scale: bool = False exclude_modules: Optional[List[str]] = None mamba_ssm_cache_dtype: Optional[str] = None + exclude_quantization: Optional[Dict] = None + activation_scheme: Optional[ActivationScheme] = None @cached_property def quant_mode(self) -> QuantModeWrapper: diff --git a/tensorrt_llm/quantization/mode.py b/tensorrt_llm/quantization/mode.py index 4615bc1376f..f6f31c4bd0d 100644 --- a/tensorrt_llm/quantization/mode.py +++ b/tensorrt_llm/quantization/mode.py @@ -473,3 +473,8 @@ class GroupwiseQuantAlgo: PRE_QUANT_SCALE = 4 W4A8_ALPHA = 8 INT8_WEIGHT = 16 + + +class ActivationScheme(StrEnum, metaclass=BaseEnumMeta): + STATIC = auto() + DYNAMIC = auto()