diff --git a/setup.py b/setup.py index 12ef67a3..89a17ad7 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ def _setup_install_requires() -> List: return ["torch>=1.7.0", "transformers<=4.40", "pydantic<2.7"] def _setup_extras() -> Dict: - return {"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0"]} + return {"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0",]} setup( name="sparsetensors", diff --git a/src/sparsetensors/quantization/lifecycle/__init__.py b/src/sparsetensors/quantization/lifecycle/__init__.py new file mode 100644 index 00000000..52b86440 --- /dev/null +++ b/src/sparsetensors/quantization/lifecycle/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa + +from .calibration import * +from .forward import * +from .frozen import * +from .initialize import * +from .status import * diff --git a/src/sparsetensors/quantization/lifecycle/calibration.py b/src/sparsetensors/quantization/lifecycle/calibration.py new file mode 100644 index 00000000..986b062a --- /dev/null +++ b/src/sparsetensors/quantization/lifecycle/calibration.py @@ -0,0 +1,43 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +from sparsetensors.quantization.lifecycle.status import QuantizationStatus +from torch.nn import Module + + +__all__ = [ + "set_module_for_calibration", +] + + +_LOGGER = logging.getLogger(__name__) + + +def set_module_for_calibration(module: Module): + if not getattr(module, "quantization_scheme", None): + # no quantization scheme nothing to do + return + status = getattr(module, "quantization_status", None) + if not status or status != QuantizationStatus.INITIALIZED: + raise _LOGGER.warning( + f"Attempting set module with status {status} to calibration mode. " + f"but status is not {QuantizationStatus.INITIALIZED} - you may " + "be calibrating an uninitialized module which may fail or attempting " + "to re-calibrate a frozen module" + ) + + module.quantization_status = QuantizationStatus.CALIBRATION diff --git a/src/sparsetensors/quantization/lifecycle/forward.py b/src/sparsetensors/quantization/lifecycle/forward.py new file mode 100644 index 00000000..ab20e29b --- /dev/null +++ b/src/sparsetensors/quantization/lifecycle/forward.py @@ -0,0 +1,127 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import wraps + +import torch +from sparsetensors.quantization.lifecycle.status import QuantizationStatus +from sparsetensors.quantization.quant_args import QuantizationArgs +from sparsetensors.quantization.quant_scheme import QuantizationScheme +from torch.nn import Module + + +__all__ = ["wrap_module_forward_quantized"] + + +def quantize( + x: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + q_max: torch.Tensor, +) -> torch.Tensor: + return torch.clamp( + torch.round( + x / scale + zero_point, + ), + 0, + q_max, + ) + + +def dequantize( + x_q: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, +) -> torch.Tensor: + return (x_q - zero_point) * scale + + +def fake_quantize( + x: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + args: QuantizationArgs, +) -> torch.Tensor: + max_q = torch.tensor(2**args.num_bits - 1) + Q = torch.zeros_like(x) + Q = quantize(x, scale, zero_point, max_q) + return dequantize(Q, scale, zero_point) + + +def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme): + # expects a module already initialized and injected with the parameters in + # initialize_module_for_quantization + forward_func_orig = module.forward.__func__ + + @wraps(forward_func_orig) # ensures docstring, names, etc are propagated + def wrapped_forward(self, *args, **kwargs): + input_ = args[0] + + if scheme.input_activations is not None: + # calibrate and (fake) quantize input activations when applicable + input_ = _maybe_calibrate_or_quantize( + module, input_, "input", scheme.input_activations + ) + + if scheme.weights is not None: + # calibrate and (fake) quantize weights when applicable + self.weight.data = _maybe_calibrate_or_quantize( + module, self.weight, "weight", scheme.weights + ) + + # perform wrapped forward call + output = forward_func_orig.__get__(module, module.__class__)( + input_, *args[1:], **kwargs + ) + + if scheme.output_activations is not None: + # calibrate and (fake) quantize output activations when applicable + output = _maybe_calibrate_or_quantize( + module, output, "output", scheme.output_activations + ) + + return output + + # bind wrapped forward to module class so reference to `self` is correct + bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__) + # set forward to wrapped forward + setattr(module, "forward", bound_wrapped_forward) + + +def _maybe_calibrate_or_quantize( + module: Module, value: Module, base_name: str, args: "QuantizationArgs" +) -> torch.Tensor: + # only run quantized for the included stages + if module.quantization_status not in { + QuantizationStatus.CALIBRATION, + QuantizationStatus.FROZEN, + }: + return value + + scale = getattr(module, f"{base_name}_scale") + # zero_point = getattr(module, f"{base_name}_zero_point").data + zero_point = getattr(module, f"{base_name}_zero_point") + + print(scale, zero_point) + + if module.quantization_status == QuantizationStatus.CALIBRATION: + # get observer and get new quant params from observation + observer = getattr(module, f"{base_name}_observer") + updated_scale, updated_zero_point = observer(value) + + # update scale and zero point + scale.data = updated_scale + zero_point.data = updated_zero_point + + return fake_quantize(value, scale, zero_point, args) diff --git a/src/sparsetensors/quantization/lifecycle/frozen.py b/src/sparsetensors/quantization/lifecycle/frozen.py new file mode 100644 index 00000000..6b92eee7 --- /dev/null +++ b/src/sparsetensors/quantization/lifecycle/frozen.py @@ -0,0 +1,36 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from sparsetensors.quantization.lifecycle.status import QuantizationStatus +from torch.nn import Module + + +__all__ = [ + "freeze_module_quantization", +] + + +def freeze_module_quantization(module: Module): + if not getattr(module, "quantization_scheme", None): + # no quantization scheme nothing to do + return + + # delete observers from module + for submodule_name, _ in module.named_modules(): + if "." not in submodule_name and submodule_name.endswith("_observer"): + # delete any observers that belong directly to this module + delattr(module, submodule_name) + + module.quantization_status = QuantizationStatus.FROZEN diff --git a/src/sparsetensors/quantization/lifecycle/initialize.py b/src/sparsetensors/quantization/lifecycle/initialize.py new file mode 100644 index 00000000..5661fdbd --- /dev/null +++ b/src/sparsetensors/quantization/lifecycle/initialize.py @@ -0,0 +1,71 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +import torch +from sparsetensors.quantization.lifecycle.forward import wrap_module_forward_quantized +from sparsetensors.quantization.lifecycle.status import QuantizationStatus +from sparsetensors.quantization.quant_args import QuantizationArgs +from sparsetensors.quantization.quant_scheme import QuantizationScheme +from torch.nn import Module, Parameter + + +__all__ = [ + "initialize_module_for_quantization", +] + + +_LOGGER = logging.getLogger(__name__) + + +def initialize_module_for_quantization(module: Module, scheme: QuantizationScheme): + if scheme.input_activations is not None: + + _initialize_scale_zero_point_observer(module, "input", scheme.input_activations) + if scheme.weights is not None: + if hasattr(module, "weight"): + _initialize_scale_zero_point_observer(module, "weight", scheme.weights) + else: + _LOGGER.warning( + f"module type {type(module)} targeted for weight quantization but " + "has no attribute weight, skipping weight quantization " + f"for {type(module)}" + ) + if scheme.output_activations is not None: + _initialize_scale_zero_point_observer( + module, "output", scheme.output_activations + ) + + module.quantization_scheme = scheme + module.quantization_status = QuantizationStatus.INITIALIZED + + # wrap forward call of module to perform quantized actions based on calltime status + wrap_module_forward_quantized(module, scheme) + + +def _initialize_scale_zero_point_observer( + module: Module, base_name: str, quantization_args: QuantizationArgs +): + # initializes empty scale and zero point parameters for the module + init_scale = Parameter(torch.empty(0), requires_grad=False) + module.register_parameter(f"{base_name}_scale", init_scale) + + init_zero_point = Parameter(torch.empty(0, dtype=int), requires_grad=False) + module.register_parameter(f"{base_name}_zero_point", init_zero_point) + + # initialize observer module and attach as submodule + observer = quantization_args.get_observer() + module.register_module(f"{base_name}_observer", observer) diff --git a/src/sparsetensors/quantization/lifecycle/status.py b/src/sparsetensors/quantization/lifecycle/status.py new file mode 100644 index 00000000..3b6a441d --- /dev/null +++ b/src/sparsetensors/quantization/lifecycle/status.py @@ -0,0 +1,26 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +__all__ = [ + "QuantizationStatus", +] + + +class QuantizationStatus(Enum): + INITIALIZED = "INITIALIZED" + CALIBRATION = "CALIBRATION" + FROZEN = "FROZEN" diff --git a/src/sparsetensors/quantization/observers/__init__.py b/src/sparsetensors/quantization/observers/__init__.py new file mode 100644 index 00000000..d0362b8f --- /dev/null +++ b/src/sparsetensors/quantization/observers/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa + +from .base import * +from .memoryless import * +from .min_max import * diff --git a/src/sparsetensors/quantization/observers/base.py b/src/sparsetensors/quantization/observers/base.py new file mode 100644 index 00000000..52a464b9 --- /dev/null +++ b/src/sparsetensors/quantization/observers/base.py @@ -0,0 +1,69 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +from sparsetensors.quantization.quant_args import QuantizationArgs +from sparsetensors.registry.registry import RegistryMixin +from torch import FloatTensor, IntTensor, Tensor +from torch.nn import Module + + +__all__ = ["Observer"] + + +class Observer(Module, RegistryMixin): + """ + Base Observer class to be subclassed for specific implementation. + Subclasses should override `calculate_qparams` to return a scale, zero_point + pair + """ + + def __init__(self, quantization_args: QuantizationArgs): + self.quantization_args: QuantizationArgs = quantization_args + super().__init__() + self._scale = None + self._zero_point = None + + def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: + """ + maps directly to get_qparams + :param observed: optional observed tensor to calculate quantization parameters + from + :return: tuple of scale and zero point based on last observed value + """ + return self.get_qparams(observed=observed) + + def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: + """ + :param observed: observed tensor to calculate quantization parameters for + :return: tuple of scale and zero point derived from the observed tensor + """ + raise NotImplementedError(f"{self.__class__} must implement calculate_qparams") + + def get_qparams( + self, observed: Optional[Tensor] = None + ) -> Tuple[FloatTensor, IntTensor]: + """ + Convenience function to wrap overwritten calculate_qparams + adds support to make observed tensor optional and support for tracking latest + calculated scale and zero point + :param observed: optional observed tensor to calculate quantization parameters + from + :return: tuple of scale and zero point based on last observed value + """ + if observed is not None: + # re-calcualte scale and zero point, update the stored value + self._scale, self._zero_point = self.calculate_qparams(observed) + return self._scale, self._zero_point diff --git a/src/sparsetensors/quantization/observers/memoryless.py b/src/sparsetensors/quantization/observers/memoryless.py new file mode 100644 index 00000000..b69c841d --- /dev/null +++ b/src/sparsetensors/quantization/observers/memoryless.py @@ -0,0 +1,61 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch +from sparsetensors.quantization.observers.base import Observer +from torch import FloatTensor, IntTensor, Tensor + + +__all__ = ["MemorylessObserver"] + + +@Observer.register("memoryless") +class MemorylessObserver(Observer): + """ + Implements a dynamic quantization observer that sets the scale and + zero point based on the latest observed value + """ + + def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: + """ + :param observed: observed tensor to calculate quantization parameters for + :return: tuple of scale and zero point derived from the observed tensor + """ + # TODO: Add support for full range of quantization Args, only supports 8bit + # per tensor + bit_range = 255 + min_val = observed.min() + max_val = observed.max() + + # ensure zero is in the range + min_val = torch.min(min_val, torch.zeros_like(min_val)) + max_val = torch.max(max_val, torch.zeros_like(max_val)) + + if self.quantization_args.symmetric: + symmetric_range = 2 * max(min_val.abs(), max_val.abs()) + scale = symmetric_range / bit_range + zero_point = torch.tensor(0).to(torch.int8) + else: + # non-symmetric + observed_range = max_val - min_val + scale = observed_range / bit_range + + # scales from a 0 range should be set to 1 + scale[observed_range == 0] = 1 + + zero_point = (0 - min_val) / scale + + return scale, zero_point diff --git a/src/sparsetensors/quantization/observers/min_max.py b/src/sparsetensors/quantization/observers/min_max.py new file mode 100644 index 00000000..40cde72c --- /dev/null +++ b/src/sparsetensors/quantization/observers/min_max.py @@ -0,0 +1,79 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch +from sparsetensors.quantization.observers.base import Observer +from sparsetensors.quantization.quant_args import QuantizationArgs +from torch import FloatTensor, IntTensor, Tensor + + +__all__ = ["MinMaxObserver"] + + +@Observer.register("minmax") +class MinMaxObserver(Observer): + """ + Implements a dynamic quantization observer that sets the scale and + zero point based on the latest observed value + """ + + def __init__(self, quantization_args: QuantizationArgs): + super().__init__(quantization_args=quantization_args) + + self.min_val = float("inf") + self.max_val = -float("inf") + self.counter = 0 + + def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: + """ + :param observed: observed tensor to calculate quantization parameters for + :return: tuple of scale and zero point derived from the observed tensor + """ + # TODO: Add support for full range of quantization Args, only supports 8bit + # per tensor + bit_range = 255 + min_val = torch.tensor([observed.min()]) + max_val = torch.tensor([observed.max()]) + + # running average + if self.counter > 0: + self.min_val = (self.min_val * self.counter + min_val) / (self.counter + 1) + self.max_val = (self.max_val * self.counter + max_val) / (self.counter + 1) + else: + self.min_val = min_val + self.max_val = max_val + + # ensure that the zeros are in the range + self.min_val = torch.min(self.min_val, torch.zeros_like(self.min_val)) + self.max_val = torch.max(self.max_val, torch.zeros_like(self.max_val)) + + self.counter += 1 + + if self.quantization_args.symmetric: + symmetric_range = 2 * max(self.min_val.abs(), self.max_val.abs()) + scale = symmetric_range / bit_range + zero_point = torch.tensor(0).to(torch.int8) + else: + # non-symmetric + observed_range = self.max_val - self.min_val + scale = observed_range / bit_range + + # scales from a 0 range should be set to 1 + scale[observed_range == 0] = 1 + + zero_point = (0 - self.min_val) / scale + + return scale, zero_point diff --git a/src/sparsetensors/quantization/quant_args.py b/src/sparsetensors/quantization/quant_args.py index 89a2e3df..d90fe9bc 100644 --- a/src/sparsetensors/quantization/quant_args.py +++ b/src/sparsetensors/quantization/quant_args.py @@ -13,9 +13,9 @@ # limitations under the License. from enum import Enum -from typing import Optional +from typing import Any, Dict, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field __all__ = ["QuantizationType", "QuantizationStrategy", "QuantizationArgs"] @@ -61,3 +61,25 @@ class QuantizationArgs(BaseModel): strategy: QuantizationStrategy = QuantizationStrategy.TENSOR group_size: Optional[int] = None block_structure: Optional[str] = None + observer: str = Field( + default="minmax", + description=( + "The class to use to compute the quantization param - " + "scale and zero-point'" + ), + ) + observer_kwargs: Dict[str, Any] = Field( + default_factory=dict, + description=( + "optional dict of kwargs to be passed directly to torch quantization " + "Observers constructor excluding quantization range or symmetry" + ), + ) + + def get_observer(self): + """ + :return: torch quantization FakeQuantize built based on these QuantizationArgs + """ + from sparsetensors.quantization.observers.base import Observer + + return Observer.load_from_registry(self.observer, quantization_args=self)