Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dyanmic Quantization #15

Merged
merged 6 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,22 @@ def _maybe_calibrate_or_quantize(
}:
return value

device = next(module.parameters()).device
scale = getattr(module, f"{base_name}_scale")
zero_point = getattr(module, f"{base_name}_zero_point")

if module.quantization_status == QuantizationStatus.CALIBRATION:
# get observer and get new quant params from observation
observer = getattr(module, f"{base_name}_observer")
updated_scale, updated_zero_point = observer(value)

# update scale and zero point
scale.data = updated_scale.to(device)
zero_point.data = updated_zero_point.to(device)
observer = getattr(module, f"{base_name}_observer")
if args.dynamic:
# dynamic quantization - get scale and zero point directly from observer
scale, zero_point = observer(value)
else:
# static quantization - get previous scale and zero point from layer
scale = getattr(module, f"{base_name}_scale")
zero_point = getattr(module, f"{base_name}_zero_point")

if module.quantization_status == QuantizationStatus.CALIBRATION:
# calibration mode - get new quant params from observer
updated_scale, updated_zero_point = observer(value)

# update scale and zero point
device = next(module.parameters()).device
scale.data = updated_scale.to(device)
zero_point.data = updated_zero_point.to(device)

return fake_quantize(value, scale, zero_point, args)
18 changes: 9 additions & 9 deletions src/compressed_tensors/quantization/lifecycle/frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,17 @@ def freeze_module_quantization(module: Module):

:param module: module to freeze quantization for
"""
if not getattr(module, "quantization_scheme", None):
scheme = getattr(module, "quantization_scheme", None)
if not scheme:
# no quantization scheme nothing to do
return

# delete observers from module
observer_names = []
for submodule_name, _ in module.named_modules():
if "." not in submodule_name and submodule_name.endswith("_observer"):
# delete any observers that belong directly to this module
observer_names.append(submodule_name)
for observer_name in observer_names:
delattr(module, observer_name)
# delete observers from module if not dynamic
if scheme.input_activations and not scheme.input_activations.dynamic:
delattr(module, "input_observer")
if scheme.weights and not scheme.weights.dynamic:
delattr(module, "weight_observer")
if scheme.output_activations and not scheme.output_activations.dynamic:
delattr(module, "output_observer")

module.quantization_status = QuantizationStatus.FROZEN
11 changes: 7 additions & 4 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ def initialize_module_for_quantization(
def _initialize_scale_zero_point_observer(
module: Module, base_name: str, quantization_args: QuantizationArgs
):
# initialize observer module and attach as submodule
observer = quantization_args.get_observer()
module.register_module(f"{base_name}_observer", observer)

if quantization_args.dynamic:
return # no need to register a scale and zero point for a dynamic observer

device = next(module.parameters()).device

# initializes empty scale and zero point parameters for the module
Expand All @@ -90,7 +97,3 @@ def _initialize_scale_zero_point_observer(
torch.empty(0, device=device, dtype=int), requires_grad=False
)
module.register_parameter(f"{base_name}_zero_point", init_zero_point)

# initialize observer module and attach as submodule
observer = quantization_args.get_observer()
module.register_module(f"{base_name}_observer", observer)
4 changes: 2 additions & 2 deletions src/compressed_tensors/quantization/observers/memoryless.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
__all__ = ["MemorylessObserver"]


@Observer.register("memoryless")
@Observer.register("memoryless", alias=["dynamic"])
class MemorylessObserver(Observer):
"""
Implements a dynamic quantization observer that sets the scale and
Implements a quantization observer that sets the scale and
zero point based on the latest observed value without tracking state
"""

Expand Down
6 changes: 6 additions & 0 deletions src/compressed_tensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class QuantizationArgs(BaseModel):
strategy: QuantizationStrategy = QuantizationStrategy.TENSOR
group_size: Optional[int] = None
block_structure: Optional[str] = None
dynamic: bool = False
bfineran marked this conversation as resolved.
Show resolved Hide resolved
observer: str = Field(
default="minmax",
description=(
Expand All @@ -82,4 +83,9 @@ def get_observer(self):
"""
from compressed_tensors.quantization.observers.base import Observer

if self.observer == "minmax" and self.dynamic:
# override defualt observer for dynamic, you never want minmax which
# keeps state across samples for dynamic
self.observer = "memoryless"

return Observer.load_from_registry(self.observer, quantization_args=self)
Loading