Skip to content

Commit

Permalink
Make compressors stackable
Browse files Browse the repository at this point in the history
Special condition for marlin_24 compressor
Update tests
  • Loading branch information
rahul-tuli committed Oct 7, 2024
1 parent da4e0dc commit a29da0d
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Optional

from compressed_tensors import CompressionFormat
from compressed_tensors.config import SparsityCompressionConfig
from compressed_tensors.quantization import QuantizationStrategy, QuantizationType
from compressed_tensors.quantization.utils import (
is_model_quantized,
Expand All @@ -16,7 +15,7 @@ def infer_quantization_format(
model,
quantization_format: Optional[str] = None,
save_compressed: bool = False,
sparsity_config: Optional[SparsityCompressionConfig] = None,
sparsity_structure: Optional[str] = None,
) -> str:
"""
Infers a quantization format based on model state and compression args
Expand All @@ -36,9 +35,7 @@ def infer_quantization_format(

if save_compressed:
weight_args, input_args = _get_unique_quant_args(model)
is_24_structure = (
sparsity_config and sparsity_config.sparsity_structure == "2:4"
)
is_24_structure = sparsity_structure is not None and sparsity_structure == "2:4"
is_weight_only = len(input_args) == 0 and len(weight_args) > 0

if is_weight_only: # w4a16 and w8a16
Expand Down
10 changes: 6 additions & 4 deletions src/llmcompressor/transformers/compression/sparsity_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Dict, Optional

from compressed_tensors import CompressionFormat, SparsityCompressionConfig
from compressed_tensors.quantization.utils import is_model_quantized
from torch import Tensor
from torch.nn import Module

Expand Down Expand Up @@ -74,6 +73,7 @@ def from_pretrained(
model: Module,
state_dict: Optional[Dict[str, Tensor]] = None,
compress: bool = False,
is_marlin: bool = False,
) -> Optional["SparsityCompressionConfig"]:
"""
Determines compression type and informational parameters for a given model
Expand All @@ -82,6 +82,7 @@ def from_pretrained(
:param state_dict: optional state_dict to replace that in model, used for
gathering global FSDP model info
:param compress: whether or not to compress the model on disk
:param is_marlin: whether or not marlin compression is being used
:return: compression config inferred from the model
"""

Expand All @@ -95,10 +96,11 @@ def from_pretrained(
sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure(
model=model
)
if is_model_quantized(model):
# compressing a sparse quantized model is not supported yet
if is_marlin:
# sparse compressor should be dense for marlin
# compression
format = CompressionFormat.dense.value
elif compress:
if compress:
format = CompressionFormat.sparse_bitmask.value
else:
format = CompressionFormat.dense.value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
import torch
import transformers
from accelerate.accelerator import get_state_dict_offloaded_model
from compressed_tensors import ModelCompressor, SparsityCompressionConfig
from compressed_tensors import (
CompressionFormat,
ModelCompressor,
SparsityCompressionConfig,
)
from loguru import logger
from transformers import PreTrainedModel

Expand Down Expand Up @@ -77,15 +81,22 @@ def save_pretrained_wrapper(
if state_dict is None:
state_dict = get_state_dict_offloaded_model(model)

sparsity_stucture = SparsityConfigMetadata.infer_sparsity_structure(model)
quantization_format = infer_quantization_format(
model=model,
quantization_format=quantization_format,
save_compressed=save_compressed,
sparsity_structure=sparsity_stucture,
)
is_marlin = quantization_format == CompressionFormat.marlin_24.value

if sparsity_config is not None:
sparsity_config.global_sparsity = (
SparsityConfigMetadata.infer_global_sparsity(
model, state_dict=state_dict
)
)
sparsity_config.sparsity_structure = (
SparsityConfigMetadata.infer_sparsity_structure()
)
sparsity_config.sparsity_structure = sparsity_stucture
elif not skip_compression_stats:
# try to infer a sparsity config from the model if none is provided
logger.info(
Expand All @@ -95,15 +106,12 @@ def save_pretrained_wrapper(
"skip_compression_stats=True"
)
sparsity_config = SparsityConfigMetadata.from_pretrained(
model, state_dict=state_dict, compress=save_compressed
model,
state_dict=state_dict,
compress=save_compressed,
is_marlin=is_marlin,
)

quantization_format = infer_quantization_format(
model=model,
quantization_format=quantization_format,
save_compressed=save_compressed,
sparsity_config=sparsity_config,
)
compressor = ModelCompressor.from_pretrained_model(
model,
sparsity_config=sparsity_config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def test_infer_quant_format(preset, sparsity_structure, expected_format):
module.quantization_scheme = quant_scheme

inferred_format = infer_quantization_format(
dummy_model, save_compressed=True, sparsity_config=sparsity_config
dummy_model,
save_compressed=True,
sparsity_structure=sparsity_config.sparsity_structure,
)
assert inferred_format.value == expected_format

0 comments on commit a29da0d

Please sign in to comment.