From 6c9ee5a8c2ab29682514ada95c4938c094e10a1c Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Mon, 7 Oct 2024 00:35:11 +0000 Subject: [PATCH 1/5] Make compressors stackable Special condition for marlin_24 compressor Update tests Signed-off-by: Rahul Tuli --- .../compression/quantization_format.py | 7 ++--- .../compression/sparsity_config.py | 10 ++++--- .../compressed_tensors_utils.py | 30 ++++++++++++------- .../compression/test_infer_quant_format.py | 4 ++- 4 files changed, 30 insertions(+), 21 deletions(-) diff --git a/src/llmcompressor/transformers/compression/quantization_format.py b/src/llmcompressor/transformers/compression/quantization_format.py index 17f9400cf..95e1c0349 100644 --- a/src/llmcompressor/transformers/compression/quantization_format.py +++ b/src/llmcompressor/transformers/compression/quantization_format.py @@ -1,7 +1,6 @@ from typing import Optional from compressed_tensors import CompressionFormat -from compressed_tensors.config import SparsityCompressionConfig from compressed_tensors.quantization import QuantizationStrategy, QuantizationType from compressed_tensors.quantization.utils import ( is_model_quantized, @@ -16,7 +15,7 @@ def infer_quantization_format( model, quantization_format: Optional[str] = None, save_compressed: bool = False, - sparsity_config: Optional[SparsityCompressionConfig] = None, + sparsity_structure: Optional[str] = None, ) -> str: """ Infers a quantization format based on model state and compression args @@ -36,9 +35,7 @@ def infer_quantization_format( if save_compressed: weight_args, input_args = _get_unique_quant_args(model) - is_24_structure = ( - sparsity_config and sparsity_config.sparsity_structure == "2:4" - ) + is_24_structure = sparsity_structure is not None and sparsity_structure == "2:4" is_weight_only = len(input_args) == 0 and len(weight_args) > 0 if is_weight_only: # w4a16 and w8a16 diff --git a/src/llmcompressor/transformers/compression/sparsity_config.py b/src/llmcompressor/transformers/compression/sparsity_config.py index d6ed9f7e7..22bfbbb06 100644 --- a/src/llmcompressor/transformers/compression/sparsity_config.py +++ b/src/llmcompressor/transformers/compression/sparsity_config.py @@ -1,7 +1,6 @@ from typing import Dict, Optional from compressed_tensors import CompressionFormat, SparsityCompressionConfig -from compressed_tensors.quantization.utils import is_model_quantized from torch import Tensor from torch.nn import Module @@ -74,6 +73,7 @@ def from_pretrained( model: Module, state_dict: Optional[Dict[str, Tensor]] = None, compress: bool = False, + is_marlin: bool = False, ) -> Optional["SparsityCompressionConfig"]: """ Determines compression type and informational parameters for a given model @@ -82,6 +82,7 @@ def from_pretrained( :param state_dict: optional state_dict to replace that in model, used for gathering global FSDP model info :param compress: whether or not to compress the model on disk + :param is_marlin: whether or not marlin compression is being used :return: compression config inferred from the model """ @@ -95,10 +96,11 @@ def from_pretrained( sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure( model=model ) - if is_model_quantized(model): - # compressing a sparse quantized model is not supported yet + if is_marlin: + # sparse compressor should be dense for marlin + # compression format = CompressionFormat.dense.value - elif compress: + if compress: format = CompressionFormat.sparse_bitmask.value else: format = CompressionFormat.dense.value diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index eff21d159..77f6731c6 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -6,7 +6,11 @@ import torch import transformers from accelerate.accelerator import get_state_dict_offloaded_model -from compressed_tensors import ModelCompressor, SparsityCompressionConfig +from compressed_tensors import ( + CompressionFormat, + ModelCompressor, + SparsityCompressionConfig, +) from loguru import logger from safetensors.torch import storage_ptr from transformers import PreTrainedModel @@ -78,15 +82,22 @@ def save_pretrained_wrapper( if state_dict is None: state_dict = get_state_dict_offloaded_model(model) + sparsity_stucture = SparsityConfigMetadata.infer_sparsity_structure(model) + quantization_format = infer_quantization_format( + model=model, + quantization_format=quantization_format, + save_compressed=save_compressed, + sparsity_structure=sparsity_stucture, + ) + is_marlin = quantization_format == CompressionFormat.marlin_24.value + if sparsity_config is not None: sparsity_config.global_sparsity = ( SparsityConfigMetadata.infer_global_sparsity( model, state_dict=state_dict ) ) - sparsity_config.sparsity_structure = ( - SparsityConfigMetadata.infer_sparsity_structure() - ) + sparsity_config.sparsity_structure = sparsity_stucture elif not skip_compression_stats: # try to infer a sparsity config from the model if none is provided logger.info( @@ -96,15 +107,12 @@ def save_pretrained_wrapper( "skip_compression_stats=True" ) sparsity_config = SparsityConfigMetadata.from_pretrained( - model, state_dict=state_dict, compress=save_compressed + model, + state_dict=state_dict, + compress=save_compressed, + is_marlin=is_marlin, ) - quantization_format = infer_quantization_format( - model=model, - quantization_format=quantization_format, - save_compressed=save_compressed, - sparsity_config=sparsity_config, - ) compressor = ModelCompressor.from_pretrained_model( model, sparsity_config=sparsity_config, diff --git a/tests/llmcompressor/transformers/compression/test_infer_quant_format.py b/tests/llmcompressor/transformers/compression/test_infer_quant_format.py index 7db2f0687..6446e141d 100644 --- a/tests/llmcompressor/transformers/compression/test_infer_quant_format.py +++ b/tests/llmcompressor/transformers/compression/test_infer_quant_format.py @@ -30,6 +30,8 @@ def test_infer_quant_format(preset, sparsity_structure, expected_format): module.quantization_scheme = quant_scheme inferred_format = infer_quantization_format( - dummy_model, save_compressed=True, sparsity_config=sparsity_config + dummy_model, + save_compressed=True, + sparsity_structure=sparsity_config.sparsity_structure, ) assert inferred_format.value == expected_format From 4baf16d14e035a2a7e6cad78113d3499226cbbd1 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Tue, 22 Oct 2024 21:19:00 +0000 Subject: [PATCH 2/5] Enable Sparse24 compressor Signed-off-by: Rahul Tuli --- .../transformers/compression/sparsity_config.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/transformers/compression/sparsity_config.py b/src/llmcompressor/transformers/compression/sparsity_config.py index 22bfbbb06..bd234aa10 100644 --- a/src/llmcompressor/transformers/compression/sparsity_config.py +++ b/src/llmcompressor/transformers/compression/sparsity_config.py @@ -101,7 +101,10 @@ def from_pretrained( # compression format = CompressionFormat.dense.value if compress: - format = CompressionFormat.sparse_bitmask.value + if sparsity_structure == "2:4": + format = CompressionFormat.sparse_24.value + else: + format = CompressionFormat.sparse_bitmask.value else: format = CompressionFormat.dense.value From 43cb1d7b96d27e4278a25ff265ac88042e1f4414 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Tue, 22 Oct 2024 23:39:33 +0000 Subject: [PATCH 3/5] Add SparsityStructure Enum Signed-off-by: Rahul Tuli --- .../compression/quantization_format.py | 7 ++- .../compression/sparsity_config.py | 61 ++++++++++++++++++- .../compression/test_sparsity_config.py | 44 +++++++++++++ 3 files changed, 108 insertions(+), 4 deletions(-) create mode 100644 tests/llmcompressor/transformers/compression/test_sparsity_config.py diff --git a/src/llmcompressor/transformers/compression/quantization_format.py b/src/llmcompressor/transformers/compression/quantization_format.py index 95e1c0349..d0a03ebf0 100644 --- a/src/llmcompressor/transformers/compression/quantization_format.py +++ b/src/llmcompressor/transformers/compression/quantization_format.py @@ -8,6 +8,8 @@ iter_named_leaf_modules, ) +from llmcompressor.transformers.compression.sparsity_config import SparsityStructure + __all__ = ["infer_quantization_format"] @@ -35,7 +37,10 @@ def infer_quantization_format( if save_compressed: weight_args, input_args = _get_unique_quant_args(model) - is_24_structure = sparsity_structure is not None and sparsity_structure == "2:4" + is_24_structure = ( + SparsityStructure(sparsity_structure).value + == SparsityStructure.TWO_FOUR.value + ) is_weight_only = len(input_args) == 0 and len(weight_args) > 0 if is_weight_only: # w4a16 and w8a16 diff --git a/src/llmcompressor/transformers/compression/sparsity_config.py b/src/llmcompressor/transformers/compression/sparsity_config.py index bd234aa10..443fea29c 100644 --- a/src/llmcompressor/transformers/compression/sparsity_config.py +++ b/src/llmcompressor/transformers/compression/sparsity_config.py @@ -1,3 +1,4 @@ +from enum import Enum, unique from typing import Dict, Optional from compressed_tensors import CompressionFormat, SparsityCompressionConfig @@ -13,13 +14,67 @@ ) +@unique +class SparsityStructure(Enum): + """ + An enumeration to represent different sparsity structures. + + Attributes + ---------- + TWO_FOUR : str + Represents a 2:4 sparsity structure. + UNSTRUCTURED : str + Represents an unstructured sparsity structure. + + Examples + -------- + >>> SparsityStructure('2:4') + + + >>> SparsityStructure('unstructured') + + + >>> SparsityStructure('2:4') == SparsityStructure.TWO_FOUR + True + + >>> SparsityStructure('UNSTRUCTURED') == SparsityStructure.UNSTRUCTURED + True + + >>> SparsityStructure(None) == SparsityStructure.UNSTRUCTURED + True + + >>> SparsityStructure('invalid') + Traceback (most recent call last): + ... + ValueError: invalid is not a valid SparsityStructure + """ + + TWO_FOUR = "2:4" + UNSTRUCTURED = "unstructured" + + def __new__(cls, value): + obj = object.__new__(cls) + obj._value_ = value.lower() if value is not None else value + return obj + + @classmethod + def _missing_(cls, value): + # Handle None and case-insensitive values + if value is None: + return cls.UNSTRUCTURED + for member in cls: + if member.value == value.lower(): + return member + raise ValueError(f"{value} is not a valid {cls.__name__}") + + class SparsityConfigMetadata: """ Class of helper functions for filling out a SparsityCompressionConfig with readable metadata from the model """ - SPARSITY_THRESHOLD: float = 0.4 + SPARSITY_THRESHOLD: float = 0.5 @staticmethod def infer_global_sparsity( @@ -66,7 +121,7 @@ def infer_sparsity_structure(model: Optional[Module] = None) -> str: if model and sparsity_structure is None: sparsity_structure = infer_sparsity_structure_from_model(model) - return sparsity_structure or "unstructured" + return SparsityStructure(sparsity_structure).value @staticmethod def from_pretrained( @@ -101,7 +156,7 @@ def from_pretrained( # compression format = CompressionFormat.dense.value if compress: - if sparsity_structure == "2:4": + if sparsity_structure == SparsityStructure.TWO_FOUR.value: format = CompressionFormat.sparse_24.value else: format = CompressionFormat.sparse_bitmask.value diff --git a/tests/llmcompressor/transformers/compression/test_sparsity_config.py b/tests/llmcompressor/transformers/compression/test_sparsity_config.py new file mode 100644 index 000000000..91f20a361 --- /dev/null +++ b/tests/llmcompressor/transformers/compression/test_sparsity_config.py @@ -0,0 +1,44 @@ +import pytest + +from llmcompressor.transformers.compression.sparsity_config import SparsityStructure + + +def test_sparsity_structure_valid_cases(): + assert ( + SparsityStructure("2:4") == SparsityStructure.TWO_FOUR + ), "Failed to match '2:4' with TWO_FOUR" + assert ( + SparsityStructure("unstructured") == SparsityStructure.UNSTRUCTURED + ), "Failed to match 'unstructured' with UNSTRUCTURED" + assert ( + SparsityStructure("UNSTRUCTURED") == SparsityStructure.UNSTRUCTURED + ), "Failed to match 'UNSTRUCTURED' with UNSTRUCTURED" + assert ( + SparsityStructure(None) == SparsityStructure.UNSTRUCTURED + ), "Failed to match None with UNSTRUCTURED" + + +def test_sparsity_structure_invalid_case(): + with pytest.raises(ValueError, match="invalid is not a valid SparsityStructure"): + SparsityStructure("invalid") + + +def test_sparsity_structure_case_insensitivity(): + assert ( + SparsityStructure("2:4") == SparsityStructure.TWO_FOUR + ), "Failed to match '2:4' with TWO_FOUR" + assert ( + SparsityStructure("2:4".upper()) == SparsityStructure.TWO_FOUR + ), "Failed to match '2:4'.upper() with TWO_FOUR" + assert ( + SparsityStructure("unstructured".upper()) == SparsityStructure.UNSTRUCTURED + ), "Failed to match 'unstructured'.upper() with UNSTRUCTURED" + assert ( + SparsityStructure("UNSTRUCTURED".lower()) == SparsityStructure.UNSTRUCTURED + ), "Failed to match 'UNSTRUCTURED'.lower() with UNSTRUCTURED" + + +def test_sparsity_structure_default_case(): + assert ( + SparsityStructure(None) == SparsityStructure.UNSTRUCTURED + ), "Failed to match None with UNSTRUCTURED" From 9fd57a112eac557c799e2950bacef0112aa3383d Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 23 Oct 2024 13:40:19 +0000 Subject: [PATCH 4/5] Add 0:0 Sparsity Structure --- src/llmcompressor/transformers/compression/sparsity_config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/llmcompressor/transformers/compression/sparsity_config.py b/src/llmcompressor/transformers/compression/sparsity_config.py index 443fea29c..1dedd285e 100644 --- a/src/llmcompressor/transformers/compression/sparsity_config.py +++ b/src/llmcompressor/transformers/compression/sparsity_config.py @@ -23,6 +23,8 @@ class SparsityStructure(Enum): ---------- TWO_FOUR : str Represents a 2:4 sparsity structure. + ZERO_ZERO : str + Represents a 0:0 sparsity structure. UNSTRUCTURED : str Represents an unstructured sparsity structure. @@ -51,6 +53,7 @@ class SparsityStructure(Enum): TWO_FOUR = "2:4" UNSTRUCTURED = "unstructured" + ZERO_ZERO = "0:0" def __new__(cls, value): obj = object.__new__(cls) From 15bcbc1ecba50b92fab5ffb6d9263b5c4e166efd Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 23 Oct 2024 14:03:56 +0000 Subject: [PATCH 5/5] Move SparsityStructure enum to compressed_tensors --- .../compression/quantization_format.py | 3 +- .../compression/sparsity_config.py | 59 +------------------ .../compression/test_sparsity_config.py | 44 -------------- 3 files changed, 2 insertions(+), 104 deletions(-) delete mode 100644 tests/llmcompressor/transformers/compression/test_sparsity_config.py diff --git a/src/llmcompressor/transformers/compression/quantization_format.py b/src/llmcompressor/transformers/compression/quantization_format.py index d0a03ebf0..4986cfb7a 100644 --- a/src/llmcompressor/transformers/compression/quantization_format.py +++ b/src/llmcompressor/transformers/compression/quantization_format.py @@ -1,6 +1,7 @@ from typing import Optional from compressed_tensors import CompressionFormat +from compressed_tensors.config import SparsityStructure from compressed_tensors.quantization import QuantizationStrategy, QuantizationType from compressed_tensors.quantization.utils import ( is_model_quantized, @@ -8,8 +9,6 @@ iter_named_leaf_modules, ) -from llmcompressor.transformers.compression.sparsity_config import SparsityStructure - __all__ = ["infer_quantization_format"] diff --git a/src/llmcompressor/transformers/compression/sparsity_config.py b/src/llmcompressor/transformers/compression/sparsity_config.py index 1dedd285e..8943d60f4 100644 --- a/src/llmcompressor/transformers/compression/sparsity_config.py +++ b/src/llmcompressor/transformers/compression/sparsity_config.py @@ -1,7 +1,7 @@ -from enum import Enum, unique from typing import Dict, Optional from compressed_tensors import CompressionFormat, SparsityCompressionConfig +from compressed_tensors.config import SparsityStructure from torch import Tensor from torch.nn import Module @@ -14,63 +14,6 @@ ) -@unique -class SparsityStructure(Enum): - """ - An enumeration to represent different sparsity structures. - - Attributes - ---------- - TWO_FOUR : str - Represents a 2:4 sparsity structure. - ZERO_ZERO : str - Represents a 0:0 sparsity structure. - UNSTRUCTURED : str - Represents an unstructured sparsity structure. - - Examples - -------- - >>> SparsityStructure('2:4') - - - >>> SparsityStructure('unstructured') - - - >>> SparsityStructure('2:4') == SparsityStructure.TWO_FOUR - True - - >>> SparsityStructure('UNSTRUCTURED') == SparsityStructure.UNSTRUCTURED - True - - >>> SparsityStructure(None) == SparsityStructure.UNSTRUCTURED - True - - >>> SparsityStructure('invalid') - Traceback (most recent call last): - ... - ValueError: invalid is not a valid SparsityStructure - """ - - TWO_FOUR = "2:4" - UNSTRUCTURED = "unstructured" - ZERO_ZERO = "0:0" - - def __new__(cls, value): - obj = object.__new__(cls) - obj._value_ = value.lower() if value is not None else value - return obj - - @classmethod - def _missing_(cls, value): - # Handle None and case-insensitive values - if value is None: - return cls.UNSTRUCTURED - for member in cls: - if member.value == value.lower(): - return member - raise ValueError(f"{value} is not a valid {cls.__name__}") - - class SparsityConfigMetadata: """ Class of helper functions for filling out a SparsityCompressionConfig with readable diff --git a/tests/llmcompressor/transformers/compression/test_sparsity_config.py b/tests/llmcompressor/transformers/compression/test_sparsity_config.py deleted file mode 100644 index 91f20a361..000000000 --- a/tests/llmcompressor/transformers/compression/test_sparsity_config.py +++ /dev/null @@ -1,44 +0,0 @@ -import pytest - -from llmcompressor.transformers.compression.sparsity_config import SparsityStructure - - -def test_sparsity_structure_valid_cases(): - assert ( - SparsityStructure("2:4") == SparsityStructure.TWO_FOUR - ), "Failed to match '2:4' with TWO_FOUR" - assert ( - SparsityStructure("unstructured") == SparsityStructure.UNSTRUCTURED - ), "Failed to match 'unstructured' with UNSTRUCTURED" - assert ( - SparsityStructure("UNSTRUCTURED") == SparsityStructure.UNSTRUCTURED - ), "Failed to match 'UNSTRUCTURED' with UNSTRUCTURED" - assert ( - SparsityStructure(None) == SparsityStructure.UNSTRUCTURED - ), "Failed to match None with UNSTRUCTURED" - - -def test_sparsity_structure_invalid_case(): - with pytest.raises(ValueError, match="invalid is not a valid SparsityStructure"): - SparsityStructure("invalid") - - -def test_sparsity_structure_case_insensitivity(): - assert ( - SparsityStructure("2:4") == SparsityStructure.TWO_FOUR - ), "Failed to match '2:4' with TWO_FOUR" - assert ( - SparsityStructure("2:4".upper()) == SparsityStructure.TWO_FOUR - ), "Failed to match '2:4'.upper() with TWO_FOUR" - assert ( - SparsityStructure("unstructured".upper()) == SparsityStructure.UNSTRUCTURED - ), "Failed to match 'unstructured'.upper() with UNSTRUCTURED" - assert ( - SparsityStructure("UNSTRUCTURED".lower()) == SparsityStructure.UNSTRUCTURED - ), "Failed to match 'UNSTRUCTURED'.lower() with UNSTRUCTURED" - - -def test_sparsity_structure_default_case(): - assert ( - SparsityStructure(None) == SparsityStructure.UNSTRUCTURED - ), "Failed to match None with UNSTRUCTURED"