Skip to content

Commit

Permalink
Add targets and ignore support to BaseSparsityCompressor
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Oct 6, 2024
1 parent c2455b7 commit 400c6c3
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os
import re
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Set, TypeVar, Union

import compressed_tensors
import torch
Expand All @@ -39,6 +39,7 @@
apply_quantization_config,
load_pretrained_quantization,
)
from compressed_tensors.quantization.lifecycle import find_compression_targets
from compressed_tensors.quantization.utils import (
is_module_quantized,
iter_named_leaf_modules,
Expand Down Expand Up @@ -276,8 +277,9 @@ def compress(
)

if self.sparsity_compressor is not None:
compression_targets = self._find_sparse_compression_targets(model=model)
compressed_state_dict = self.sparsity_compressor.compress(
compressed_state_dict
compressed_state_dict, compression_targets=compression_targets
)

# HACK: Override the dtype_byte_size function in transformers to
Expand Down Expand Up @@ -368,6 +370,13 @@ def _replace_weights(self, dense_weight_generator, model):
module = operator.attrgetter(prefix)(model)
update_parameter_data(module, data, param_name)

def _find_sparse_compression_targets(self, model: Module) -> Set[str]:
return find_compression_targets(
model=model,
targets=self.sparsity_config.targets,
ignore=self.sparsity_config.ignore,
)


def map_modules_to_quant_args(model: Module) -> Dict:
quantized_modules_to_args = {}
Expand Down
13 changes: 11 additions & 2 deletions src/compressed_tensors/compressors/sparse_compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import Dict, Generator, Tuple
from typing import Dict, Generator, Optional, Set, Tuple

from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
Expand Down Expand Up @@ -59,18 +59,27 @@ class BaseSparseCompressor(BaseCompressor):
:param config: config specifying compression parameters
"""

def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
def compress(
self,
model_state: Dict[str, Tensor],
compression_targets: Optional[Set[str]] = None,
) -> Dict[str, Tensor]:
"""
Compresses a dense state dict using bitmask compression
:param model_state: state dict of uncompressed model
:param compression_targets: optional set of layer prefixes to compress, if None
compress all layers (for backwards compatibility)
:return: compressed state dict
"""
compressed_dict = {}
_LOGGER.debug(
f"Compressing model with {len(model_state)} parameterized layers..."
)
for name, value in tqdm(model_state.items(), desc="Compressing model"):
prefix = name.rsplit(".", 1)[0]
if compression_targets and prefix not in compression_targets:
continue
compression_data = self.compress_weight(name, value)
for key in compression_data.keys():
if key in compressed_dict:
Expand Down
26 changes: 25 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from copy import deepcopy
from typing import Dict, Iterable, List, Optional
from typing import OrderedDict as OrderedDictType
from typing import Union
from typing import Set, Union

import torch
from compressed_tensors.config import CompressionFormat
Expand Down Expand Up @@ -56,6 +56,7 @@
"apply_quantization_config",
"apply_quantization_status",
"find_name_or_class_matches",
"find_compression_targets",
]

from compressed_tensors.quantization.utils.helpers import is_module_quantized
Expand Down Expand Up @@ -276,6 +277,29 @@ def find_name_or_class_matches(
return matches


def find_compression_targets(
model: Module, targets: Iterable[str], ignore: Iterable[str]
) -> Set[str]:
"""
Finds all the targets in the model that match the given targets and ignore lists
Note: Targets must be regexes, layer types, or full layer names
:param model: model to search for targets in
:param targets: list of targets to search for
:param ignore: list of targets to ignore
:return: set of all targets that match the given targets and should
not be ignored
"""
current_targets = set()
for name, module in iter_named_leaf_modules(model):
if find_name_or_class_matches(
name, module, targets
) and not find_name_or_class_matches(name, module, ignore):
current_targets.add(name)
return current_targets


def _find_matches(
value: str, targets: Iterable[str], check_contains: bool = False
) -> List[str]:
Expand Down
27 changes: 27 additions & 0 deletions tests/test_quantization/lifecycle/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from compressed_tensors.quantization.lifecycle import (
apply_quantization_config,
apply_quantization_status,
find_compression_targets,
)
from compressed_tensors.quantization.utils import iter_named_leaf_modules
from transformers import AutoModelForCausalLM
Expand Down Expand Up @@ -272,3 +273,29 @@ def test_apply_quantization_status(caplog, ignore, should_raise_warning):
assert len(caplog.text) > 0
else:
assert len(caplog.text) == 0


@pytest.fixture
def model():
return AutoModelForCausalLM.from_pretrained(
"Xenova/llama2.c-stories15M",
torch_dtype="auto",
)


@pytest.mark.parametrize(
"targets, ignore, expected",
[
# ignore all
(["Linear"], ["Linear"], set()),
# ignore subset
(
["re:model.layers.[01].self_attn.q_proj"],
["re:model.layers.1.self_attn.q_proj"],
set(["model.layers.0.self_attn.q_proj"]),
),
],
)
def test_find_compression_targets(model, targets, ignore, expected):
actual_targets = find_compression_targets(model, targets, ignore)
assert actual_targets == expected

0 comments on commit 400c6c3

Please sign in to comment.