Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move llmcompressor util is_model_path_quantized to ct #246

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from compressed_tensors.quantization.lifecycle import expand_sparse_target_names
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.utils import (
is_module_quantized,
is_submodule_quantized,
iter_named_leaf_modules,
)
from compressed_tensors.utils import get_safetensors_folder, update_parameter_data
Expand Down Expand Up @@ -426,7 +426,7 @@ def map_modules_to_quant_args(model: Module) -> Dict[str, QuantizationArgs]:
"""
quantized_modules_to_args = {}
for name, submodule in iter_named_leaf_modules(model):
if is_module_quantized(submodule):
if is_submodule_quantized(submodule):
if submodule.quantization_scheme.weights is not None:
name = fix_fsdp_module_name(name)
quantized_modules_to_args[name] = submodule.quantization_scheme.weights
Expand Down
4 changes: 2 additions & 2 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
"is_sparse_target",
]

from compressed_tensors.quantization.utils.helpers import is_module_quantized
from compressed_tensors.quantization.utils.helpers import is_submodule_quantized
from compressed_tensors.utils.safetensors_load import get_quantization_state_dict


Expand All @@ -76,7 +76,7 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
state_dict = get_quantization_state_dict(model_path)

for name, submodule in iter_named_leaf_modules(model):
if not is_module_quantized(submodule):
if not is_submodule_quantized(submodule):
continue
if submodule.quantization_scheme.weights is not None:
base_name = "weight"
Expand Down
4 changes: 2 additions & 2 deletions src/compressed_tensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)
from compressed_tensors.quantization.utils import (
calculate_compression_ratio,
is_module_quantized,
is_submodule_quantized,
iter_named_quantizable_modules,
module_type,
parse_out_kv_cache_args,
Expand Down Expand Up @@ -181,7 +181,7 @@ def from_pretrained(
model, include_children=True, include_attn=True
):
layer_type = module_type(submodule)
if not is_module_quantized(submodule):
if not is_submodule_quantized(submodule):
if layer_type not in ignore:
ignore[layer_type] = []
ignore[layer_type].append(name)
Expand Down
50 changes: 37 additions & 13 deletions src/compressed_tensors/quantization/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,26 @@
from torch import FloatTensor, IntTensor, Tensor
from torch.nn import Module
from tqdm import tqdm
from transformers import AutoConfig


__all__ = [
"infer_quantization_status",
"is_module_quantized",
"is_model_quantized",
"module_type",
"KV_CACHE_TARGETS",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alphabetical

"calculate_compression_ratio",
"get_torch_bit_depth",
"calculate_qparams",
"calculate_range",
"can_quantize",
"parse_out_kv_cache_args",
"KV_CACHE_TARGETS",
"compute_dynamic_scales_and_zp",
"get_torch_bit_depth",
"infer_quantization_status",
"is_kv_cache_quant_scheme",
"is_model_quantized",
"is_model_quantized_from_path",
"is_submodule_quantized",
"iter_named_leaf_modules",
"iter_named_quantizable_modules",
"compute_dynamic_scales_and_zp",
"calculate_range",
"calculate_qparams",
"module_type",
"parse_out_kv_cache_args",
]

# target the self_attn layer
Expand Down Expand Up @@ -167,7 +169,7 @@ def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]:
return None


def is_module_quantized(module: Module) -> bool:
def is_submodule_quantized(module: Module) -> bool:
horheynm marked this conversation as resolved.
Show resolved Hide resolved
"""
Check if a module is quantized, based on the existence of a non-empty quantization
scheme
Expand Down Expand Up @@ -200,12 +202,31 @@ def is_model_quantized(model: Module) -> bool:
"""

for _, submodule in iter_named_leaf_modules(model):
if is_module_quantized(submodule):
if is_submodule_quantized(submodule):
return True

return False


def is_model_quantized_from_path(path: str) -> bool:
"""
Determine if model stub or path is quantized based
on the config

:param path: path to the model or HF stub
:return: True if config contains quantization_config from the given path

"""
config = AutoConfig.from_pretrained(path)
if config is not None:
if (
hasattr(config, "quantization_config")
and config.quantization_config["quant_method"] == "compressed-tensors"
):
return True
return False


def module_type(module: Module) -> str:
"""
Gets a string representation of a module type
Expand Down Expand Up @@ -331,7 +352,10 @@ def calculate_compression_ratio(model: Module) -> float:
for parameter in model.parameters():
uncompressed_bits = get_torch_bit_depth(parameter)
compressed_bits = uncompressed_bits
if is_module_quantized(submodule) and submodule.quantization_scheme.weights:
if (
is_submodule_quantized(submodule)
and submodule.quantization_scheme.weights
):
compressed_bits = submodule.quantization_scheme.weights.num_bits

num_weights = parameter.numel()
Expand Down