Skip to content

Commit d707c5b

Browse files
authored
Dyanmic Quantization (#15)
* [WIP] Dyanmic Quantization * update imports post rename * update dynamic bool * move dynamic control to Quant Args * Apply suggestions from code review * docstring and test
1 parent dd2bd7f commit d707c5b

File tree

6 files changed

+164
-27
lines changed

6 files changed

+164
-27
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def wrapped_forward(self, *args, **kwargs):
111111

112112

113113
def _maybe_calibrate_or_quantize(
114-
module: Module, value: Module, base_name: str, args: "QuantizationArgs"
114+
module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
115115
) -> torch.Tensor:
116116
# only run quantized for the included stages
117117
if module.quantization_status not in {
@@ -120,17 +120,23 @@ def _maybe_calibrate_or_quantize(
120120
}:
121121
return value
122122

123-
device = next(module.parameters()).device
124-
scale = getattr(module, f"{base_name}_scale")
125-
zero_point = getattr(module, f"{base_name}_zero_point")
126-
127-
if module.quantization_status == QuantizationStatus.CALIBRATION:
128-
# get observer and get new quant params from observation
123+
if args.dynamic:
124+
# dynamic quantization - get scale and zero point directly from observer
129125
observer = getattr(module, f"{base_name}_observer")
130-
updated_scale, updated_zero_point = observer(value)
131-
132-
# update scale and zero point
133-
scale.data = updated_scale.to(device)
134-
zero_point.data = updated_zero_point.to(device)
126+
scale, zero_point = observer(value)
127+
else:
128+
# static quantization - get previous scale and zero point from layer
129+
scale = getattr(module, f"{base_name}_scale")
130+
zero_point = getattr(module, f"{base_name}_zero_point")
131+
132+
if module.quantization_status == QuantizationStatus.CALIBRATION:
133+
# calibration mode - get new quant params from observer
134+
observer = getattr(module, f"{base_name}_observer")
135+
updated_scale, updated_zero_point = observer(value)
136+
137+
# update scale and zero point
138+
device = next(module.parameters()).device
139+
scale.data = updated_scale.to(device)
140+
zero_point.data = updated_zero_point.to(device)
135141

136142
return fake_quantize(value, scale, zero_point, args)

src/compressed_tensors/quantization/lifecycle/frozen.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,17 @@ def freeze_module_quantization(module: Module):
3030
3131
:param module: module to freeze quantization for
3232
"""
33-
if not getattr(module, "quantization_scheme", None):
33+
scheme = getattr(module, "quantization_scheme", None)
34+
if not scheme:
3435
# no quantization scheme nothing to do
3536
return
3637

37-
# delete observers from module
38-
observer_names = []
39-
for submodule_name, _ in module.named_modules():
40-
if "." not in submodule_name and submodule_name.endswith("_observer"):
41-
# delete any observers that belong directly to this module
42-
observer_names.append(submodule_name)
43-
for observer_name in observer_names:
44-
delattr(module, observer_name)
38+
# delete observers from module if not dynamic
39+
if scheme.input_activations and not scheme.input_activations.dynamic:
40+
delattr(module, "input_observer")
41+
if scheme.weights and not scheme.weights.dynamic:
42+
delattr(module, "weight_observer")
43+
if scheme.output_activations and not scheme.output_activations.dynamic:
44+
delattr(module, "output_observer")
4545

4646
module.quantization_status = QuantizationStatus.FROZEN

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,13 @@ def initialize_module_for_quantization(
8080
def _initialize_scale_zero_point_observer(
8181
module: Module, base_name: str, quantization_args: QuantizationArgs
8282
):
83+
# initialize observer module and attach as submodule
84+
observer = quantization_args.get_observer()
85+
module.register_module(f"{base_name}_observer", observer)
86+
87+
if quantization_args.dynamic:
88+
return # no need to register a scale and zero point for a dynamic observer
89+
8390
device = next(module.parameters()).device
8491

8592
# initializes empty scale and zero point parameters for the module
@@ -90,7 +97,3 @@ def _initialize_scale_zero_point_observer(
9097
torch.empty(0, device=device, dtype=int), requires_grad=False
9198
)
9299
module.register_parameter(f"{base_name}_zero_point", init_zero_point)
93-
94-
# initialize observer module and attach as submodule
95-
observer = quantization_args.get_observer()
96-
module.register_module(f"{base_name}_observer", observer)

src/compressed_tensors/quantization/observers/memoryless.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323
__all__ = ["MemorylessObserver"]
2424

2525

26-
@Observer.register("memoryless")
26+
@Observer.register("memoryless", alias=["dynamic"])
2727
class MemorylessObserver(Observer):
2828
"""
29-
Implements a dynamic quantization observer that sets the scale and
29+
Implements a quantization observer that sets the scale and
3030
zero point based on the latest observed value without tracking state
3131
"""
3232

src/compressed_tensors/quantization/quant_args.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ class QuantizationArgs(BaseModel):
5353
:param group_size: group length to use for the group strategy
5454
:param block_structure: 2d block structure to use for the block strategy, must be
5555
of the format "2x4", "8x16", etc.
56+
:param dynamic: set True to perform dynamic quantization - values will not be
57+
calibrated during calibration phase, instead during inference new quantization
58+
ranges will be observed with every sample. Defaults to False for static
59+
quantization. Note that enabling dynamic quantization will change the default
60+
observer to a memoryless one
5661
"""
5762

5863
num_bits: int = 8
@@ -61,6 +66,7 @@ class QuantizationArgs(BaseModel):
6166
strategy: QuantizationStrategy = QuantizationStrategy.TENSOR
6267
group_size: Optional[int] = None
6368
block_structure: Optional[str] = None
69+
dynamic: bool = False
6470
observer: str = Field(
6571
default="minmax",
6672
description=(
@@ -82,4 +88,9 @@ def get_observer(self):
8288
"""
8389
from compressed_tensors.quantization.observers.base import Observer
8490

91+
if self.observer == "minmax" and self.dynamic:
92+
# override defualt observer for dynamic, you never want minmax which
93+
# keeps state across samples for dynamic
94+
self.observer = "memoryless"
95+
8596
return Observer.load_from_registry(self.observer, quantization_args=self)
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import torch
17+
from compressed_tensors.quantization.lifecycle import (
18+
apply_quantization_config,
19+
freeze_module_quantization,
20+
)
21+
from compressed_tensors.quantization.quant_config import QuantizationConfig
22+
from transformers import AutoModelForCausalLM
23+
24+
25+
def test_apply_tinyllama_dynamic_activations():
26+
quant_config = get_sample_dynamic_tinyllama_quant_config()
27+
model = get_tinyllama_model()
28+
29+
# check that model is not already quantized
30+
for module in model.modules():
31+
_test_layer_dynamic_quantization_status(module, inputs=False, weights=False)
32+
33+
# apply quant config to model
34+
apply_quantization_config(model, quant_config)
35+
36+
# test linears are dynamically quantized for calibration
37+
_test_linears_dynamic_quantization_status(model, quant_config, frozen=False)
38+
# verify forward works w/ dynamic during calibration
39+
model(torch.zeros((1, 1), dtype=int), torch.zeros((1, 1), dtype=int))
40+
41+
# freeze and test that only weight observers are deleted
42+
model.apply(freeze_module_quantization)
43+
_test_linears_dynamic_quantization_status(model, quant_config, frozen=True)
44+
# verify forward works w/ dynamic after freeze
45+
model(torch.zeros((1, 1), dtype=int), torch.zeros((1, 1), dtype=int))
46+
47+
48+
def _test_linears_dynamic_quantization_status(model, quant_config, frozen: bool):
49+
# check for correct application of quant config
50+
num_linears = 0
51+
for name, module in model.named_modules():
52+
if name in quant_config.ignore:
53+
continue
54+
module_type = module.__class__.__name__
55+
if module_type == "Linear":
56+
num_linears += 1
57+
_test_layer_dynamic_quantization_status(
58+
module, inputs=True, weights=True, frozen=frozen
59+
)
60+
61+
# sanity check correct number of layers targeted
62+
assert num_linears == 154 # 155 Linear layers - 1 that gets ignored
63+
64+
65+
def _test_layer_dynamic_quantization_status(
66+
module, inputs: bool, weights: bool, frozen: bool = False
67+
):
68+
# check if quantization is applied at all (true if inputs or weights targeted)
69+
quantized = inputs or weights
70+
assert hasattr(module, "quantization_scheme") == quantized
71+
assert hasattr(module, "quantization_status") == quantized
72+
73+
# check inputs always have an observer if quantized but never scale/zp
74+
assert not hasattr(module, "input_scale")
75+
assert not hasattr(module, "input_zero_point")
76+
assert hasattr(module, "input_observer") == inputs
77+
78+
# check weights always have scale/zp and observer only if not frozen
79+
assert hasattr(module, "weight_scale") == weights
80+
assert hasattr(module, "weight_zero_point") == weights
81+
assert hasattr(module, "weight_observer") == (weights and not frozen)
82+
83+
84+
def get_tinyllama_model():
85+
return AutoModelForCausalLM.from_pretrained(
86+
"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
87+
)
88+
89+
90+
def get_sample_dynamic_tinyllama_quant_config():
91+
config_dict = {
92+
"quant_method": "sparseml",
93+
"format": "fakequant",
94+
"quantization_status": "calibration",
95+
"global_compression_ratio": None,
96+
"config_groups": {
97+
"group_1": {
98+
"weights": {
99+
"num_bits": 8,
100+
"type": "int",
101+
"symmetric": True,
102+
"strategy": "tensor",
103+
"dynamic": False,
104+
},
105+
"input_activations": {
106+
"num_bits": 8,
107+
"type": "int",
108+
"symmetric": True,
109+
"strategy": "tensor",
110+
"dynamic": True,
111+
},
112+
"targets": ["Linear"],
113+
},
114+
},
115+
"ignore": ["LlamaRotaryEmbedding", "model.layers.1.mlp.down_proj"],
116+
}
117+
return QuantizationConfig.parse_obj(config_dict)

0 commit comments

Comments
 (0)