Skip to content

[Transforms] Apply, serialize, deserialize #276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: transform_arg_support
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/compressed_tensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
COMPRESSION_VERSION_NAME = "version"
QUANTIZATION_METHOD_NAME = "quant_method"
TRANSFORMS_CONFIG = "transforms_config"
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
QUANTIZATION_CONFIG_NAME,
QUANTIZATION_METHOD_NAME,
SPARSITY_CONFIG_NAME,
TRANSFORMS_CONFIG,
)
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
Expand All @@ -45,6 +46,7 @@
is_module_quantized,
iter_named_leaf_modules,
)
from compressed_tensors.transforms.transform_config import TransformationConfig
from compressed_tensors.utils import (
get_safetensors_folder,
merge_names,
Expand Down Expand Up @@ -133,6 +135,8 @@ def from_compression_config(

sparsity_config = cls.parse_sparsity_config(compression_config)
quantization_config = cls.parse_quantization_config(compression_config)
transforms_config = cls.parse_transforms_config(compression_config)

if sparsity_config is None and quantization_config is None:
return None

Expand All @@ -144,8 +148,13 @@ def from_compression_config(
if quantization_config is not None:
quantization_config = QuantizationConfig.model_validate(quantization_config)

if transforms_config is not None:
transforms_config = TransformationConfig.model_validate(transforms_config)

return cls(
sparsity_config=sparsity_config, quantization_config=quantization_config
sparsity_config=sparsity_config,
quantization_config=quantization_config,
transforms_config=transforms_config,
)

@classmethod
Expand All @@ -170,6 +179,10 @@ def from_pretrained_model(
model, format=quantization_format
)

# TODO: update to fetch from the pretrained model
# using the attached config for now
transforms_config = getattr(model, "transforms_config", None)

if isinstance(sparsity_config, str): # we passed in a sparsity format
sparsity_config = SparsityCompressionConfig.load_from_registry(
sparsity_config
Expand All @@ -179,9 +192,25 @@ def from_pretrained_model(
return None

return cls(
sparsity_config=sparsity_config, quantization_config=quantization_config
sparsity_config=sparsity_config,
quantization_config=quantization_config,
transforms_config=transforms_config,
)

@staticmethod
def parse_transforms_config(
compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
) -> Union[Dict[str, Any], None]:

if compression_config is None:
return None

if is_compressed_tensors_config(compression_config):
t_config = compression_config.transforms_config
return t_config.model_dump() if t_config is not None else None

return compression_config.get(TRANSFORMS_CONFIG, None)

@staticmethod
def parse_sparsity_config(
compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
Expand Down Expand Up @@ -243,9 +272,11 @@ def __init__(
self,
sparsity_config: Optional[SparsityCompressionConfig] = None,
quantization_config: Optional[QuantizationConfig] = None,
transforms_config: Optional[TransformationConfig] = None,
):
self.sparsity_config = sparsity_config
self.quantization_config = quantization_config
self.transforms_config = transforms_config
self.sparsity_compressor = None
self.quantization_compressor = None

Expand Down Expand Up @@ -434,7 +465,9 @@ def decompress(self, model_path: str, model: Module):
self.quantization_config, QuantizationStatus.FROZEN
):
names_to_scheme = apply_quantization_config(
model, self.quantization_config
model,
self.quantization_config,
transforms_config=self.transforms_config,
)
load_pretrained_quantization(model, model_path)

Expand Down Expand Up @@ -497,6 +530,12 @@ def update_config(self, save_directory: str):
SPARSITY_CONFIG_NAME
] = sparsity_config_data

if self.transforms_config is not None:
transforms_config_data = self.transforms_config.to_dict()
config_data[QUANTIZATION_CONFIG_NAME][
TRANSFORMS_CONFIG
] = transforms_config_data

with open(config_file_path, "w") as config_file:
json.dump(config_data, config_file, indent=2, sort_keys=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def decompress_weight(
:param quantization_args: quantization parameters for the weight
:return: tensor of the decompressed weight
"""

weight = compressed_data["weight_packed"]
scale = compressed_data["weight_scale"]
zero_point = compressed_data.get("weight_zero_point", None)
Expand Down
129 changes: 127 additions & 2 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
iter_named_leaf_modules,
iter_named_quantizable_modules,
)
from compressed_tensors.transforms import Transforms
from compressed_tensors.transforms.transform_config import TransformationConfig
from compressed_tensors.transforms.transform_data import TransformData
from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
from compressed_tensors.utils.offload import update_parameter_data
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
Expand All @@ -49,20 +52,50 @@

__all__ = [
"load_pretrained_quantization",
"load_transforms",
"apply_quantization_config",
"apply_quantization_status",
"find_name_or_class_matches",
"expand_target_names",
"is_target",
"process_transforms_config",
]

from compressed_tensors.quantization.utils.helpers import is_module_quantized
from compressed_tensors.utils.safetensors_load import get_quantization_state_dict
from compressed_tensors.utils.safetensors_load import (
get_quantization_state_dict,
get_weight_mappings,
)
from safetensors import safe_open


_LOGGER = logging.getLogger(__name__)


def load_transforms(model: Module, model_name_or_path: str):
model_path = get_safetensors_folder(model_name_or_path)
weight_mappings = get_weight_mappings(model_path)

state_dict = {}
for weight_name, safe_path in weight_mappings.items():
if "transform" in weight_name:
with safe_open(safe_path, framework="pt", device="cpu") as f:
state_dict[weight_name] = f.get_tensor(weight_name)

for name, submodule in iter_named_leaf_modules(model):
transform_data = getattr(submodule, "transform_data", None)
Copy link
Contributor

@kylesayrs kylesayrs Apr 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TransformData includes a dictionary of all the transforms-relevant runtime data, and is attached to the layer as "transform_data"

To me, it seems like this information is essentially duplicating the information of the scheme/args? Why not attach the scheme/args to the module, rather than creating a new abstraction?


if transform_data:
for transform_name, transform_values in transform_data.data.items():
full_name = f"{name}.{transform_name}"
transform_data = state_dict.get(full_name, None)
transform = transform_values.get("transform")
transform.register_to_module(name=transform_name, module=submodule)
transform.update_transform(
module=submodule, data=transform_data, name=transform_name
)


def load_pretrained_quantization(model: Module, model_name_or_path: str):
"""
Loads the quantization parameters (scale and zero point) from model_name_or_path to
Expand Down Expand Up @@ -104,8 +137,94 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
)


def process_transforms_config(
Copy link
Contributor

@brian-dellabetta brian-dellabetta Apr 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed that this is actually modifying model registering transforms to it. The name process_transforms_config is pretty innocuous for something like this, given it's modifying model significantly. Consider renaming to add_transforms_to_model or something more explicit?

transforms_config: TransformationConfig,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another example where I think consistency in naming convention will help users. Something like this?

Suggested change
transforms_config: TransformationConfig,
transform_config: TransformConfig,

model: torch.nn.Module,
quantization_status: Optional[QuantizationStatus] = QuantizationStatus.INITIALIZED,
):
for _, group in transforms_config.transform_groups.items():
# Each group/scheme targets one type of transform
transform_type = group.transform_type
transform_creation_args = group.transform_creation_args

# Need a better name - too many groups
for transform_arg in group.groups:
module_targets = transform_arg.module_targets

for name, submodule in model.named_modules():
if len(transform_arg.ignore) > 0:
if matches := find_name_or_class_matches(
name, submodule, transform_arg.ignore
):
for match in matches:
print("ignoring", match, name)
continue # layer matches ignore list, continue

targets = find_name_or_class_matches(
name, submodule, transform_arg.targets
)

if targets:
# Every layer which matches gets its own transform
# Same transform type and args are used however

# attach the transform to the submodule
# because we can have more than one transform, need to attach some
# form of key to fetch
# OR we store it in the dictionary, handle cpu-offloading separatly

if hasattr(submodule, "transform_data"):
idx = submodule.transform_data.idx + 1
else:
idx = 0
# only support weight parameters for now, assume one value in
# module targets
transform_name = f"{module_targets[0]}_transform_{idx}"

# create an empty tensor OR create a new transform
dtype = getattr(submodule, module_targets[0]).dtype
if quantization_status in [
QuantizationStatus.COMPRESSED,
QuantizationStatus.FROZEN,
]:
transform = Transforms.load_from_registry(
transform_type,
dtype=dtype,
empty=True,
**transform_creation_args,
)
else:
transform = Transforms.load_from_registry(
transform_type,
dtype=dtype,
**transform_creation_args,
)
transform.register_to_module(
name=transform_name, module=submodule
)

# add relevant transform data to the submodule as well
data = {
transform_name: {
"transform": transform,
"call_args": transform_arg.call_args,
}
}

if hasattr(submodule, "transform_data"):
submodule.transform_data.data.update(data)
submodule.transform_data.idx = idx
else:
transform_data = TransformData(data=OrderedDict(data))
submodule.transform_data = transform_data
return model


def apply_quantization_config(
model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
model: Module,
config: Union[QuantizationConfig, None],
run_compressed: bool = False,
transforms_config=None,
) -> OrderedDict:
"""
Initializes the model for quantization in-place based on the given config.
Expand Down Expand Up @@ -184,6 +303,12 @@ def apply_quantization_config(
f"{set(config.ignore) - set(ignored_submodules)}"
)

if transforms_config:
model.transforms_config = transforms_config
model = process_transforms_config(
transforms_config, model, config.quantization_status
)

# apply current quantization status across all targeted layers
apply_quantization_status(model, config.quantization_status)
return names_to_scheme
Expand Down
19 changes: 19 additions & 0 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
calculate_range,
compute_dynamic_scales_and_zp,
)
from compressed_tensors.transforms.apply import (
apply_inverse_transforms_to_parameter,
apply_transforms_to_parameter,
)
from compressed_tensors.utils import safe_permute
from torch.nn import Module

Expand Down Expand Up @@ -280,10 +284,25 @@ def wrapped_forward(self, *args, **kwargs):
if scheme.weights is not None and not compressed:
# calibrate and (fake) quantize weights when applicable
unquantized_weight = self.weight.data.clone()
transform_data = getattr(module, "transform_data", None)
if transform_data is not None:
apply_transforms_to_parameter(
module=module,
module_parameter=self.weight,
transform_data=transform_data,
)

self.weight.data = forward_quantize(
module, self.weight, "weight", scheme.weights
)

if transform_data is not None:
apply_inverse_transforms_to_parameter(
module=module,
module_parameter=self.weight,
transform_data=transform_data,
)

# perform wrapped forward call
output = forward_func_orig.__get__(module, module.__class__)(
input_, *args[1:], **kwargs
Expand Down
Loading