Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Sparse compression #822

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional

from compressed_tensors import CompressionFormat
from compressed_tensors.config import SparsityCompressionConfig
from compressed_tensors.config import SparsityStructure
from compressed_tensors.quantization import QuantizationStrategy, QuantizationType
from compressed_tensors.quantization.utils import (
is_model_quantized,
Expand All @@ -16,7 +16,7 @@ def infer_quantization_format(
model,
quantization_format: Optional[str] = None,
save_compressed: bool = False,
sparsity_config: Optional[SparsityCompressionConfig] = None,
sparsity_structure: Optional[str] = None,
kylesayrs marked this conversation as resolved.
Show resolved Hide resolved
) -> str:
"""
Infers a quantization format based on model state and compression args
Expand All @@ -37,7 +37,8 @@ def infer_quantization_format(
if save_compressed:
weight_args, input_args = _get_unique_quant_args(model)
is_24_structure = (
sparsity_config and sparsity_config.sparsity_structure == "2:4"
SparsityStructure(sparsity_structure).value
== SparsityStructure.TWO_FOUR.value
)
is_weight_only = len(input_args) == 0 and len(weight_args) > 0

Expand Down
20 changes: 13 additions & 7 deletions src/llmcompressor/transformers/compression/sparsity_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict, Optional

from compressed_tensors import CompressionFormat, SparsityCompressionConfig
from compressed_tensors.quantization.utils import is_model_quantized
from compressed_tensors.config import SparsityStructure
from torch import Tensor
from torch.nn import Module

Expand All @@ -20,7 +20,7 @@ class SparsityConfigMetadata:
metadata from the model
"""

SPARSITY_THRESHOLD: float = 0.4
SPARSITY_THRESHOLD: float = 0.5

@staticmethod
def infer_global_sparsity(
Expand Down Expand Up @@ -67,13 +67,14 @@ def infer_sparsity_structure(model: Optional[Module] = None) -> str:
if model and sparsity_structure is None:
sparsity_structure = infer_sparsity_structure_from_model(model)

return sparsity_structure or "unstructured"
return SparsityStructure(sparsity_structure).value

@staticmethod
def from_pretrained(
model: Module,
state_dict: Optional[Dict[str, Tensor]] = None,
compress: bool = False,
is_marlin: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make this more generic, why not pass in the quantization config? We will for sure have different compression formats which affect sparsity in the future

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we should make this generic as marlin will likely not be the only case

) -> Optional["SparsityCompressionConfig"]:
"""
Determines compression type and informational parameters for a given model
Expand All @@ -82,6 +83,7 @@ def from_pretrained(
:param state_dict: optional state_dict to replace that in model, used for
gathering global FSDP model info
:param compress: whether or not to compress the model on disk
:param is_marlin: whether or not marlin compression is being used
:return: compression config inferred from the model
"""

Expand All @@ -95,11 +97,15 @@ def from_pretrained(
sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure(
model=model
)
if is_model_quantized(model):
# compressing a sparse quantized model is not supported yet
if is_marlin:
# sparse compressor should be dense for marlin
# compression
rahul-tuli marked this conversation as resolved.
Show resolved Hide resolved
format = CompressionFormat.dense.value
elif compress:
format = CompressionFormat.sparse_bitmask.value
if compress:
if sparsity_structure == SparsityStructure.TWO_FOUR.value:
format = CompressionFormat.sparse_24.value
else:
format = CompressionFormat.sparse_bitmask.value
dsikka marked this conversation as resolved.
Show resolved Hide resolved
else:
format = CompressionFormat.dense.value

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
import torch
import transformers
from accelerate.accelerator import get_state_dict_offloaded_model
from compressed_tensors import ModelCompressor, SparsityCompressionConfig
from compressed_tensors import (
CompressionFormat,
ModelCompressor,
SparsityCompressionConfig,
)
from loguru import logger
from safetensors.torch import storage_ptr
from transformers import PreTrainedModel
Expand Down Expand Up @@ -78,15 +82,22 @@ def save_pretrained_wrapper(
if state_dict is None:
state_dict = get_state_dict_offloaded_model(model)

sparsity_stucture = SparsityConfigMetadata.infer_sparsity_structure(model)
rahul-tuli marked this conversation as resolved.
Show resolved Hide resolved
quantization_format = infer_quantization_format(
model=model,
quantization_format=quantization_format,
save_compressed=save_compressed,
sparsity_structure=sparsity_stucture,
)
is_marlin = quantization_format == CompressionFormat.marlin_24.value

if sparsity_config is not None:
sparsity_config.global_sparsity = (
rahul-tuli marked this conversation as resolved.
Show resolved Hide resolved
SparsityConfigMetadata.infer_global_sparsity(
model, state_dict=state_dict
)
)
sparsity_config.sparsity_structure = (
SparsityConfigMetadata.infer_sparsity_structure()
)
sparsity_config.sparsity_structure = sparsity_stucture
elif not skip_compression_stats:
# try to infer a sparsity config from the model if none is provided
logger.info(
Expand All @@ -96,15 +107,12 @@ def save_pretrained_wrapper(
"skip_compression_stats=True"
)
sparsity_config = SparsityConfigMetadata.from_pretrained(
model, state_dict=state_dict, compress=save_compressed
model,
state_dict=state_dict,
compress=save_compressed,
is_marlin=is_marlin,
)

quantization_format = infer_quantization_format(
model=model,
quantization_format=quantization_format,
save_compressed=save_compressed,
sparsity_config=sparsity_config,
)
compressor = ModelCompressor.from_pretrained_model(
model,
sparsity_config=sparsity_config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def test_infer_quant_format(preset, sparsity_structure, expected_format):
module.quantization_scheme = quant_scheme

inferred_format = infer_quantization_format(
dummy_model, save_compressed=True, sparsity_config=sparsity_config
dummy_model,
save_compressed=True,
sparsity_structure=sparsity_config.sparsity_structure,
)
assert inferred_format.value == expected_format
Loading