From 129770e1a779225fa643418b0be52469e66fe68a Mon Sep 17 00:00:00 2001 From: Benjamin Fineran Date: Mon, 15 Apr 2024 11:13:19 -0400 Subject: [PATCH] Apply quantization config implementation (#4) * Apply quantization config implementation * add TODO * integrate full lifecycle support, QuantizationStatus updates, add tinyllama test * fix comment --- src/sparsetensors/quantization/__init__.py | 2 + .../quantization/lifecycle/__init__.py | 3 +- .../quantization/lifecycle/apply.py | 120 ++++++++++++++++++ .../quantization/lifecycle/calibration.py | 2 +- .../quantization/lifecycle/forward.py | 2 +- .../quantization/lifecycle/frozen.py | 7 +- .../quantization/lifecycle/initialize.py | 2 +- .../quantization/quant_config.py | 26 +++- .../lifecycle/status.py => tests/__init__.py | 13 -- tests/quantization/__init__.py | 13 ++ tests/quantization/lifecycle/__init__.py | 13 ++ tests/quantization/lifecycle/test_apply.py | 113 +++++++++++++++++ 12 files changed, 296 insertions(+), 20 deletions(-) create mode 100644 src/sparsetensors/quantization/lifecycle/apply.py rename src/sparsetensors/quantization/lifecycle/status.py => tests/__init__.py (76%) create mode 100644 tests/quantization/__init__.py create mode 100644 tests/quantization/lifecycle/__init__.py create mode 100644 tests/quantization/lifecycle/test_apply.py diff --git a/src/sparsetensors/quantization/__init__.py b/src/sparsetensors/quantization/__init__.py index b53a328a..7227f889 100644 --- a/src/sparsetensors/quantization/__init__.py +++ b/src/sparsetensors/quantization/__init__.py @@ -13,6 +13,8 @@ # limitations under the License. # flake8: noqa +# isort: skip_file + from .quant_args import * from .quant_config import * from .quant_scheme import * diff --git a/src/sparsetensors/quantization/lifecycle/__init__.py b/src/sparsetensors/quantization/lifecycle/__init__.py index 52b86440..9504597b 100644 --- a/src/sparsetensors/quantization/lifecycle/__init__.py +++ b/src/sparsetensors/quantization/lifecycle/__init__.py @@ -13,9 +13,10 @@ # limitations under the License. # flake8: noqa +# isort: skip_file from .calibration import * from .forward import * from .frozen import * from .initialize import * -from .status import * +from .apply import * diff --git a/src/sparsetensors/quantization/lifecycle/apply.py b/src/sparsetensors/quantization/lifecycle/apply.py new file mode 100644 index 00000000..f58ee636 --- /dev/null +++ b/src/sparsetensors/quantization/lifecycle/apply.py @@ -0,0 +1,120 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from collections import OrderedDict +from typing import Iterable, Optional, Tuple + +from sparsetensors.quantization.lifecycle.calibration import set_module_for_calibration +from sparsetensors.quantization.lifecycle.frozen import freeze_module_quantization +from sparsetensors.quantization.lifecycle.initialize import ( + initialize_module_for_quantization, +) +from sparsetensors.quantization.quant_config import ( + QuantizationConfig, + QuantizationStatus, +) +from sparsetensors.quantization.quant_scheme import QuantizationScheme +from torch.nn import Module + + +__all__ = [ + "apply_quantization_config", + "apply_quantization_status", +] + + +def apply_quantization_config(model: Module, config: QuantizationConfig): + """ + Initializes the model for quantization in-place based on the given config + + :param model: model to apply quantization config to + :param config: quantization config + """ + # build mapping of targets to schemes for easier matching + # use ordered dict to preserve target ordering in config + target_to_scheme = OrderedDict() + for scheme in config.config_groups.values(): + for target in scheme.targets: + target_to_scheme[target] = scheme + + # build list of layers to target to avoid mutating submodule dict during iteration + layer_quant_scheme_pairs = [] + for name, submodule in _iter_named_leaf_modules(model): + if _find_first_name_or_class_match(name, submodule, config.ignore): + continue # layer matches ignore list, continue + target = _find_first_name_or_class_match(name, submodule, target_to_scheme) + if target is not None: + # target matched - add layer and scheme to target list + layer_quant_scheme_pairs.append((submodule, target_to_scheme[target])) + + # apply current quantization status for each matched pair + for layer, scheme in layer_quant_scheme_pairs: + apply_quantization_status( + module=layer, + scheme=scheme, + status=config.quantization_status, + ) + + +def apply_quantization_status( + module: Module, scheme: QuantizationScheme, status: QuantizationStatus +): + """ + Applies in place the quantization lifecycle up to the given status + + :param module: module to apply quantization to + :param scheme: quantization scheme to apply + :param status: status to update the module to + """ + if status >= QuantizationStatus.INITIALIZED: + initialize_module_for_quantization(module, scheme) + if status >= QuantizationStatus.CALIBRATION: + set_module_for_calibration(module) + if status >= QuantizationStatus.FROZEN: + freeze_module_quantization(module) + + +def _iter_named_leaf_modules(model: Module) -> Tuple[str, Module]: + # yields modules that do not have any submodules + # TODO: potentially expand to add list of allowed submodules such as observers + for name, submodule in model.named_modules(): + if len(list(submodule.children())) == 0: + yield name, submodule + + +def _find_first_name_or_class_match( + name: str, + module: Module, + targets: Iterable[str], +) -> Optional[str]: + # first element of targets that matches the given name + # if no name matches returns first target that matches the class name + # returns None otherwise + return _find_first_match(name, targets) or _find_first_match( + module.__class__.__name__, targets + ) + + +def _find_first_match(value: str, targets: Iterable[str]) -> Optional[str]: + # returns first element of target that matches value either + # exactly or as a regex after 're:' + for target in targets: + if target.startswith("re:"): + pattern = target[3:] + if re.match(pattern, value): + return target + elif target == value: + return target + return None diff --git a/src/sparsetensors/quantization/lifecycle/calibration.py b/src/sparsetensors/quantization/lifecycle/calibration.py index 986b062a..37102b6c 100644 --- a/src/sparsetensors/quantization/lifecycle/calibration.py +++ b/src/sparsetensors/quantization/lifecycle/calibration.py @@ -15,7 +15,7 @@ import logging -from sparsetensors.quantization.lifecycle.status import QuantizationStatus +from sparsetensors.quantization.quant_config import QuantizationStatus from torch.nn import Module diff --git a/src/sparsetensors/quantization/lifecycle/forward.py b/src/sparsetensors/quantization/lifecycle/forward.py index ab20e29b..3624229a 100644 --- a/src/sparsetensors/quantization/lifecycle/forward.py +++ b/src/sparsetensors/quantization/lifecycle/forward.py @@ -15,8 +15,8 @@ from functools import wraps import torch -from sparsetensors.quantization.lifecycle.status import QuantizationStatus from sparsetensors.quantization.quant_args import QuantizationArgs +from sparsetensors.quantization.quant_config import QuantizationStatus from sparsetensors.quantization.quant_scheme import QuantizationScheme from torch.nn import Module diff --git a/src/sparsetensors/quantization/lifecycle/frozen.py b/src/sparsetensors/quantization/lifecycle/frozen.py index 6b92eee7..1cf6fd1f 100644 --- a/src/sparsetensors/quantization/lifecycle/frozen.py +++ b/src/sparsetensors/quantization/lifecycle/frozen.py @@ -13,7 +13,7 @@ # limitations under the License. -from sparsetensors.quantization.lifecycle.status import QuantizationStatus +from sparsetensors.quantization.quant_config import QuantizationStatus from torch.nn import Module @@ -28,9 +28,12 @@ def freeze_module_quantization(module: Module): return # delete observers from module + observer_names = [] for submodule_name, _ in module.named_modules(): if "." not in submodule_name and submodule_name.endswith("_observer"): # delete any observers that belong directly to this module - delattr(module, submodule_name) + observer_names.append(submodule_name) + for observer_name in observer_names: + delattr(module, observer_name) module.quantization_status = QuantizationStatus.FROZEN diff --git a/src/sparsetensors/quantization/lifecycle/initialize.py b/src/sparsetensors/quantization/lifecycle/initialize.py index 5661fdbd..a87dbc3d 100644 --- a/src/sparsetensors/quantization/lifecycle/initialize.py +++ b/src/sparsetensors/quantization/lifecycle/initialize.py @@ -17,8 +17,8 @@ import torch from sparsetensors.quantization.lifecycle.forward import wrap_module_forward_quantized -from sparsetensors.quantization.lifecycle.status import QuantizationStatus from sparsetensors.quantization.quant_args import QuantizationArgs +from sparsetensors.quantization.quant_config import QuantizationStatus from sparsetensors.quantization.quant_scheme import QuantizationScheme from torch.nn import Module, Parameter diff --git a/src/sparsetensors/quantization/quant_config.py b/src/sparsetensors/quantization/quant_config.py index 1c8bd796..813c7197 100644 --- a/src/sparsetensors/quantization/quant_config.py +++ b/src/sparsetensors/quantization/quant_config.py @@ -19,7 +19,11 @@ from sparsetensors.quantization.quant_scheme import QuantizationScheme -__all__ = ["QuantizationStatus", "QuantizationConfig"] +__all__ = [ + "QuantizationStatus", + "QuantizationConfig", + "LIFECYCLE_ORDER", +] class QuantizationStatus(Enum): @@ -41,6 +45,26 @@ class QuantizationStatus(Enum): FROZEN = "frozen" COMPRESSED = "compressed" + @classmethod + def lifecycle_order(cls) -> List["QuantizationStatus"]: + """ + :return: list of correct quantization lifecycle order + """ + return + + def __ge__(self, other): + if not isinstance(other, self.__class__): + raise NotImplementedError + return LIFECYCLE_ORDER.index(self) >= LIFECYCLE_ORDER.index(other) + + +LIFECYCLE_ORDER = [ + QuantizationStatus.INITIALIZED, + QuantizationStatus.CALIBRATION, + QuantizationStatus.FROZEN, + QuantizationStatus.COMPRESSED, +] + class QuantizationConfig(BaseModel): """ diff --git a/src/sparsetensors/quantization/lifecycle/status.py b/tests/__init__.py similarity index 76% rename from src/sparsetensors/quantization/lifecycle/status.py rename to tests/__init__.py index 3b6a441d..0c44f887 100644 --- a/src/sparsetensors/quantization/lifecycle/status.py +++ b/tests/__init__.py @@ -11,16 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from enum import Enum - - -__all__ = [ - "QuantizationStatus", -] - - -class QuantizationStatus(Enum): - INITIALIZED = "INITIALIZED" - CALIBRATION = "CALIBRATION" - FROZEN = "FROZEN" diff --git a/tests/quantization/__init__.py b/tests/quantization/__init__.py new file mode 100644 index 00000000..0c44f887 --- /dev/null +++ b/tests/quantization/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/quantization/lifecycle/__init__.py b/tests/quantization/lifecycle/__init__.py new file mode 100644 index 00000000..0c44f887 --- /dev/null +++ b/tests/quantization/lifecycle/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/quantization/lifecycle/test_apply.py b/tests/quantization/lifecycle/test_apply.py new file mode 100644 index 00000000..46351cd8 --- /dev/null +++ b/tests/quantization/lifecycle/test_apply.py @@ -0,0 +1,113 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from sparsetensors.quantization.lifecycle import apply_quantization_config +from sparsetensors.quantization.quant_config import QuantizationConfig +from transformers import AutoModelForCausalLM + + +def test_apply_quantization_config_tinyllama(): + quant_config = get_sample_tinyllama_quant_config() + model = get_tinyllama_model() + + # check that model is not already quantized + for module in model.modules(): + _test_layer_quantization_status(module, inputs=False, weights=False) + + # apply quant config to model + apply_quantization_config(model, quant_config) + + # check for correct application of quant config + num_linears = 0 + num_embeddings = 0 + num_rotary_embeddings = 0 + for module in model.modules(): + module_type = module.__class__.__name__ + if module_type == "Linear": + num_linears += 1 + _test_layer_quantization_status(module, inputs=True, weights=True) + elif module_type == "Embedding": + num_embeddings += 1 + _test_layer_quantization_status(module, inputs=False, weights=True) + elif module_type == "LlamaRotaryEmbedding": + num_rotary_embeddings += 1 + _test_layer_quantization_status(module, inputs=False, weights=False) + + # sanity check correct number of layers targeted + assert num_linears == 155 + assert num_embeddings == 1 + assert num_rotary_embeddings == 22 + + +def _test_layer_quantization_status(module, inputs: bool, weights: bool): + # check if quantization is applied at all (true if inputs or weights targeted) + quantized = inputs or weights + assert hasattr(module, "quantization_scheme") == quantized + assert hasattr(module, "quantization_status") == quantized + + # check inputs matches expected + assert hasattr(module, "input_scale") == inputs + assert hasattr(module, "input_zero_point") == inputs + + # check weights matches expected + assert hasattr(module, "weight_scale") == weights + assert hasattr(module, "weight_zero_point") == weights + + +def get_tinyllama_model(): + return AutoModelForCausalLM.from_pretrained( + "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" + ) + + +def get_sample_tinyllama_quant_config(): + config_dict = { + "quant_method": "sparseml", + "format": "fakequant", + "quantization_status": "frozen", + "global_compression_ratio": None, + "config_groups": { + "group_1": { + "weights": { + "num_bits": 8, + "type": "int", + "symmetric": True, + "strategy": "tensor", + }, + "input_activations": { + "num_bits": 8, + "type": "int", + "symmetric": True, + "strategy": "tensor", + }, + "targets": ["Linear"], + }, + "group_2": { + "weights": { + "num_bits": 8, + "type": "int", + "symmetric": False, + "strategy": "tensor", + }, + "input_activations": None, + "targets": ["Embedding"], + }, + }, + "ignore": ["LlamaRotaryEmbedding"], + } + return QuantizationConfig.parse_obj(config_dict) + + +test_apply_quantization_config_tinyllama()