-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Apply quantization config implementation (#4)
* Apply quantization config implementation * add TODO * integrate full lifecycle support, QuantizationStatus updates, add tinyllama test * fix comment
- Loading branch information
Showing
12 changed files
with
296 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import re | ||
from collections import OrderedDict | ||
from typing import Iterable, Optional, Tuple | ||
|
||
from sparsetensors.quantization.lifecycle.calibration import set_module_for_calibration | ||
from sparsetensors.quantization.lifecycle.frozen import freeze_module_quantization | ||
from sparsetensors.quantization.lifecycle.initialize import ( | ||
initialize_module_for_quantization, | ||
) | ||
from sparsetensors.quantization.quant_config import ( | ||
QuantizationConfig, | ||
QuantizationStatus, | ||
) | ||
from sparsetensors.quantization.quant_scheme import QuantizationScheme | ||
from torch.nn import Module | ||
|
||
|
||
__all__ = [ | ||
"apply_quantization_config", | ||
"apply_quantization_status", | ||
] | ||
|
||
|
||
def apply_quantization_config(model: Module, config: QuantizationConfig): | ||
""" | ||
Initializes the model for quantization in-place based on the given config | ||
:param model: model to apply quantization config to | ||
:param config: quantization config | ||
""" | ||
# build mapping of targets to schemes for easier matching | ||
# use ordered dict to preserve target ordering in config | ||
target_to_scheme = OrderedDict() | ||
for scheme in config.config_groups.values(): | ||
for target in scheme.targets: | ||
target_to_scheme[target] = scheme | ||
|
||
# build list of layers to target to avoid mutating submodule dict during iteration | ||
layer_quant_scheme_pairs = [] | ||
for name, submodule in _iter_named_leaf_modules(model): | ||
if _find_first_name_or_class_match(name, submodule, config.ignore): | ||
continue # layer matches ignore list, continue | ||
target = _find_first_name_or_class_match(name, submodule, target_to_scheme) | ||
if target is not None: | ||
# target matched - add layer and scheme to target list | ||
layer_quant_scheme_pairs.append((submodule, target_to_scheme[target])) | ||
|
||
# apply current quantization status for each matched pair | ||
for layer, scheme in layer_quant_scheme_pairs: | ||
apply_quantization_status( | ||
module=layer, | ||
scheme=scheme, | ||
status=config.quantization_status, | ||
) | ||
|
||
|
||
def apply_quantization_status( | ||
module: Module, scheme: QuantizationScheme, status: QuantizationStatus | ||
): | ||
""" | ||
Applies in place the quantization lifecycle up to the given status | ||
:param module: module to apply quantization to | ||
:param scheme: quantization scheme to apply | ||
:param status: status to update the module to | ||
""" | ||
if status >= QuantizationStatus.INITIALIZED: | ||
initialize_module_for_quantization(module, scheme) | ||
if status >= QuantizationStatus.CALIBRATION: | ||
set_module_for_calibration(module) | ||
if status >= QuantizationStatus.FROZEN: | ||
freeze_module_quantization(module) | ||
|
||
|
||
def _iter_named_leaf_modules(model: Module) -> Tuple[str, Module]: | ||
# yields modules that do not have any submodules | ||
# TODO: potentially expand to add list of allowed submodules such as observers | ||
for name, submodule in model.named_modules(): | ||
if len(list(submodule.children())) == 0: | ||
yield name, submodule | ||
|
||
|
||
def _find_first_name_or_class_match( | ||
name: str, | ||
module: Module, | ||
targets: Iterable[str], | ||
) -> Optional[str]: | ||
# first element of targets that matches the given name | ||
# if no name matches returns first target that matches the class name | ||
# returns None otherwise | ||
return _find_first_match(name, targets) or _find_first_match( | ||
module.__class__.__name__, targets | ||
) | ||
|
||
|
||
def _find_first_match(value: str, targets: Iterable[str]) -> Optional[str]: | ||
# returns first element of target that matches value either | ||
# exactly or as a regex after 're:' | ||
for target in targets: | ||
if target.startswith("re:"): | ||
pattern = target[3:] | ||
if re.match(pattern, value): | ||
return target | ||
elif target == value: | ||
return target | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from sparsetensors.quantization.lifecycle import apply_quantization_config | ||
from sparsetensors.quantization.quant_config import QuantizationConfig | ||
from transformers import AutoModelForCausalLM | ||
|
||
|
||
def test_apply_quantization_config_tinyllama(): | ||
quant_config = get_sample_tinyllama_quant_config() | ||
model = get_tinyllama_model() | ||
|
||
# check that model is not already quantized | ||
for module in model.modules(): | ||
_test_layer_quantization_status(module, inputs=False, weights=False) | ||
|
||
# apply quant config to model | ||
apply_quantization_config(model, quant_config) | ||
|
||
# check for correct application of quant config | ||
num_linears = 0 | ||
num_embeddings = 0 | ||
num_rotary_embeddings = 0 | ||
for module in model.modules(): | ||
module_type = module.__class__.__name__ | ||
if module_type == "Linear": | ||
num_linears += 1 | ||
_test_layer_quantization_status(module, inputs=True, weights=True) | ||
elif module_type == "Embedding": | ||
num_embeddings += 1 | ||
_test_layer_quantization_status(module, inputs=False, weights=True) | ||
elif module_type == "LlamaRotaryEmbedding": | ||
num_rotary_embeddings += 1 | ||
_test_layer_quantization_status(module, inputs=False, weights=False) | ||
|
||
# sanity check correct number of layers targeted | ||
assert num_linears == 155 | ||
assert num_embeddings == 1 | ||
assert num_rotary_embeddings == 22 | ||
|
||
|
||
def _test_layer_quantization_status(module, inputs: bool, weights: bool): | ||
# check if quantization is applied at all (true if inputs or weights targeted) | ||
quantized = inputs or weights | ||
assert hasattr(module, "quantization_scheme") == quantized | ||
assert hasattr(module, "quantization_status") == quantized | ||
|
||
# check inputs matches expected | ||
assert hasattr(module, "input_scale") == inputs | ||
assert hasattr(module, "input_zero_point") == inputs | ||
|
||
# check weights matches expected | ||
assert hasattr(module, "weight_scale") == weights | ||
assert hasattr(module, "weight_zero_point") == weights | ||
|
||
|
||
def get_tinyllama_model(): | ||
return AutoModelForCausalLM.from_pretrained( | ||
"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" | ||
) | ||
|
||
|
||
def get_sample_tinyllama_quant_config(): | ||
config_dict = { | ||
"quant_method": "sparseml", | ||
"format": "fakequant", | ||
"quantization_status": "frozen", | ||
"global_compression_ratio": None, | ||
"config_groups": { | ||
"group_1": { | ||
"weights": { | ||
"num_bits": 8, | ||
"type": "int", | ||
"symmetric": True, | ||
"strategy": "tensor", | ||
}, | ||
"input_activations": { | ||
"num_bits": 8, | ||
"type": "int", | ||
"symmetric": True, | ||
"strategy": "tensor", | ||
}, | ||
"targets": ["Linear"], | ||
}, | ||
"group_2": { | ||
"weights": { | ||
"num_bits": 8, | ||
"type": "int", | ||
"symmetric": False, | ||
"strategy": "tensor", | ||
}, | ||
"input_activations": None, | ||
"targets": ["Embedding"], | ||
}, | ||
}, | ||
"ignore": ["LlamaRotaryEmbedding"], | ||
} | ||
return QuantizationConfig.parse_obj(config_dict) | ||
|
||
|
||
test_apply_quantization_config_tinyllama() |