Skip to content

Commit 2395833

Browse files
author
Benjamin
committed
[WIP] Dyanmic Quantization
1 parent ee6a913 commit 2395833

File tree

6 files changed

+68
-19
lines changed

6 files changed

+68
-19
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -112,18 +112,22 @@ def _maybe_calibrate_or_quantize(
112112
}:
113113
return value
114114

115-
device = next(module.parameters()).device
116-
scale = getattr(module, f"{base_name}_scale")
117-
# zero_point = getattr(module, f"{base_name}_zero_point").data
118-
zero_point = getattr(module, f"{base_name}_zero_point")
119-
120-
if module.quantization_status == QuantizationStatus.CALIBRATION:
121-
# get observer and get new quant params from observation
122-
observer = getattr(module, f"{base_name}_observer")
123-
updated_scale, updated_zero_point = observer(value)
124-
125-
# update scale and zero point
126-
scale.data = updated_scale.to(device)
127-
zero_point.data = updated_zero_point.to(device)
115+
observer = getattr(module, f"{base_name}_observer")
116+
if observer.DYNAMIC:
117+
# dynamic quantization - get scale and zero point directly from observer
118+
scale, zero_point = observer(value)
119+
else:
120+
# static quantization - get previous scale and zero point from layer
121+
scale = getattr(module, f"{base_name}_scale")
122+
zero_point = getattr(module, f"{base_name}_zero_point")
123+
124+
if module.quantization_status == QuantizationStatus.CALIBRATION:
125+
# calibration mode - get new quant params from observer
126+
updated_scale, updated_zero_point = observer(value)
127+
128+
# update scale and zero point
129+
device = next(module.parameters()).device
130+
scale.data = updated_scale.to(device)
131+
zero_point.data = updated_zero_point.to(device)
128132

129133
return fake_quantize(value, scale, zero_point, args)

src/compressed_tensors/quantization/lifecycle/frozen.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,12 @@ def freeze_module_quantization(module: Module):
3636

3737
# delete observers from module
3838
observer_names = []
39-
for submodule_name, _ in module.named_modules():
39+
for submodule_name, submodule in module.named_modules():
4040
if "." not in submodule_name and submodule_name.endswith("_observer"):
41-
# delete any observers that belong directly to this module
41+
if getattr(submodule, "DYNAMIC", False):
42+
continue # do not delete dynamic observers
43+
44+
# delete any non-dynamic observers that belong directly to this module
4245
observer_names.append(submodule_name)
4346
for observer_name in observer_names:
4447
delattr(module, observer_name)

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,13 @@ def initialize_module_for_quantization(
7878
def _initialize_scale_zero_point_observer(
7979
module: Module, base_name: str, quantization_args: QuantizationArgs
8080
):
81+
# initialize observer module and attach as submodule
82+
observer = quantization_args.get_observer()
83+
module.register_module(f"{base_name}_observer", observer)
84+
85+
if observer.DYNAMIC:
86+
return # no need to register a scale and zero point for a dynamic observer
87+
8188
device = next(module.parameters()).device
8289

8390
# initializes empty scale and zero point parameters for the module
@@ -88,7 +95,3 @@ def _initialize_scale_zero_point_observer(
8895
torch.empty(0, device=device, dtype=int), requires_grad=False
8996
)
9097
module.register_parameter(f"{base_name}_zero_point", init_zero_point)
91-
92-
# initialize observer module and attach as submodule
93-
observer = quantization_args.get_observer()
94-
module.register_module(f"{base_name}_observer", observer)

src/compressed_tensors/quantization/observers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@
1919
from .base import *
2020
from .memoryless import *
2121
from .min_max import *
22+
from .dynamic import *

src/compressed_tensors/quantization/observers/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ class Observer(Module, RegistryMixin):
3030
pair
3131
"""
3232

33+
# child classes should set to True if they are meant to be used as dynamic
34+
DYNAMIC = False
35+
3336
def __init__(self, quantization_args: QuantizationArgs):
3437
self.quantization_args: QuantizationArgs = quantization_args
3538
super().__init__()
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from sparsetensors.quantization.observers.base import Observer
17+
from sparsetensors.quantization.observers.memoryless import MemorylessObserver
18+
19+
20+
__all__ = ["DynamicObserver"]
21+
22+
23+
@Observer.register("dynamic")
24+
class DynamicObserver(MemorylessObserver):
25+
"""
26+
Values targted for a dyanmic observer do not require calibration,
27+
this observer will persist in the model through the lifecycle, calculating
28+
the quantization parameters on the fly for each observed Tensor.
29+
30+
This base dynamic observer uses the `calculate_qparams` from MemorylessObserver
31+
where each scale and zero point is based solely on the currently observed
32+
Tensor.
33+
"""
34+
35+
DYNAMIC = False

0 commit comments

Comments
 (0)