-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Support W4A8 method of AngleSlim tool #6857
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -18,7 +18,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from tensorrt_llm.logger import logger | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from tensorrt_llm.mapping import Mapping | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from tensorrt_llm.models.modeling_utils import QuantConfig | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from tensorrt_llm.quantization.mode import QuantAlgo | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from tensorrt_llm.quantization.mode import QuantAlgo, ActivationScheme | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -286,6 +286,56 @@ def load_modelopt_quant_config(quant_config_file, model_dir, moe_backend): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return quant_config, layer_quant_config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def load_angelslim_quant_config(quant_config_file, model_dir, moe_backend): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config = QuantConfig() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| layer_quant_config = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with open(quant_config_file) as f: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config_dict = json.load(f) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| json_quant_configs = quant_config_dict['quantization'] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.quant_algo = QuantAlgo( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| json_quant_configs.get('quant_algo', None).upper()) if json_quant_configs.get("quant_algo") else None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # fp8_pb_wo from modelopt is the same as FP8_BLOCK_SCALES | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if quant_config.quant_algo == "fp8_pb_wo": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.quant_algo = QuantAlgo('FP8_BLOCK_SCALES') | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.kv_cache_quant_algo = QuantAlgo( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| json_quant_configs.get("kv_cache_quant_algo").upper() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) if json_quant_configs.get("kv_cache_quant_algo") else None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.group_size = json_quant_configs.get('group_size', None) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.exclude_modules = json_quant_configs.get( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 'exclude_modules', None) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.activation_scheme = ActivationScheme( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| json_quant_configs.get('activation_scheme', None).upper() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) if json_quant_configs.get("activation_scheme") else None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| json_exclude_quantization= json_quant_configs.get('exclude_quantization', None) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if json_exclude_quantization: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.exclude_quant_config = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "quant_algo": QuantAlgo( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| json_exclude_quantization.get('quant_algo', None).upper() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) if json_exclude_quantization.get("quant_algo") else None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "kv_cache_quant_algo": QuantAlgo( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| json_exclude_quantization.get("kv_cache_quant_algo").upper() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) if json_exclude_quantization.get("kv_cache_quant_algo") else None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "activation_scheme": ActivationScheme( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| json_exclude_quantization.get('activation_scheme', None).upper() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) if json_exclude_quantization.get("activation_scheme") else None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "group_size": json_exclude_quantization.get('group_size', None), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if quant_config.exclude_quantization["quant_algo"] in [QuantAlgo.FP8_BLOCK_SCALES, QuantAlgo.W4A8_AWQ]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if quant_config.exclude_quantization["group_size"] is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.exclude_quantization["group_size"] = 128 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if quant_config.quant_algo in [QuantAlgo.FP8_BLOCK_SCALES, QuantAlgo.W4A8_AWQ]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if quant_config.group_size is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.group_size = 128 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+315
to
+335
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typo breaks exclude overrides: use QuantConfig.exclude_quantization consistently and initialize it. Lines 269 and 356 use exclude_quant_config but QuantConfig defines exclude_quantization. Also, subsequent code (Line 281) expects exclude_quantization. This prevents per-module overrides from being applied and can raise exceptions. - json_exclude_quantization= json_quant_configs.get('exclude_quantization', None)
+ json_exclude_quantization = json_quant_configs.get('exclude_quantization', None)
if json_exclude_quantization:
- quant_config.exclude_quant_config = {
+ quant_config.exclude_quantization = {
"quant_algo": QuantAlgo(
json_exclude_quantization.get('quant_algo', None).upper()
) if json_exclude_quantization.get("quant_algo") else None,
"kv_cache_quant_algo": QuantAlgo(
json_exclude_quantization.get("kv_cache_quant_algo").upper()
) if json_exclude_quantization.get("kv_cache_quant_algo") else None,
"activation_scheme": ActivationScheme(
json_exclude_quantization.get('activation_scheme', None).upper()
) if json_exclude_quantization.get("activation_scheme") else None,
"group_size": json_exclude_quantization.get('group_size', None),
}
- if quant_config.exclude_quantization["quant_algo"] in [QuantAlgo.FP8_BLOCK_SCALES, QuantAlgo.W4A8_AWQ]:
+ if quant_config.exclude_quantization["quant_algo"] in {QuantAlgo.FP8_BLOCK_SCALES, QuantAlgo.W4A8_AWQ}:
if quant_config.exclude_quantization["group_size"] is None:
quant_config.exclude_quantization["group_size"] = 128
-
- if quant_config.quant_algo in [QuantAlgo.FP8_BLOCK_SCALES, QuantAlgo.W4A8_AWQ]:
+ if quant_config.quant_algo in {QuantAlgo.FP8_BLOCK_SCALES, QuantAlgo.W4A8_AWQ}:
if quant_config.group_size is None:
quant_config.group_size = 128📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return quant_config, layer_quant_config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+333
to
+337
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Replicate TRTLLM default excludes for FP8_BLOCK_SCALES (parity with modelopt path). For AngelsLim FP8 block scales, set default exclude_modules when moe_backend == 'TRTLLM', as done in load_modelopt_quant_config. - return quant_config, layer_quant_config
+ if (moe_backend == 'TRTLLM'
+ and quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
+ and quant_config.exclude_modules is None):
+ quant_config.exclude_modules = ["*kv_b_proj*", "*k_b_proj*", "*eh_proj"]
+ return quant_config, layer_quant_config📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_mxfp4_quant_algo(moe_backend, is_dynamic_quant=False): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_algo = ModelConfig.override_quant_algo() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -330,6 +380,58 @@ def load_hf_quant_config(hf_quant_config, moe_backend): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 'block.*.attn.out', 'block.*.mlp.gate', 'block.*.attn.qkv', | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 'embedding', 'unembedding' | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif hf_quant_config.get("quant_method") == "fp8": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.quant_algo = QuantAlgo.FP8 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif hf_quant_config.get("quant_method") == "w4a8_awq": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.quant_algo = QuantAlgo.W4A8_AWQ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.group_size = hf_quant_config.get("weight_group_size", 128) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise NotImplementedError(f"Unsupported quantization_config: {hf_quant_config}.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # set kv_cache_quant_algo | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.kv_cache_quant_algo = QuantAlgo(hf_quant_config.get("kv_cache_quant_method").upper()) \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if hf_quant_config.get("kv_cache_quant_method") else None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # set activation_scheme | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.activation_scheme = ActivationScheme(hf_quant_config.get("activation_scheme").upper()) \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if hf_quant_config.get("activation_scheme") else None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # set exclude_modules | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if quant_config.exclude_modules: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if hf_quant_config.get("ignored_layers"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.exclude_modules += hf_quant_config.get("ignored_layers") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.exclude_modules = hf_quant_config.get("ignored_layers") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+398
to
+403
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Accept both 'ignored_modules' and 'ignored_layers' from HF config. Two places in the codebase use different keys. Normalize here to avoid silently missing excludes. - if quant_config.exclude_modules:
- if hf_quant_config.get("ignored_layers"):
- quant_config.exclude_modules += hf_quant_config.get("ignored_layers")
- else:
- quant_config.exclude_modules = hf_quant_config.get("ignored_layers")
+ ignored = hf_quant_config.get("ignored_modules") or hf_quant_config.get("ignored_layers")
+ if ignored:
+ if quant_config.exclude_modules:
+ quant_config.exclude_modules += ignored
+ else:
+ quant_config.exclude_modules = ignored📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # set exclude_quant_config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hf_ignored_quantization_config = hf_quant_config.get("ignored_quantization_config") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if hf_ignored_quantization_config: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.exclude_quant_config = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "kv_cache_quant_algo": QuantAlgo( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hf_ignored_quantization_config.get("kv_cache_quant_method").upper() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) if hf_ignored_quantization_config.get("kv_cache_quant_method") else None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "activation_scheme": ActivationScheme( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hf_ignored_quantization_config.get("activation_scheme").upper() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) if hf_ignored_quantization_config.get("activation_scheme") else None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "group_size": 128, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if hf_ignored_quantization_config.get( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "quant_method") == "fp8" and hf_ignored_quantization_config.get("weight_block_size", []): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.exclude_quantization["quant_algo"] = QuantAlgo.FP8_BLOCK_SCALES | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_size = hf_ignored_quantization_config.get("weight_block_size", []) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert tuple(block_size) == ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 128, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 128), "FP8_BLOCK_SCALES only supports block_size=(128,128)" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.exclude_quantization["group_size"] = block_size[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif hf_ignored_quantization_config.get("quant_method") == "fp8": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.exclude_quantization["quant_algo"] = QuantAlgo.FP8 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif hf_ignored_quantization_config.get("quant_method") == "w4a8_awq": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.exclude_quantization["quant_algo"] = QuantAlgo.W4A8_AWQ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.exclude_quantization["group_size"] = hf_ignored_quantization_config.get( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "weight_group_size", 128) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise NotImplementedError(f"Unsupported quantization_config.ignored_quantization_config: " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"{hf_ignored_quantization_config}.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+404
to
+433
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix exclude overrides for HF: wrong attribute name + missing FP8 block handling. Same naming typo as above; also add FP8_BLOCK_SCALES handling when ignored_quantization_config carries weight_block_size. - hf_ignored_quantization_config = hf_quant_config.get("ignored_quantization_config")
- if hf_ignored_quantization_config:
- quant_config.exclude_quant_config = {
- "kv_cache_quant_algo": QuantAlgo(
- hf_ignored_quantization_config.get("kv_cache_quant_method").upper()
- ) if hf_ignored_quantization_config.get("kv_cache_quant_method") else None,
- "activation_scheme": ActivationScheme(
- hf_ignored_quantization_config.get("activation_scheme").upper()
- ) if hf_ignored_quantization_config.get("activation_scheme") else None,
- "group_size": 128,
- }
- if hf_ignored_quantization_config.get(
- "quant_method") == "fp8" and hf_ignored_quantization_config.get("weight_block_size", []):
- quant_config.exclude_quantization["quant_algo"] = QuantAlgo.FP8_BLOCK_SCALES
- block_size = hf_ignored_quantization_config.get("weight_block_size", [])
- assert tuple(block_size) == (
- 128,
- 128), "FP8_BLOCK_SCALES only supports block_size=(128,128)"
- quant_config.exclude_quantization["group_size"] = block_size[0]
- elif hf_ignored_quantization_config.get("quant_method") == "fp8":
- quant_config.exclude_quantization["quant_algo"] = QuantAlgo.FP8
- elif hf_ignored_quantization_config.get("quant_method") == "w4a8_awq":
- quant_config.exclude_quantization["quant_algo"] = QuantAlgo.W4A8_AWQ
- quant_config.exclude_quantization["group_size"] = hf_ignored_quantization_config.get(
- "weight_group_size", 128)
- else:
- raise NotImplementedError(f"Unsupported quantization_config.ignored_quantization_config: "
- f"{hf_ignored_quantization_config}.")
+ hf_ignored_quant = hf_quant_config.get("ignored_quantization_config")
+ if hf_ignored_quant:
+ quant_config.exclude_quantization = {
+ "kv_cache_quant_algo": QuantAlgo(hf_ignored_quant["kv_cache_quant_method"].upper())
+ if hf_ignored_quant.get("kv_cache_quant_method") else None,
+ "activation_scheme": ActivationScheme(hf_ignored_quant["activation_scheme"].upper())
+ if hf_ignored_quant.get("activation_scheme") else None,
+ "group_size": 128,
+ }
+ if hf_ignored_quant.get("quant_method") == "fp8" and hf_ignored_quant.get("weight_block_size"):
+ block_size = hf_ignored_quant["weight_block_size"]
+ assert tuple(block_size) == (128, 128), "FP8_BLOCK_SCALES only supports block_size=(128,128)"
+ quant_config.exclude_quantization["quant_algo"] = QuantAlgo.FP8_BLOCK_SCALES
+ quant_config.exclude_quantization["group_size"] = block_size[0]
+ elif hf_ignored_quant.get("quant_method") == "fp8":
+ quant_config.exclude_quantization["quant_algo"] = QuantAlgo.FP8
+ elif hf_ignored_quant.get("quant_method") == "w4a8_awq":
+ quant_config.exclude_quantization["quant_algo"] = QuantAlgo.W4A8_AWQ
+ quant_config.exclude_quantization["group_size"] = hf_ignored_quant.get("weight_group_size", 128)
+ else:
+ raise NotImplementedError(f"Unsupported quantization_config.ignored_quantization_config: {hf_ignored_quant}.")📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.info(f"Load quantization config from pretrained config, quant_config: {quant_config}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return quant_config, layer_quant_config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -412,6 +514,9 @@ def from_pretrained(cls, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (quant_config_file := model_dir / 'hf_quant_config.json').exists(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config, layer_quant_config = cls.load_modelopt_quant_config( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config_file, model_dir, moe_backend) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif (quant_config_file := model_dir / 'angelslim_hf_quant_config.json').exists(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config, layer_quant_config = cls.load_angelslim_quant_config( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config_file, model_dir, moe_backend) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+517
to
+519
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Verification agent 🧩 Analysis chainVerify the consistency of the naming convention for config files. The code now supports both 🏁 Script executed: #!/bin/bash
# Check for any existing references to these config file names in the codebase
echo "Searching for references to angelslim_hf_quant_config.json:"
rg "angelslim_hf_quant_config" --type py -A 2 -B 2
echo -e "\nSearching for references to hf_quant_config.json:"
rg "hf_quant_config\.json" --type py -A 2 -B 2
echo -e "\nChecking for any documentation about these config files:"
fd -e md -e rst | xargs rg -l "angelslim_hf_quant_config|hf_quant_config"Length of output: 8264 Update documentation to include the new AngelsLim config filename The code now recognizes both • docs/source/performance/perf-benchmarking.md – Add a note explaining that if an 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # quantized ckpt in other formats | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif hasattr(pretrained_config, "quantization_config"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hf_quant_config = pretrained_config.quantization_config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -477,7 +477,20 @@ def apply_quant_config_exclude_modules(self): | |
| kv_cache_quant_algo = None | ||
| if quant_config: | ||
| kv_cache_quant_algo = quant_config.kv_cache_quant_algo | ||
| new_config = QuantConfig(kv_cache_quant_algo=kv_cache_quant_algo) | ||
| quant_algo = None | ||
| activation_scheme = None | ||
| group_size = 128 | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are quantization parameters like quant_algo and group_size necessary in the modules that skip quantization? |
||
| exclude_quantization = quant_config.exclude_quantization | ||
| if exclude_quantization: | ||
| quant_algo = exclude_quantization.get("quant_algo", None) | ||
| activation_scheme = exclude_quantization.get("activation_scheme", None) | ||
| group_size = exclude_quantization.get("group_size", 128) | ||
| new_config = QuantConfig( | ||
| quant_algo=quant_algo, | ||
| kv_cache_quant_algo=kv_cache_quant_algo, | ||
| activation_scheme=activation_scheme, | ||
| group_size=group_size, | ||
| ) | ||
|
|
||
| if quant_config is not None: | ||
| if quant_config.exclude_modules is not None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -224,9 +224,10 @@ def load_expert_weights_to_dst( | |
| MoEWeightLoadingMode.VANILLA, | ||
| MoEWeightLoadingMode.W4A8_CUSTOM | ||
| ]: | ||
| w1_weight = weights[f"{expert_id}.w1.weight"] | ||
| w3_weight = weights[f"{expert_id}.w3.weight"] | ||
| w2_weight = weights[f"{expert_id}.w2.weight"] | ||
| weight_name = "qweight" if f"{expert_id}.w1.qweight" in weights else "weight" | ||
| w1_weight = weights[f"{expert_id}.w1.{weight_name}"] | ||
| w3_weight = weights[f"{expert_id}.w3.{weight_name}"] | ||
| w2_weight = weights[f"{expert_id}.w2.{weight_name}"] | ||
| if module.bias: | ||
| w1_bias = weights[f"{expert_id}.w1.bias"] | ||
| w3_bias = weights[f"{expert_id}.w3.bias"] | ||
|
|
@@ -1140,6 +1141,10 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): | |
| w4a8_custom = module.weight_loading_mode == MoEWeightLoadingMode.W4A8_CUSTOM | ||
| if w4a8_custom: | ||
| weight_scale_name = "weight_scale_inv" | ||
| for expert_id in module.initial_local_expert_ids: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this work fine for the original w4a8 implementation? |
||
| if f"{expert_id}.w3.weight_scale.int4" in weights: | ||
| weight_scale_name = "weight_scale.int4" | ||
| break | ||
| else: | ||
| weight_scale_name = "weight_scale" | ||
|
|
||
|
|
@@ -1158,13 +1163,31 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): | |
| all_w3_w1_input_scales_max = torch.max( | ||
| torch.stack(all_w3_input_scales), | ||
| torch.stack(all_w1_input_scales)).max() | ||
| all_w3_w1_scales_fp8_max = None | ||
| has_fp8_weight_scale = False | ||
| if w4a8_custom: | ||
| # In custom W4A8 ckpt, per-tensor input_scale and per-channel pre_quant_scale are fused into input_scale | ||
| module.fc31_act_scale.data.copy_( | ||
| torch.ones_like(module.fc31_act_scale, device=self.device) * | ||
| (1 / all_w3_w1_input_scales_max)) | ||
|
|
||
| for expert_id in module.initial_local_expert_ids: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. |
||
| if f"{expert_id}.w1.weight_scale" in weights: | ||
| has_fp8_weight_scale = True | ||
| break | ||
| if has_fp8_weight_scale: | ||
| all_w3_w1_scales_fp8_max = [] | ||
| for expert_id in module.initial_local_expert_ids: | ||
| w1_weight_scale_fp8 = load_weight_shard(weights[f"{expert_id}.w1.weight_scale"], | ||
| device=self.device) | ||
| w3_weight_scale_fp8 = load_weight_shard(weights[f"{expert_id}.w3.weight_scale"], | ||
| device=self.device) | ||
| all_w3_w1_scales_fp8_max.append(torch.max(w3_weight_scale_fp8, w1_weight_scale_fp8)) | ||
| all_w3_w1_scales_fp8_max = torch.stack(all_w3_w1_scales_fp8_max).reshape(module.fc31_alpha.shape) | ||
| else: | ||
| all_w3_w1_scales_fp8_max = torch.ones_like(module.fc31_alpha, device=self.device) | ||
| module.fc31_alpha.data.copy_( | ||
| (torch.ones_like(module.fc31_alpha, device=self.device) * | ||
| (all_w3_w1_scales_fp8_max * | ||
| all_w3_w1_input_scales_max).float()) | ||
| else: | ||
| # In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and per-channel pre_quant_scale are separately stored | ||
|
|
@@ -1221,6 +1244,8 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): | |
| device=self.device) | ||
| for expert_id in module.initial_local_expert_ids | ||
| ] | ||
| if w4a8_custom and has_fp8_weight_scale: | ||
| all_w3_scales = torch.stack(all_w3_scales) / all_w3_w1_scales_fp8_max.unsqueeze(2) | ||
| all_w1_scales = [ | ||
| load_weight_shard(weights[f"{expert_id}.w1.{weight_scale_name}"], | ||
| module.tp_size, | ||
|
|
@@ -1229,9 +1254,15 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): | |
| device=self.device) | ||
| for expert_id in module.initial_local_expert_ids | ||
| ] | ||
| all_w3_w1_scales = torch.cat( | ||
| [torch.stack(all_w3_scales), | ||
| torch.stack(all_w1_scales)], dim=-2) | ||
| if w4a8_custom and has_fp8_weight_scale: | ||
| all_w1_scales = torch.stack(all_w1_scales) / all_w3_w1_scales_fp8_max.unsqueeze(2) | ||
| all_w3_w1_scales = torch.cat( | ||
| [all_w3_scales, | ||
| all_w1_scales], dim=-2) | ||
| else: | ||
| all_w3_w1_scales = torch.cat( | ||
| [torch.stack(all_w3_scales), | ||
| torch.stack(all_w1_scales)], dim=-2) | ||
| if module.sm_version == 89: | ||
| w3_w1_scales = all_w3_w1_scales.to(torch.float16).view(module.dtype) | ||
| else: | ||
|
|
@@ -1259,14 +1290,23 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): | |
| all_w2_input_scales_max = torch.stack(all_w2_input_scales).to( | ||
| module.dtype).max() | ||
|
|
||
| all_w2_scales_fp8 = None | ||
| if w4a8_custom: | ||
| # In custom W4A8 ckpt, per-tensor input_scale and per-channel pre_quant_scale are fused into input_scale | ||
| module.fc2_act_scale.data.copy_( | ||
| torch.ones_like(module.fc2_act_scale, device=self.device) * | ||
| (1 / all_w2_input_scales_max)) | ||
| # In custom W4A8 ckpt, per-tensor weight_scale_2 is fused into alpha | ||
| if has_fp8_weight_scale: | ||
| all_w2_scales_fp8 = [ | ||
| load_weight_shard(weights[f"{expert_id}.w2.weight_scale"], device=self.device) | ||
| for expert_id in module.initial_local_expert_ids | ||
| ] | ||
| all_w2_scales_fp8 = torch.stack(all_w2_scales_fp8).reshape(module.fc2_alpha.shape) | ||
| else: | ||
| all_w2_scales_fp8 = torch.ones_like(module.fc2_alpha, device=self.device) | ||
| module.fc2_alpha.data.copy_( | ||
| (torch.ones_like(module.fc2_alpha, device=self.device) * | ||
| (all_w2_scales_fp8 * | ||
| all_w2_input_scales_max).float()) | ||
| else: | ||
| # In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and per-channel pre_quant_scale are separately stored | ||
|
|
@@ -1305,6 +1345,8 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): | |
| device=self.device) | ||
| for expert_id in module.initial_local_expert_ids | ||
| ] | ||
| if w4a8_custom and has_fp8_weight_scale: | ||
| all_w2_scales = torch.stack(all_w2_scales) / all_w2_scales_fp8.unsqueeze(2) | ||
| if module.sm_version == 89: | ||
| w2_scales = torch.stack(all_w2_scales).to(torch.float16).view( | ||
| module.dtype) | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -25,7 +25,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from ..logger import logger | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from ..mapping import Mapping | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from ..models.automodel import MODEL_MAP, AutoConfig, AutoModelForCausalLM | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from ..models.modeling_utils import PretrainedConfig, QuantAlgo, QuantConfig | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from ..models.modeling_utils import PretrainedConfig, QuantAlgo, QuantConfig, ActivationScheme | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from ..module import Module | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from .build_cache import (BuildCache, BuildCacheConfig, CachedStage, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| get_build_cache_config_from_env) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -426,6 +426,11 @@ def _update_from_hf_quant_config(self) -> bool: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "weight_block_size"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.exclude_modules = ["*eh_proj"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_size = hf_quant_config.get("weight_block_size", []) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert tuple(block_size) == ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 128, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 128), "FP8_BLOCK_SCALES only supports block_size=(128,128)" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.group_size = block_size[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif hf_quant_config.get("quant_method") == "mxfp4": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from .._torch.model_config import ModelConfig | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -435,10 +440,46 @@ def _update_from_hf_quant_config(self) -> bool: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 'block.*.attn.out', 'block.*.mlp.gate', | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 'block.*.attn.qkv', 'embedding', 'unembedding' | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif hf_quant_config.get("quant_method") == "fp8": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.quant_algo = QuantAlgo.FP8 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif hf_quant_config.get("quant_method") == "w4a8_awq": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.quant_algo = QuantAlgo.W4A8_AWQ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.group_size = hf_quant_config.get("weight_group_size", 128) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise NotImplementedError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"Unsupported quantization_config: {hf_quant_config}.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # set kv_cache_quant_algo | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.kv_cache_quant_algo = QuantAlgo( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hf_quant_config.get("kv_cache_quant_method").upper() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) if hf_quant_config.get("kv_cache_quant_method") else None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # set activation_scheme | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.activation_scheme = ActivationScheme( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hf_quant_config.get("activation_scheme").upper() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) if hf_quant_config.get("activation_scheme") else None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # set exclude_modules | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if quant_config.exclude_modules: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if hf_quant_config.get("ignored_modules"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.exclude_modules += hf_quant_config.get("ignored_modules") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.exclude_modules = hf_quant_config.get("ignored_modules") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # set exclude_quant_config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+461
to
+466
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Accept both 'ignored_modules' and 'ignored_layers'. Align with the other loader to avoid missing excludes. - if quant_config.exclude_modules:
- if hf_quant_config.get("ignored_modules"):
- quant_config.exclude_modules += hf_quant_config.get("ignored_modules")
- else:
- quant_config.exclude_modules = hf_quant_config.get("ignored_modules")
+ ignored = hf_quant_config.get("ignored_modules") or hf_quant_config.get("ignored_layers")
+ if ignored:
+ if quant_config.exclude_modules:
+ quant_config.exclude_modules += ignored
+ else:
+ quant_config.exclude_modules = ignored📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hf_ignored_quantization_config = hf_quant_config.get("ignored_quantization_config") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if hf_ignored_quantization_config: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| quant_config.exclude_quant_config = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "quant_algo": QuantAlgo( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hf_ignored_quantization_config.get("quant_method").upper() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) if hf_ignored_quantization_config.get("quant_method") else None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "kv_cache_quant_algo": QuantAlgo( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hf_ignored_quantization_config.get("kv_cache_quant_method").upper() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) if hf_ignored_quantization_config.get("kv_cache_quant_method") else None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "activation_scheme": ActivationScheme( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hf_ignored_quantization_config.get("activation_scheme").upper() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) if hf_ignored_quantization_config.get("activation_scheme") else None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+467
to
+479
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix exclude overrides: wrong attribute, add FP8 block handling, and group_size. Use exclude_quantization (not exclude_quant_config) and mirror FP8_BLOCK_SCALES logic. - hf_ignored_quantization_config = hf_quant_config.get("ignored_quantization_config")
- if hf_ignored_quantization_config:
- quant_config.exclude_quant_config = {
- "quant_algo": QuantAlgo(
- hf_ignored_quantization_config.get("quant_method").upper()
- ) if hf_ignored_quantization_config.get("quant_method") else None,
- "kv_cache_quant_algo": QuantAlgo(
- hf_ignored_quantization_config.get("kv_cache_quant_method").upper()
- ) if hf_ignored_quantization_config.get("kv_cache_quant_method") else None,
- "activation_scheme": ActivationScheme(
- hf_ignored_quantization_config.get("activation_scheme").upper()
- ) if hf_ignored_quantization_config.get("activation_scheme") else None,
- }
+ ignored_q = hf_quant_config.get("ignored_quantization_config")
+ if ignored_q:
+ quant_config.exclude_quantization = {
+ "kv_cache_quant_algo": QuantAlgo(ignored_q["kv_cache_quant_method"].upper())
+ if ignored_q.get("kv_cache_quant_method") else None,
+ "activation_scheme": ActivationScheme(ignored_q["activation_scheme"].upper())
+ if ignored_q.get("activation_scheme") else None,
+ "group_size": 128,
+ }
+ if ignored_q.get("quant_method") == "fp8" and ignored_q.get("weight_block_size"):
+ block_size = ignored_q["weight_block_size"]
+ assert tuple(block_size) == (128, 128), "FP8_BLOCK_SCALES only supports block_size=(128,128)"
+ quant_config.exclude_quantization["quant_algo"] = QuantAlgo.FP8_BLOCK_SCALES
+ quant_config.exclude_quantization["group_size"] = block_size[0]
+ elif ignored_q.get("quant_method") == "fp8":
+ quant_config.exclude_quantization["quant_algo"] = QuantAlgo.FP8
+ elif ignored_q.get("quant_method") == "w4a8_awq":
+ quant_config.exclude_quantization["quant_algo"] = QuantAlgo.W4A8_AWQ
+ quant_config.exclude_quantization["group_size"] = ignored_q.get("weight_group_size", 128)
+ else:
+ raise NotImplementedError(f"Unsupported quantization_config.ignored_quantization_config: {ignored_q}.")📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.info( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"Detected quantization_config: {quant_config}." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Logical issue: string comparison against enum value.
Line 254 compares
quant_config.quant_algo(which is now aQuantAlgoenum) against the string"fp8_pb_wo". This will always fail because you're comparing an enum to a string.Fix the comparison:
📝 Committable suggestion
🤖 Prompt for AI Agents