Skip to content

Commit

Permalink
Apply quantization config implementation (#4)
Browse files Browse the repository at this point in the history
* Apply quantization config implementation

* add TODO

* integrate full lifecycle support, QuantizationStatus updates, add tinyllama test

* fix comment
  • Loading branch information
bfineran authored Apr 15, 2024
1 parent f64ec82 commit 129770e
Show file tree
Hide file tree
Showing 12 changed files with 296 additions and 20 deletions.
2 changes: 2 additions & 0 deletions src/sparsetensors/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.

# flake8: noqa
# isort: skip_file

from .quant_args import *
from .quant_config import *
from .quant_scheme import *
3 changes: 2 additions & 1 deletion src/sparsetensors/quantization/lifecycle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.

# flake8: noqa
# isort: skip_file

from .calibration import *
from .forward import *
from .frozen import *
from .initialize import *
from .status import *
from .apply import *
120 changes: 120 additions & 0 deletions src/sparsetensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from collections import OrderedDict
from typing import Iterable, Optional, Tuple

from sparsetensors.quantization.lifecycle.calibration import set_module_for_calibration
from sparsetensors.quantization.lifecycle.frozen import freeze_module_quantization
from sparsetensors.quantization.lifecycle.initialize import (
initialize_module_for_quantization,
)
from sparsetensors.quantization.quant_config import (
QuantizationConfig,
QuantizationStatus,
)
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module


__all__ = [
"apply_quantization_config",
"apply_quantization_status",
]


def apply_quantization_config(model: Module, config: QuantizationConfig):
"""
Initializes the model for quantization in-place based on the given config
:param model: model to apply quantization config to
:param config: quantization config
"""
# build mapping of targets to schemes for easier matching
# use ordered dict to preserve target ordering in config
target_to_scheme = OrderedDict()
for scheme in config.config_groups.values():
for target in scheme.targets:
target_to_scheme[target] = scheme

# build list of layers to target to avoid mutating submodule dict during iteration
layer_quant_scheme_pairs = []
for name, submodule in _iter_named_leaf_modules(model):
if _find_first_name_or_class_match(name, submodule, config.ignore):
continue # layer matches ignore list, continue
target = _find_first_name_or_class_match(name, submodule, target_to_scheme)
if target is not None:
# target matched - add layer and scheme to target list
layer_quant_scheme_pairs.append((submodule, target_to_scheme[target]))

# apply current quantization status for each matched pair
for layer, scheme in layer_quant_scheme_pairs:
apply_quantization_status(
module=layer,
scheme=scheme,
status=config.quantization_status,
)


def apply_quantization_status(
module: Module, scheme: QuantizationScheme, status: QuantizationStatus
):
"""
Applies in place the quantization lifecycle up to the given status
:param module: module to apply quantization to
:param scheme: quantization scheme to apply
:param status: status to update the module to
"""
if status >= QuantizationStatus.INITIALIZED:
initialize_module_for_quantization(module, scheme)
if status >= QuantizationStatus.CALIBRATION:
set_module_for_calibration(module)
if status >= QuantizationStatus.FROZEN:
freeze_module_quantization(module)


def _iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
# yields modules that do not have any submodules
# TODO: potentially expand to add list of allowed submodules such as observers
for name, submodule in model.named_modules():
if len(list(submodule.children())) == 0:
yield name, submodule


def _find_first_name_or_class_match(
name: str,
module: Module,
targets: Iterable[str],
) -> Optional[str]:
# first element of targets that matches the given name
# if no name matches returns first target that matches the class name
# returns None otherwise
return _find_first_match(name, targets) or _find_first_match(
module.__class__.__name__, targets
)


def _find_first_match(value: str, targets: Iterable[str]) -> Optional[str]:
# returns first element of target that matches value either
# exactly or as a regex after 're:'
for target in targets:
if target.startswith("re:"):
pattern = target[3:]
if re.match(pattern, value):
return target
elif target == value:
return target
return None
2 changes: 1 addition & 1 deletion src/sparsetensors/quantization/lifecycle/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import logging

from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from sparsetensors.quantization.quant_config import QuantizationStatus
from torch.nn import Module


Expand Down
2 changes: 1 addition & 1 deletion src/sparsetensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from functools import wraps

import torch
from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from sparsetensors.quantization.quant_args import QuantizationArgs
from sparsetensors.quantization.quant_config import QuantizationStatus
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module

Expand Down
7 changes: 5 additions & 2 deletions src/sparsetensors/quantization/lifecycle/frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.


from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from sparsetensors.quantization.quant_config import QuantizationStatus
from torch.nn import Module


Expand All @@ -28,9 +28,12 @@ def freeze_module_quantization(module: Module):
return

# delete observers from module
observer_names = []
for submodule_name, _ in module.named_modules():
if "." not in submodule_name and submodule_name.endswith("_observer"):
# delete any observers that belong directly to this module
delattr(module, submodule_name)
observer_names.append(submodule_name)
for observer_name in observer_names:
delattr(module, observer_name)

module.quantization_status = QuantizationStatus.FROZEN
2 changes: 1 addition & 1 deletion src/sparsetensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

import torch
from sparsetensors.quantization.lifecycle.forward import wrap_module_forward_quantized
from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from sparsetensors.quantization.quant_args import QuantizationArgs
from sparsetensors.quantization.quant_config import QuantizationStatus
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module, Parameter

Expand Down
26 changes: 25 additions & 1 deletion src/sparsetensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
from sparsetensors.quantization.quant_scheme import QuantizationScheme


__all__ = ["QuantizationStatus", "QuantizationConfig"]
__all__ = [
"QuantizationStatus",
"QuantizationConfig",
"LIFECYCLE_ORDER",
]


class QuantizationStatus(Enum):
Expand All @@ -41,6 +45,26 @@ class QuantizationStatus(Enum):
FROZEN = "frozen"
COMPRESSED = "compressed"

@classmethod
def lifecycle_order(cls) -> List["QuantizationStatus"]:
"""
:return: list of correct quantization lifecycle order
"""
return

def __ge__(self, other):
if not isinstance(other, self.__class__):
raise NotImplementedError
return LIFECYCLE_ORDER.index(self) >= LIFECYCLE_ORDER.index(other)


LIFECYCLE_ORDER = [
QuantizationStatus.INITIALIZED,
QuantizationStatus.CALIBRATION,
QuantizationStatus.FROZEN,
QuantizationStatus.COMPRESSED,
]


class QuantizationConfig(BaseModel):
"""
Expand Down
13 changes: 0 additions & 13 deletions ...etensors/quantization/lifecycle/status.py → tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum


__all__ = [
"QuantizationStatus",
]


class QuantizationStatus(Enum):
INITIALIZED = "INITIALIZED"
CALIBRATION = "CALIBRATION"
FROZEN = "FROZEN"
13 changes: 13 additions & 0 deletions tests/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions tests/quantization/lifecycle/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
113 changes: 113 additions & 0 deletions tests/quantization/lifecycle/test_apply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from sparsetensors.quantization.lifecycle import apply_quantization_config
from sparsetensors.quantization.quant_config import QuantizationConfig
from transformers import AutoModelForCausalLM


def test_apply_quantization_config_tinyllama():
quant_config = get_sample_tinyllama_quant_config()
model = get_tinyllama_model()

# check that model is not already quantized
for module in model.modules():
_test_layer_quantization_status(module, inputs=False, weights=False)

# apply quant config to model
apply_quantization_config(model, quant_config)

# check for correct application of quant config
num_linears = 0
num_embeddings = 0
num_rotary_embeddings = 0
for module in model.modules():
module_type = module.__class__.__name__
if module_type == "Linear":
num_linears += 1
_test_layer_quantization_status(module, inputs=True, weights=True)
elif module_type == "Embedding":
num_embeddings += 1
_test_layer_quantization_status(module, inputs=False, weights=True)
elif module_type == "LlamaRotaryEmbedding":
num_rotary_embeddings += 1
_test_layer_quantization_status(module, inputs=False, weights=False)

# sanity check correct number of layers targeted
assert num_linears == 155
assert num_embeddings == 1
assert num_rotary_embeddings == 22


def _test_layer_quantization_status(module, inputs: bool, weights: bool):
# check if quantization is applied at all (true if inputs or weights targeted)
quantized = inputs or weights
assert hasattr(module, "quantization_scheme") == quantized
assert hasattr(module, "quantization_status") == quantized

# check inputs matches expected
assert hasattr(module, "input_scale") == inputs
assert hasattr(module, "input_zero_point") == inputs

# check weights matches expected
assert hasattr(module, "weight_scale") == weights
assert hasattr(module, "weight_zero_point") == weights


def get_tinyllama_model():
return AutoModelForCausalLM.from_pretrained(
"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
)


def get_sample_tinyllama_quant_config():
config_dict = {
"quant_method": "sparseml",
"format": "fakequant",
"quantization_status": "frozen",
"global_compression_ratio": None,
"config_groups": {
"group_1": {
"weights": {
"num_bits": 8,
"type": "int",
"symmetric": True,
"strategy": "tensor",
},
"input_activations": {
"num_bits": 8,
"type": "int",
"symmetric": True,
"strategy": "tensor",
},
"targets": ["Linear"],
},
"group_2": {
"weights": {
"num_bits": 8,
"type": "int",
"symmetric": False,
"strategy": "tensor",
},
"input_activations": None,
"targets": ["Embedding"],
},
},
"ignore": ["LlamaRotaryEmbedding"],
}
return QuantizationConfig.parse_obj(config_dict)


test_apply_quantization_config_tinyllama()

0 comments on commit 129770e

Please sign in to comment.