From 400c6c391aa2fa316da1b61b3de648a9095bf2bd Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Sun, 6 Oct 2024 22:45:56 +0000 Subject: [PATCH] Add targets and ignore support to BaseSparsityCompressor --- .../model_compressors/model_compressor.py | 13 +++++++-- .../compressors/sparse_compressors/base.py | 13 +++++++-- .../quantization/lifecycle/apply.py | 26 +++++++++++++++++- .../test_quantization/lifecycle/test_apply.py | 27 +++++++++++++++++++ 4 files changed, 74 insertions(+), 5 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 6473554d..a825f149 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -18,7 +18,7 @@ import os import re from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Set, TypeVar, Union import compressed_tensors import torch @@ -39,6 +39,7 @@ apply_quantization_config, load_pretrained_quantization, ) +from compressed_tensors.quantization.lifecycle import find_compression_targets from compressed_tensors.quantization.utils import ( is_module_quantized, iter_named_leaf_modules, @@ -276,8 +277,9 @@ def compress( ) if self.sparsity_compressor is not None: + compression_targets = self._find_sparse_compression_targets(model=model) compressed_state_dict = self.sparsity_compressor.compress( - compressed_state_dict + compressed_state_dict, compression_targets=compression_targets ) # HACK: Override the dtype_byte_size function in transformers to @@ -368,6 +370,13 @@ def _replace_weights(self, dense_weight_generator, model): module = operator.attrgetter(prefix)(model) update_parameter_data(module, data, param_name) + def _find_sparse_compression_targets(self, model: Module) -> Set[str]: + return find_compression_targets( + model=model, + targets=self.sparsity_config.targets, + ignore=self.sparsity_config.ignore, + ) + def map_modules_to_quant_args(model: Module) -> Dict: quantized_modules_to_args = {} diff --git a/src/compressed_tensors/compressors/sparse_compressors/base.py b/src/compressed_tensors/compressors/sparse_compressors/base.py index 1b1a6825..96fd3591 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/base.py +++ b/src/compressed_tensors/compressors/sparse_compressors/base.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Dict, Generator, Tuple +from typing import Dict, Generator, Optional, Set, Tuple from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.utils import get_nested_weight_mappings, merge_names @@ -59,11 +59,17 @@ class BaseSparseCompressor(BaseCompressor): :param config: config specifying compression parameters """ - def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]: + def compress( + self, + model_state: Dict[str, Tensor], + compression_targets: Optional[Set[str]] = None, + ) -> Dict[str, Tensor]: """ Compresses a dense state dict using bitmask compression :param model_state: state dict of uncompressed model + :param compression_targets: optional set of layer prefixes to compress, if None + compress all layers (for backwards compatibility) :return: compressed state dict """ compressed_dict = {} @@ -71,6 +77,9 @@ def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]: f"Compressing model with {len(model_state)} parameterized layers..." ) for name, value in tqdm(model_state.items(), desc="Compressing model"): + prefix = name.rsplit(".", 1)[0] + if compression_targets and prefix not in compression_targets: + continue compression_data = self.compress_weight(name, value) for key in compression_data.keys(): if key in compressed_dict: diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index a66dba92..5e0d8f20 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -18,7 +18,7 @@ from copy import deepcopy from typing import Dict, Iterable, List, Optional from typing import OrderedDict as OrderedDictType -from typing import Union +from typing import Set, Union import torch from compressed_tensors.config import CompressionFormat @@ -56,6 +56,7 @@ "apply_quantization_config", "apply_quantization_status", "find_name_or_class_matches", + "find_compression_targets", ] from compressed_tensors.quantization.utils.helpers import is_module_quantized @@ -276,6 +277,29 @@ def find_name_or_class_matches( return matches +def find_compression_targets( + model: Module, targets: Iterable[str], ignore: Iterable[str] +) -> Set[str]: + """ + Finds all the targets in the model that match the given targets and ignore lists + + Note: Targets must be regexes, layer types, or full layer names + + :param model: model to search for targets in + :param targets: list of targets to search for + :param ignore: list of targets to ignore + :return: set of all targets that match the given targets and should + not be ignored + """ + current_targets = set() + for name, module in iter_named_leaf_modules(model): + if find_name_or_class_matches( + name, module, targets + ) and not find_name_or_class_matches(name, module, ignore): + current_targets.add(name) + return current_targets + + def _find_matches( value: str, targets: Iterable[str], check_contains: bool = False ) -> List[str]: diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 5f0bd093..a673ae83 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -27,6 +27,7 @@ from compressed_tensors.quantization.lifecycle import ( apply_quantization_config, apply_quantization_status, + find_compression_targets, ) from compressed_tensors.quantization.utils import iter_named_leaf_modules from transformers import AutoModelForCausalLM @@ -272,3 +273,29 @@ def test_apply_quantization_status(caplog, ignore, should_raise_warning): assert len(caplog.text) > 0 else: assert len(caplog.text) == 0 + + +@pytest.fixture +def model(): + return AutoModelForCausalLM.from_pretrained( + "Xenova/llama2.c-stories15M", + torch_dtype="auto", + ) + + +@pytest.mark.parametrize( + "targets, ignore, expected", + [ + # ignore all + (["Linear"], ["Linear"], set()), + # ignore subset + ( + ["re:model.layers.[01].self_attn.q_proj"], + ["re:model.layers.1.self_attn.q_proj"], + set(["model.layers.0.self_attn.q_proj"]), + ), + ], +) +def test_find_compression_targets(model, targets, ignore, expected): + actual_targets = find_compression_targets(model, targets, ignore) + assert actual_targets == expected