Skip to content

Commit 7ca861b

Browse files
author
Benjamin
committed
move dynamic control to Quant Args
1 parent 5082aad commit 7ca861b

File tree

7 files changed

+19
-52
lines changed

7 files changed

+19
-52
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def _maybe_calibrate_or_quantize(
121121
return value
122122

123123
observer = getattr(module, f"{base_name}_observer")
124-
if observer.DYNAMIC:
124+
if args.dynamic:
125125
# dynamic quantization - get scale and zero point directly from observer
126126
scale, zero_point = observer(value)
127127
else:

src/compressed_tensors/quantization/lifecycle/frozen.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,17 @@ def freeze_module_quantization(module: Module):
3030
3131
:param module: module to freeze quantization for
3232
"""
33-
if not getattr(module, "quantization_scheme", None):
33+
scheme = getattr(module, "quantization_scheme", None)
34+
if not scheme:
3435
# no quantization scheme nothing to do
3536
return
3637

37-
# delete observers from module
38-
observer_names = []
39-
for submodule_name, submodule in module.named_modules():
40-
if "." not in submodule_name and submodule_name.endswith("_observer"):
41-
if getattr(submodule, "DYNAMIC", False):
42-
continue # do not delete dynamic observers
43-
44-
# delete any non-dynamic observers that belong directly to this module
45-
observer_names.append(submodule_name)
46-
for observer_name in observer_names:
47-
delattr(module, observer_name)
38+
# delete observers from module if not dynamic
39+
if scheme.input_activations and not scheme.input_activations.dynamic:
40+
delattr(module, "input_observer")
41+
if scheme.weights and not scheme.weights.dynamic:
42+
delattr(module, "weight_observer")
43+
if scheme.output_activations and not scheme.output_activations.dynamic:
44+
delattr(module, "output_observer")
4845

4946
module.quantization_status = QuantizationStatus.FROZEN

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _initialize_scale_zero_point_observer(
8484
observer = quantization_args.get_observer()
8585
module.register_module(f"{base_name}_observer", observer)
8686

87-
if observer.DYNAMIC:
87+
if quantization_args.dynamic:
8888
return # no need to register a scale and zero point for a dynamic observer
8989

9090
device = next(module.parameters()).device

src/compressed_tensors/quantization/observers/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,3 @@
1919
from .base import *
2020
from .memoryless import *
2121
from .min_max import *
22-
from .dynamic import *

src/compressed_tensors/quantization/observers/dynamic.py

Lines changed: 0 additions & 35 deletions
This file was deleted.

src/compressed_tensors/quantization/observers/memoryless.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323
__all__ = ["MemorylessObserver"]
2424

2525

26-
@Observer.register("memoryless")
26+
@Observer.register("memoryless", alias=["dynamic"])
2727
class MemorylessObserver(Observer):
2828
"""
29-
Implements a dynamic quantization observer that sets the scale and
29+
Implements a quantization observer that sets the scale and
3030
zero point based on the latest observed value without tracking state
3131
"""
3232

src/compressed_tensors/quantization/quant_args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class QuantizationArgs(BaseModel):
6161
strategy: QuantizationStrategy = QuantizationStrategy.TENSOR
6262
group_size: Optional[int] = None
6363
block_structure: Optional[str] = None
64+
dynamic: bool = False
6465
observer: str = Field(
6566
default="minmax",
6667
description=(
@@ -82,4 +83,9 @@ def get_observer(self):
8283
"""
8384
from compressed_tensors.quantization.observers.base import Observer
8485

86+
if self.observer == "minmax" and self.dynamic:
87+
# override defualt observer for dynamic, you never want minmax which
88+
# keeps state across samples for dynamic
89+
self.observer = "memoryless"
90+
8591
return Observer.load_from_registry(self.observer, quantization_args=self)

0 commit comments

Comments
 (0)